likes
comments
collection
share

一文讲透如何二次封装 FastAPI 框架,看完直呼 "Pythonic"

作者站长头像
站长
· 阅读数 8

注:本文并不包含项目全部代码,查看全部代码请直接跳转到 Github

鉴于 FastAPIFlask 框架一样保留了足够多的扩展性,无法做到在企业级项目中开箱即用,本文主要讲述了如何对 FastAPI 框架进行二次封装。以下内容主要分为 8 个部分,分别是统一接口返回、全局异常处理、自定义上下文、用户鉴权、单元测试、多环境配置、数据迁移、CLI启动等。看完之后如果您有更好的建议,欢迎到下方留言,或者到 Github 发起 PR。

统一接口返回

一般来讲,返回给前端的接口数据都是统一格式,例如:

{
  "success": true,
  "msg": "",
  "data": {}
}

我们定义两个方法,请求成功的结构如下:

def success_response(data=''):
    new_body = {'success': True, 'msg': '', 'data': data}
    return new_body

请求失败的结构如下:

def failed_response(error_type, error_message, error_data=None):
    """failed response
    """
    new_body = {
        'success': False,
        'error_type': error_type,
        'msg': error_message,
        'data': ''
    }
    if error_data is not None:
        new_body['data'] = error_data
    return JSONResponse(new_body)

如果每次在接口返回时都去调用这些方法是一件非常麻烦的事,本着 DRY 的原则,开始我们初步的封装。 FastAPIresponse_model 参数需要指定模型,所以先定义 APIResponse 类:

class APIResponse(GenericModel, Generic[T]):
    success: bool
    msg: str
    data: T = None

然后自定义新的路由函数,用来封装我们的接口:

def route(
    router: APIRouter,
    path: str,
    methods: List[str],
    response_model=None,
    **options
):
    common_response_model = APIResponse[response_model]

    def wrapper(func: Callable[..., Any]):

        async def decorator(*args, **kwargs):
            response = await func(*args, **kwargs)
            if isinstance(response, Response):
                # The response may have already been wrapped, this situation should be ignored.
                return response

            return success_response(response)

        signature = inspect.signature(func)
        decorator.__signature__ = signature
        decorator.__name__ = func.__name__
        decorator.__doc__ = func.__doc__
        router.add_api_route(
            path,
            endpoint=decorator,
            response_model=common_response_model,
            methods=methods,
            **options
        )
        return decorator

    return wrapper

使用新的路由,一个统一的返回模型就做好了:

@route(router, '/', ['GET'], response_model=List[PersonOut])
async def list_examples(sa_session: AsyncSession = Depends(get_db_session)):
    persons = await sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

现在每次使用 @route 还需要编写 GET 等参数,再次本着 DRY 原则,拿出大名鼎鼎的 functools 库,我们创建新的方法:

get = functools.partial(route, methods=['GET'])

上面的代码最终被改写为:

@get(router, '/', response_model=List[PersonOut])
async def list_examples(sa_session: AsyncSession = Depends(get_db_session)):
    persons = await sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

统一异常处理

FastAPI 已经提供了统一的异常处理,我们只需要简单的封装。

定义最顶层的异常类:

class APIException(Exception):
    """api 异常
    """
    error_type = 'api_error'
    error_message = 'A server error occurred.'

    def __init__(self, error_type=None, error_message=None):
        if error_type is not None:
            self.error_type = error_type
        if error_message is not None:
            self.error_message = error_message

    def __repr__(self):
        return '<{} {}: {}>'.format(
            self.__class__, self.error_type, self.error_message
        )

创建统一异常处理类:

class GlobalExceptionHandler:

    def __init__(self, app: FastAPI):
        self.app = app

    @staticmethod
    async def handle_api_exception(request: Request, error: APIException):
        return failed_response(
            error_type=error.error_type, error_message=error.error_message
        )

    @staticmethod
    async def handle_exception(request: Request, error: Exception):
        logger.error(f'{request.url} {error}')
        return failed_response(error_type='server_error', error_message='Server error')

    def init(self):
        self.app.add_exception_handler(
            RequestValidationError, self.handle_request_validation_error
        )
        self.app.add_exception_handler(APIException, self.handle_api_exception)
        self.app.add_exception_handler(Exception, self.handle_exception)

然后创建一个自定义异常:

class PersonNotFound(APIException):
    error_type = 'person_not_found'
    error_message = 'Person not found'

最后只需要在代码里面抛出异常:

@get(router, '/{first_name}', response_model=PersonOut)
async def get_person(first_name: str, sa_session: AsyncSession = Depends(get_db_session)):
    person = await sa_session.get(Person, first_name)
    if not person:
        raise PersonNotFound
    return person

现在已经统一处理了接口成功和失败的响应,这已经为我们省去了大量的重复编码工作,但是这还远远不够,我们还需要让这个项目更加适合日常开发。

自定义上下文

日常开发中,我们需要使用到各种三方库,例如 Redis, Mysql , Kafka 等,得益于 FastAPIDepends 功能,使得我们引用这些中间件变得非常简单,但还是本着 DRY 的原则,我们把这些代码都封装起来。

我个人最喜欢 FastAPI 框架的一个点就是可以有效解决 PythonPycharm 里面不提示类型的问题,所以我们先定义一个类:

class AppContext:

    def __init__(self):
        self.request: Request | None = None
        self.response: Response | None = None
        self.sa_session: AsyncSession | None = None
        self.redis: Redis | None = None

创建上下文的获取方法:

async def get_app_ctx(
    request: Request,
    response: Response,
    sa_session: AsyncSession = Depends(get_db_session),
    redis: Redis = Depends(get_redis)
) -> AppContext:

    context = AppContext()
    context.request = request
    context.response = response
    context.sa_session = sa_session
    context.redis = redis
    return context

然后直接使用:

@get(router, '/', response_model=List[PersonOut])
async def list_examples(context: AppContext = Depends(get_app_ctx)):
    persons = await context.sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

但是我还是觉得不太方便,所以进一步封装:

DependsOnContext: AppContext = cast(AppContext, Depends(get_app_ctx))

最终把上面的参数改写为:

@get(router, '/', response_model=List[PersonOut])
async def list_examples(context: AppContext = DependsOnContext):
    persons = await context.sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

现在我们引入任何的三方库都可以在这个上下文里面去添加,后续使用起来就很方便啦。

获取当前用户

企业项目往往在网关中处理用户鉴权,传递到后端的已经是被解析出来的用户数据。但是在这里我们还是来模拟下基于 JWT 编码和解码的过程,获取当前用户信息。

登录接口如下:

@post(router, '/login', response_model=LoginOut)
async def login(user_in: UserIn, context: AppContext = DependsOnContext):
    user = await context.sa_session.scalar(
        select(User).where(and_(User.username == user_in.username))
    )
    if not user:
        raise UserNotFound

    if not await user.verify_password(user_in.password):
        raise AccountOrPasswordWrong

    token = await generate_token(UserOut.from_orm(user).dict())
    await context.redis.set(
        f'{config.SERVICE_NAME}:user:token:{user.id}', token,
        config.EXPIRED_SECONDS
    )

    return {'token': token, 'user': user}

自定义 CurrentUser 模型:

class CurrentUser(BaseModel):
    id: int
    username: str

假设前端拿到 token 后,后面的每次请求都在 header 中携带 token

async def get_current_user(
    token: str = Header('token'), redis: Redis = Depends(get_redis)
):
    if not token:
        raise ApiSignatureExpired

    payload = await parse_token(token)

    if 'id' in payload:
        try:
            current_user = CurrentUser.parse_obj(payload)
        except ValidationError:
            raise JwtTokenError

        cache_token = await redis.get(
            f'{config.SERVICE_NAME}:user:token:{current_user.id}'
        )
        if cache_token != token:
            raise ApiSignatureExpired

        return current_user

    raise ApiSignatureExpired

自定义 Depends

DependsOnUser: CurrentUser = cast(CurrentUser, Depends(get_current_user))

然后就可以获取当前登录用户信息了:

@get(router, '/', response_model=UserOut)
async def get_user_detail(current_user: CurrentUser = DependsOnUser):
    return current_user

单元测试

真实世界的项目开发中,为了保证代码的质量,自然少不了编写单元测试,在这里我对 pytest 框架就不再做过多的赘述。我在封装该模块的时候,考虑的最多的是保证单元测试的一致性。

所以在这里我覆盖了原本的数据库 session,保证每次测试运行之后回滚数据:

@pytest.fixture
async def session(app: FastAPI) -> AsyncSession:
    async with engine.begin() as connection:
        async_session_local = async_sessionmaker(
            bind=connection,
            autoflush=False,
            future=True,
            autocommit=False,
            expire_on_commit=False
        )
        async_session = async_session_local()
        # Overwrite the current database so that every time the test is run, the transaction is rollback
        app.dependency_overrides[get_db_session] = lambda: async_session
        yield async_session
        await async_session.close()

        await connection.rollback()

测试用例的编写:

async def test_list_examples(client: AsyncClient):
    response = await client.get('/v1/examples/')
    assert response.status_code == 200

    json_result = response.json()
    assert len(json_result['data']) == 0

    # 创建
    response = await client.post(
        '/v1/examples/', json={
            'first_name': 'test',
            'last_name': 'test'
        }
    )
    assert response.status_code == 200

    response = await client.post(
        '/v1/examples/', json={
            'first_name': 'test2',
            'last_name': 'test2'
        }
    )
    assert response.status_code == 200

    response = await client.get('/v1/examples/')
    assert response.status_code == 200

    json_result = response.json()
    assert len(json_result['data']) == 2

多环境配置

一般来讲我们需要区分开发、测试、预发布、正式等环境,每个环境的数据库、秘钥等配置都各不相同,在这里我通过环境变量来导入不同的配置文件。

@lru_cache(maxsize=1)
def get_config(config_name: str = None) -> BaseConfig:
    """
    if config name is none, get active profile from env
    """
    if not config_name:
        config_name = get_active_env()

    configs_module = importlib.import_module('configs')
    config_class = getattr(configs_module, config_name.capitalize())
    return config_class()

如果环境变量值为 Development,则导入的就是 Development 类下的配置文件:

class Development(BaseConfig):
    STAGE: str = 'dev'

class Docker(BaseConfig):
    STAGE = 'docker'

    DB_HOST = 'mysql'
    REDIS_URL = 'redis://redis:6379/0'

class Testing(BaseConfig):
    STAGE: str = 'test'

    # logger config
    LOGGING_LEVEL: str = 'DEBUG'

    # db config
    DB_DATABASE = 'example_test'
    DB_ENABLE_ECHO = False

class Production(BaseConfig):
    STAGE: str = 'prod'
    DEBUG: bool = False

    # logger config
    LOGGING_LEVEL: str = 'INFO'

数据迁移

这里数据迁移是基于 alembic 实现的,它能够与 Sqlachelmy 很好的集成,只需要简单的编码:

def migrate(commit):
    path = Path('models', 'migrations', 'versions')
    if path.exists():
        has_versions = any(
            filter(lambda _dir: _dir.name.endswith('.py'), path.iterdir())
        )
    else:
        path.mkdir()
        has_versions = False

    revision_args = ['revision', '--autogenerate', '-m']
    if has_versions is False:
        revision_args.append('"init db"')
    else:
        if commit:
            revision_args.append(f'"{commit}"')
        else:
            revision_args.append('"update"')

    alembic.config.main(argv=revision_args)

    migrate_args = ['--raiseerr', 'upgrade', 'head']
    alembic.config.main(argv=migrate_args)

CLI

看到这里,我们的项目不单单要运行 FastAPI 应用,还可能要运行测试、数据迁移等,这个时候单一的程序入口已经不能满足最基本的要求,所以我基于 click 框架对项目入口进一步封装。

@click.group()
@click.version_option(version='1.0.0')
def cli():
    """CLI management for FastAPI project
    """

添加项目启动方法:

@cli.command('start')
def start():
    uvicorn.run(
        app,
        host=config.HOST,
        port=config.PORT,
        debug=config.DEBUG,
        log_config=config.log_config()
    )

添加测试运行方法,并支持运行单个测试:

@cli.command('test')
@click.argument('test_names', required=False, nargs=-1)
def test(test_names):
    import pytest

    args = config.PYTEST_ARGS
    if test_names:
        args.extend(test_names)
    pytest.main(args)

添加数据迁移方法,并支持自定义 commit

@cli.command('migrate')
@click.argument('commit', required=False, nargs=-1)
def migrate(commit):
    ...

现在我们可以在命令行中执行以下命令:

# 启动项目
python manage.py start

# 运行全部测试
python manage.py test

# 运行单个测试
python manage.py test unittests/test_example.py::test_list_examples2

# 数据迁移
python manage.py migrate

总结

感谢您阅读到这里,目前为止已经涵盖了日常开发最常见的一部分,但还不足以包含企业级项目开发中的所有情况,所以我热烈欢迎您可以来继续完善这个项目,以及在评论区和我积极讨论,谢谢!