likes
comments
collection
share

在FastAPI 中正确使用 async SQLAlchemy、celery、websockets

作者站长头像
站长
· 阅读数 30

在FastAPI 中正确使用 async SQLAlchemy、celery、websockets

在FastAPI 中正确使用 async SQLAlchemy、celery、websockets

从版本1.4开始,SQLAlchemy支持asyncio。在本文章中,我们将尝试使用async SQLAlchemy功能、encryptioncelerywebsocket来实现简单的项目。我们从数据库连接开始。

使用异步SQLAlchemy设置数据库

首先让我们创建异步session

 from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
 from sqlalchemy.orm import sessionmaker
 ​
 from app.core.config import settings
 ​
 engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URL, echo=True)
 SessionLocal = sessionmaker(
     expire_on_commit=False,
     class_=AsyncSession,
     bind=engine,
 )

我们使用 FastAPI的注入 dependencies 能力来依赖注入 db session

 async def get_db() -> AsyncSession:
     async with SessionLocal() as session:
         yield session

所有都准备好之后我们使用 DB. 在项目中我们是用 token authentication 来控制用户的登录行为, 因此我们需要两个数据表: usersuser_tokens

 # 导入相关的依赖包
 from sqlalchemy.orm import declarative_base
 from sqlalchemy_utils import EmailType, force_auto_coercion, PasswordType
 ​
 # Base 为数据库的基类,我们的 table 需要继承 Base 才能实现 orm 的相关能力
 Base = declarative_base()
 force_auto_coercion()
 ​
 ​
 class User(Base):
     """定义 User class, 目的是实现 orm 能力"""
     __tablename__ = "users"
 ​
     id = Column(Integer, primary_key=True, index=True)
     name = Column(String(50))
     email = Column(EmailType(50), unique=True, nullable=False)
     password = Column(PasswordType(schemes=["pbkdf2_sha512"]), nullable=False)
 ​
     tokens = relationship(
         "UserToken",
         back_populates="user",
         lazy='dynamic',
         cascade="all, delete-orphan",
     )
 ​
 class UserToken(Base):
     """用户 token 的表"""
     __tablename__ = "user_tokens"
 ​
     id = Column(Integer, primary_key=True, index=True)
     user_id = Column(
         Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
     )
     token = Column(
         UUID(as_uuid=True), unique=True, nullable=False, default=uuid.uuid4
     )
     expires = Column(DateTime)
 ​
     user = relationship("User", back_populates="tokens", lazy='joined')Note `force_auto_coercion()`

请注意,我们在模型之前使用了 force_auto_coercion(). 在记录保存到数据库前,确保密码经过哈希处理。

现在我们构建的 web 项目都几乎不在自己手动去添加数据库,而是使用相关的数据库迁移工具。我们将使用alembic来实现这一目的。(如果你的项目中还没有使用 magration的工具,那么建议你赶快用起来)

安装 alembic:

 pip install alembic

初始化 alembic

 alembic init migrations

以上命令将创建带有 env.pyREADMEscript.py.mako 文件的 migrations 目录。

要使 alembic 与我们的数据库配合工作,我们需要更新 env.py 文件。

 import asyncio
 import os
 from logging.config import fileConfig
 ​
 from sqlalchemy import engine_from_config
 from sqlalchemy import pool
 from sqlalchemy.engine import Connection
 from sqlalchemy.ext.asyncio import AsyncEngine
 from alembic import context
 ​
 config = context.config
 ​
 if config.config_file_name is not None:
     fileConfig(config.config_file_name)
 ​
 # Here we importing and specifying our DB metadata
 from app.db.base import Base  
 target_metadata = Base.metadata
 ​
 ​
 # This method returns url of our DB
 def get_url():
     return os.getenv("SQLALCHEMY_DATABASE_URL", "")
 ​
 ​
 def run_migrations_offline() -> None:
     """Run migrations in 'offline' mode.
 ​
     This configures the context with just a URL
     and not an Engine, though an Engine is acceptable
     here as well.  By skipping the Engine creation
     we don't even need a DBAPI to be available.
 ​
     Calls to context.execute() here emit the given string to the
     script output.
 ​
     """
     # Specify which database we use with alembic
     url = get_url()
     context.configure(
         url=url,
         target_metadata=target_metadata,
         literal_binds=True,
         dialect_opts={"paramstyle": "named"},
     )
 ​
     with context.begin_transaction():
         context.run_migrations()
 ​
 ​
 def do_run_migrations(connection: Connection) -> None:
     context.configure(connection=connection, target_metadata=target_metadata)
 ​
     with context.begin_transaction():
         context.run_migrations()
 ​
 ​
 async def run_migrations_online() -> None:
     """Run migrations in 'online' mode.
 ​
     In this scenario we need to create an Engine
     and associate a connection with the context.
 ​
     """
     configuration = config.get_section(config.config_ini_section)
     configuration["sqlalchemy.url"] = get_url()
     connectable = AsyncEngine(
         engine_from_config(
             configuration,
             prefix="sqlalchemy.",
             poolclass=pool.NullPool,
             future=True,
         )
     )
 ​
     async with connectable.connect() as connection:
         await connection.run_sync(do_run_migrations)
 ​
     await connectable.dispose()
 ​
 ​
 if context.is_offline_mode():
     run_migrations_offline()
 else:
     asyncio.run(run_migrations_online())

我们使用下面的命令来创建 migration 文件:

 alembic revision --autogenerate -m "Added required tables"

运行以下命令来应用迁移并更新数据库:

 alembic upgrade head

执行完成后我们的数据库里面的表已经创建好了。

现在我们的数据库已经准备好了,我们可以尝试创建新的用户和令牌:

 from sqlalchemy import select
 from app.db.base import User
 ​
 async def get_user_by_email(db: AsyncSession, email: str) -> User:
     """根据邮箱查找用户,所以数据库中我们的 email 字段需要保证唯一"""
     statement = select(User).where(User.email == email)
     result = await db.execute(statement)
     return result.scalars().first()
 ​
 async def create_user(db: AsyncSession, user: UserCreate) -> User:
     """创建用户"""
     db_user = User(
         email=user.email,
         name=user.name,
         password=user.password,
     )
     db.add(db_user)
     await db.commit()
     await db.refresh(db_user)
     return db_user
 ​
 ​
 async def create_user_token(db: AsyncSession, user: User) -> UserToken:
     """用户登陆成功后,创建用户 token """
     db_token = UserToken(
         user=user, expires=datetime.now() + timedelta(weeks=2)
     )
     db.add(db_token)
     await db.commit()
     return db_token

编写注册新用户的代码:

 from fastapi import APIRouter, FastAPI
 from pydantic import BaseModel
 from app.crud import crud_user
 ​
 app = FastAPI()
 ​
 router = APIRouter()
 ​
 class UserBase(BaseModel):
     email: EmailStr
     name: str
 ​
 class UserCreate(UserBase):
     password: constr(strip_whitespace=True, min_length=8)
 ​
 class User(UserBase):
     id: Optional[int] = None
     token: TokenBase | None = None
 ​
     class Config:
         orm_mode = True

添加注册用户的路由:

 @router.post("/sign-up/", response_model=User)
 async def create_user(user: UserCreate, db: DBSession):
     user_db = await crud_user.get_user_by_email(db, email=user.email)
     if user_db:
         raise HTTPException(status_code=400, detail="User already registered")
     user = await crud_user.create_user(db, user=user)
     user.token = await crud_user.create_user_token(db, user=user)
     return user
 ​
 app.include_router(user_routes)

测试代码

耶!我们已经实现了注册逻辑,最好添加一些测试来检查一切是否如预期那样工作。由于我们使用异步数据库连接,因此需要使用异步测试。因此,我们需要添加一些特殊能力代码:

 import asyncio
 ​
 import pytest
 import pytest_asyncio
 from httpx import AsyncClient
 from sqlalchemy import text
 from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
 from sqlalchemy.orm import sessionmaker
 ​
 from app.api.deps import get_db
 from app.core.config import settings
 from app.db.base import Base
 from app.main import app
 ​
 ​
 @pytest.fixture(scope="session")
 def event_loop() -> asyncio.AbstractEventLoop:
     loop = asyncio.get_event_loop_policy().new_event_loop()
     yield loop
     loop.close()
 ​
 ​
 @pytest.fixture(scope="session")
 def engine():
     engine = create_async_engine(settings.TEST_SQLALCHEMY_DATABASE_URL)
     yield engine
     engine.sync_engine.dispose()
 ​
 ​
 @pytest_asyncio.fixture(scope="session")
 async def prepare_db():
     create_db_engine = create_async_engine(
         settings.POSTGRES_DATABASE_URL,
         isolation_level="AUTOCOMMIT",
     )
     async with create_db_engine.begin() as connection:
         await connection.execute(
             text(
                 "drop database if exists {name};".format(
                     name=settings.TEST_DB_NAME
                 )
             ),
         )
         await connection.execute(
             text("create database {name};".format(name=settings.TEST_DB_NAME)),
         )
 ​
 ​
 @pytest_asyncio.fixture(scope="session")
 async def db_session(engine) -> AsyncSession:
     async with engine.begin() as connection:
         await connection.run_sync(Base.metadata.drop_all)
         await connection.run_sync(Base.metadata.create_all)
         TestingSessionLocal = sessionmaker(
             expire_on_commit=False,
             class_=AsyncSession,
             bind=engine,
         )
         async with TestingSessionLocal(bind=connection) as session:
             yield session
             await session.flush()
             await session.rollback()
 ​
 ​
 @pytest.fixture(scope="session")
 def override_get_db(prepare_db, db_session: AsyncSession):
     async def _override_get_db():
         yield db_session
 ​
     return _override_get_db
 ​
 ​
 @pytest_asyncio.fixture(scope="session")
 async def async_client(override_get_db):
     app.dependency_overrides[get_db] = override_get_db
     async with AsyncClient(app=app, base_url="http://test") as ac:
         yield ac

首先,我们需要为 event_loop fixture 更改作用域。默认情况下它是函数级别的 fixture,但在这种情况下,我们必须将我们的 DB 设为函数级别,这会导致性能问题,而使用会话级别(session scope)可以解决这个问题。

另外,我们添加了 engine fixture 以使用测试数据库而非实际数据库。在 prepare_db 中,我们确保数据库已创建。在 db_session 中,我们创建表格并返回数据库连接。然后在 override_get_db 中更新项目依赖项,以确保测试期间的视图不会使用实际数据库。最后,我们创建了 async_client 来执行对我们API的异步请求。

所有准备工作已完成,现在我们进行测试:

使用 Celery tasks

AsyncIO 适合IO密集型任务。这就是为什么我们使用它来从数据库中读取数据。但如果我们需要执行一些需要大量CPU的任务呢?在这种情况下,我们应该考虑将此任务发送到单独的进程。我们可以查看文档Celery来帮助我们完成这个任务。

在我们的系统中,用户将能够创建帖子。但是,在存储到数据库之前,帖子的内容将被加密。加密是一个CPU密集型任务,因此我们需要使用celery。让我们创建所需的 Model

 from sqlalchemy import Column, ForeignKey, Integer, String, Text
 from sqlalchemy.orm import relationship
 ​
 from app.db.base_class import Base
 ​
 ​
 class UserKeys(Base):
     __tablename__ = "user_keys"
 ​
     id = Column(Integer, primary_key=True, index=True)
     user_id = Column(
         Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
     )
     public_key = Column(String(2000), nullable=False)
     is_revoked = Column(Boolean, default=False)
 ​
     user = relationship("User", back_populates="keys")
 ​
 ​
 class UserGroup(Base):
     __tablename__ = "user_groups"
 ​
     id = Column(Integer, primary_key=True, index=True)
     name = Column(String(50))
 ​
     users = relationship(
         "User",
         secondary="user_group_association",
         back_populates="groups",
     )
     posts = relationship(
         "Post",
         back_populates="user_group",
         cascade="all, delete-orphan",
     )
 ​
 ​
 class UserGroupAssociation(Base):
     __tablename__ = "user_group_association"
 ​
     id = Column(Integer, primary_key=True)
     user_id = Column(Integer, ForeignKey("users.id"))
     group_id = Column(Integer, ForeignKey("user_groups.id"))
 ​
 ​
 class Post(Base):
     __tablename__ = "posts"
 ​
     id = Column(Integer, primary_key=True, index=True)
     title = Column(String(100))
     content = Column(Text)
     user_id = Column(
         Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
     )
     group_id = Column(
         Integer,
         ForeignKey("user_groups.id", ondelete='CASCADE'),
         nullable=False,
     )
 ​
     author = relationship("User", back_populates="posts")
     user_group = relationship("UserGroup", back_populates="posts")
     keys = relationship(
         "PostKeys",
         back_populates="post",
         cascade="all, delete-orphan",
     )
 ​
 ​
 class PostKeys(Base):
     __tablename__ = "post_keys"
 ​
     id = Column(Integer, primary_key=True, index=True)
     post_id = Column(
         Integer,
         ForeignKey("posts.id", ondelete='CASCADE'),
         nullable=False,
     )
     public_key_id = Column(
         Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
     )
     encrypted_key = Column(Text)
 ​
     post = relationship("Post", back_populates="keys")
     public_key = relationship("UserKeys")

每个用户都有自己的公钥/私钥对。他将公钥上传到服务器,并保持私钥的机密性。还有用户组 . 每个用户可以参与不同的组,但每篇帖子只能附加到一个特定的组中。因此,只有该组成员才能阅读帖子内容。

当添加新帖子时,系统会生成临时密钥,并使用该密钥加密帖子的内容,然后为每个群组成员将临时密钥用用户的公钥加密。当用户从服务器获取帖子时,他会收到使用他的公钥加密的加密内容和临时密钥。他可以使用私有秘钥解密临时秘钥,并用它来解密帖子的内容。让我们看一下代码。

 from pydantic import BaseModel
 ​
 class PostBase(BaseModel):
     title: str
     content: str
     group_id: int
 ​
 class PostInDBBase(PostBase):
     id: Optional[int] = None
     class Config:
         orm_mode = True
 ​
 async def create_post(db: AsyncSession, post: PostBase, author: User) -> Post:
     db_post = Post(
         title=post.title,
         content=post.content,
         group_id=post.group_id,
         author=author,
     )
     db.add(db_post)
     await db.commit()
     await db.refresh(db_post)
     return db_post
 ​

添加发帖路由代码:

 @router.post("/posts/", response_model=PostInDBBase, status_code=201)
 async def create_post(
     post: PostBase,
     db: DBSession,
     current_user: CurrentUser,
 ):
     plain_content = post.content
     post.content = ""
     post = await create_post(
         db=db,
         post=post,
         author=current_user,
     )
     encrypt_post_content.delay(post_id=post.id, content=plain_content)
     return post
 ​

这是一个视图,它接收帖子并将其保存到数据库。这里最有趣的部分是encrypt_post_content.delay()方法。实际上,这是一个Celery任务,将在单独的进程中执行。就是这样:

 import os
 ​
 from celery import Celery
 from cryptography.hazmat.primitives.serialization import load_pem_public_key
 from sqlalchemy import create_engine, select, update
 from sqlalchemy.orm import sessionmaker
 ​
 from app.core.crypto_tools import (
     asymmetric_encryption,
     generate_symmetric_key,
     symmetric_encryption,
 )
 from app.db.base import Post, PostKeys, User, UserGroup, UserKeys
 from app.core.config import settings
 ​
 celery = Celery("secureblogs")
 celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL")
 ​
 ​
 sync_engine = create_engine(settings.SYNC_SQLALCHEMY_DATABASE_URL, echo=True)
 SyncSessionLocal = sessionmaker(
     autocommit=False,
     autoflush=False,
     bind=sync_engine,
 )
 ​
 ​
 @celery.task(name="encrypt_post_content")
 def encrypt_post_content(post_id: int, content: str):
     # generate temp key and encrypt content
     key = generate_symmetric_key()
     encrypted_content = symmetric_encryption(content, key)
     with SyncSessionLocal() as db:
         # update post instance
         post_statement = (
             update(Post)
             .returning(Post.group_id)
             .where(Post.id == post_id)
             .values(content=encrypted_content)
         )
         post = db.execute(post_statement).fetchone()
 ​
         # fetch user's public keys from DB
         users_subquery = (
             select(User.id)
             .where(User.groups.any(UserGroup.id.in_([post.group_id])))
             .subquery()
         )
         statement = select(UserKeys).where(
             (UserKeys.user_id.in_(users_subquery))
             & (UserKeys.is_revoked == False)
         )
         public_keys = db.execute(statement).scalars().all()
         db_post_keys = []
         for public_key in public_keys:
             # Save generated keys in DB
             public_pem_data = public_key.public_key
             public_key_object = load_pem_public_key(public_pem_data.encode())
             encrypted_key = asymmetric_encryption(key, public_key_object)
             db_post_keys.append(
                 PostKeys(
                     post_id=post_id,
                     public_key_id=public_key.id,
                     encrypted_key=encrypted_key,
                 )
             )
         db.bulk_save_objects(db_post_keys)
         db.commit()

Celery 任务只是一个同步的Python函数,因此为了执行数据库查询,我们在其中使用同步数据库会话。

Websockets

我们的API允许创建加密的帖子。但是等等,它只适用于现有用户。如果新用户加入群并想阅读一些帖子,该怎么办?他不能这样做,因为他不能解密临时密钥。他需要那个创建帖子的人给他寄临时钥匙。这里的websocket可能非常方便。当用户要求帖子访问时,我们会向帖子的作者发送实时通知。Post的作者收到通知并决定是批准还是拒绝请求。我们来实现它。首先添加新的 model。它包含关于请求帖子访问的用户、帖子本身和用户公钥的信息:

class ReadPostRequest(Base):
    __tablename__ = "read_post_request"

    id = Column(Integer, primary_key=True, index=True)
    post_id = Column(
        Integer,
        ForeignKey("posts.id", ondelete='CASCADE'),
        nullable=False,
    )
    user_id = Column(
        Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
    )
    public_key_id = Column(
        Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
    )

    post = relationship("Post")
    requester = relationship("User")
    public_key = relationship("UserKeys")

现在我们需要允许创建新请求的API:

async def get_post(db: AsyncSession, post_id: int) -> Post:
    statement = select(Post).where(Post.id == post_id)
    result = await db.execute(statement)
    return result.scalars().first()

async def get_user_key(
    db: AsyncSession,
    user: User,
) -> UserKeys:
    statement = select(UserKeys).where(
        (UserKeys.user == user) & (UserKeys.is_revoked == False)
    )
    result = await db.execute(statement)
    return result.scalars().first()

async def add_read_post_request(
    db: AsyncSession, user: User, post_id: int
) -> ReadPostRequest:
    exists_statement = select(ReadPostRequest.id).where(
        (ReadPostRequest.user_id == user.id)
        & (ReadPostRequest.post_id == post_id)
    )
    result = await db.execute(exists_statement)
    if result.scalars().first():
        return None
    public_key_statement = select(UserKeys).where(
        (UserKeys.is_revoked == False) & (UserKeys.user_id == user.id)
    )
    result = await db.execute(public_key_statement)
    if not (public_key := result.scalars().first()):
        return None
    db_read_post_request = ReadPostRequest(
        user_id=user.id,
        post_id=post_id,
        public_key=public_key,
    )
    db.add(db_read_post_request)
    await db.commit()
    await db.refresh(db_read_post_request)
    return db_read_post_request

添加请求路由:

@router.post("/posts/{post_id}/request_read/", status_code=204)
async def add_read_post_request(
    post_id: int,
    db: DBSession,
    current_user: CurrentUser,
):
    post = await crud_post.get_post(db, post_id)
    if not post:
        raise HTTPException(status_code=404)
    user_key = await crud_user.get_user_key(db, current_user)
    request = await crud_post.add_read_post_request(db, current_user, post_id)
    if request:
        await ws_manager.send_personal_message(
            {
                'request_id': request.id,
                'post_id': post_id,
                'requested_user_id': current_user.id,
                'user_public_key': user_key.public_key,
            },
            post.user_id,
        )

除了最后一行之外,这里没有什么新内容。我们只是检查帖子是否真的存在于数据库中。然后我们获取用户的公钥,创建新请求,最后发送 WebSocket 通知。 现在让我们更仔细地看看我们是如何实现这个功能的。

from fastapi import WebSocket


class ConnectionManager:
    def __init__(self):
        self.active_connections: dict[int, WebSocket] = {}

    async def connect(self, user_id: int, websocket: WebSocket):
        await websocket.accept()
        self.active_connections[user_id] = websocket

    def disconnect(self, user_id: int):
        self.active_connections.pop(user_id)

    async def send_personal_message(self, message: dict, user_id: int):
        if websocket := self.active_connections.get(user_id):
            await websocket.send_json(message)



ws_manager = ConnectionManager()

这是我们的 WebSocket 管理器。在这里,我们有一个字典,用于保存用户ID及与每个用户ID相关联的WebSocket连接。当有人想要发送个人通知时,会使用 send_personal_message

最后让我们看一下如何创建新的WebSocket连接。

from typing import Annotated

from fastapi import (
    APIRouter,
    Depends,
    Query,
    status,
    WebSocket,
    WebSocketDisconnect,
    WebSocketException,
)

from app.api.deps import DBSession
from app.api.websockets.managers import ws_manager
from app.crud.crud_user import get_user_by_token


router = APIRouter()


async def get_token(
    websocket: WebSocket,
    token: Annotated[str | None, Query()] = None,
):
    if token is None:
        raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
    return token

添加路由请求:

@router.websocket("/ws/post_request")
async def websocket_endpoint(
    websocket: WebSocket,
    db: DBSession,
    token: Annotated[str, Depends(get_token)],
):
    user = await get_user_by_token(db, token)
    if not user:
        raise WebSocketException(code=status.HTTP_401_UNAUTHORIZED)
    try:
        await ws_manager.connect(user.id, websocket)
        await ws_manager.send_personal_message(
            {"message": "connection accepted"},
            user.id,
        )
        while True:
            await websocket.receive_text()
    except WebSocketDisconnect:
        ws_manager.disconnect(user.id)

我们创建了新的 websocket· API /ws/post_request。此 API 会检查用户的令牌,如果令牌有效,则会创建新连接并发送给用户确认信息 connection accepted

总结

以上内容我们以用户发帖和浏览帖子为程序主要内容分别介绍了下面的内容:

  • 使用 FastAPI 程序
  • FastAPI 中使用 async SQLAlchemy 数据库能力。
  • FastAPI 中使用 celery 做任务队列的异步。
  • FastAPI 中使用 websockets 做信息通知。
转载自:https://juejin.cn/post/7369117524482015282
评论
请登录