Fastapi框架-使用过程填坑篇(6)- 吐槽Request请求上下文对象只能被消费一次的问题和解决思路,竟然不能再中间件里面再消费一次~唉
1:吐槽fastapi的中间件设计
在搭建我自己的脚手架的过程中,我自己搭建的一个模板示例如下的结构.
但是在使用中间件的过程其实遇到过一个很无爱无感的体验,就是我们的Fastapi提供的所谓的中间件的自定义其实如果你把它应用到我们的中间件上的话,很抱歉,你将会永远的到一个“卡巴斯基”的效果,就是一直再等待,整个程序会永远一直无法执行下去。
如果你仅仅是把中间件用于设计我们的鉴权,还能很好进行的工作,当然前提基础是你的token必须是放置在我们的请求头里面进行传递,如果是使用其他的方式的进行如:get或post等的话,很抱歉!基于中间件设计的鉴权模式也会遇到上面说的一种死等待的效果。
现象1:路由注册方式有app和APIRouter进行注册
具体的实践示例如:
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.requests import Request
from starlette.middleware.base import BaseHTTPMiddleware
app = FastAPI()
class LogerMiddleware1(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 调用一次 perf_counter(),从计算机系统里随机选一个时间点A,计算其距离当前时间点B1有多少秒。
# 当第二次调用该函数时,默认从第一次调用的时间点A算起,距离当前时间点B2有多少秒。
# 两个函数取差,即实现从时间点B1到B2的计时功能。
# logger.info("111111111111111111111111111")
print("我是ProcesMiddleware---1----中间件!!!!!!!!!", request.headers)
print("ProcesMiddleware---1----中间件!中的request的ID是:", id(request))
response = await call_next(request)
return response
class AuthMiddleware2(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 调用一次 perf_counter(),从计算机系统里随机选一个时间点A,计算其距离当前时间点B1有多少秒。
# 当第二次调用该函数时,默认从第一次调用的时间点A算起,距离当前时间点B2有多少秒。
# 两个函数取差,即实现从时间点B1到B2的计时功能。
# logger.info("111111111111111111111111111")
print("我是ProcesMiddleware---2----中间件!!!!!!!!!", request.headers)
print("ProcesMiddleware---2----中间件!中的request的ID是:", id(request))
response = await call_next(request)
return response
from fastapi import APIRouter
bp = APIRouter(tags=['权限管理'], prefix='/api/v1')
class AdminLogin(BaseModel):
"""
管理员登陆 提交的参数信息
"""
password: str = None
username: str = None
@bp.post("/one")
async def sum_numbers(req: Request):
print('one接口内的:', id(req))
return 'one'
@bp.post("/two")
async def sum_numbers(req: Request, sdsada: AdminLogin):
print('two接口内的:', id(req))
return 'two'
@app.post("/inapprsp")
async def inapprsp(req: Request, sdsada: AdminLogin):
print('inapprsp-接口内的:', id(req))
return 'inapprsp'
# 注册中间件
app.add_middleware(ProcesMiddleware1)
app.add_middleware(ProcesMiddleware2)
# 添加路由组
app.include_router(bp)
if __name__ == '__main__':
uvicorn.run(
app=app,
host="127.0.0.1",
port=5586,
workers=1,
reload=False,
debug=False,
)
对应接口显示如下:
上面相关的接口的请求访问都是没问题的!都能正常的运行!
但是现在问题引发出来了!
首先上面定义了两个中间件,
- 假设第一个log中间件中需要读取它自己传入的req: Request的请求提信息,如json和body的内容(不是查询参数的内容query_params)
- 第2个的认证中间件,也需要读取它自己传入的从req: Request的请求提信息
如修改后的中间内容为:
主要是新增了读取请全体内容的代码:
_body = await request.body()
当然你也可以读取json()--但是再没有值的情况下直接的读取json()会异常!
此时我们再去请求我们的接口的时候(任何接口,甚至是启动的时候)你会惊讶!全线崩溃!无一幸免!
甚至想启动访问,就一直卡再那了!
当取消请求之后就会:
async for chunk in self.stream():
File "C:\Users\mayn\.virtualenvs\fastapi_5g_msg\lib\site-packages\starlette\requests.py", line 214, in stream
raise ClientDisconnect()
starlette.requests.ClientDisconnect
什么玩意!所以说明了一个问题!!!这个使用中间件获取请求提的问题,非常严重!!!
现象2:路由注册方式仅仅使用app(不使用APIRouter)
修改代码为:
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.requests import Request
from starlette.middleware.base import BaseHTTPMiddleware
app = FastAPI()
class LogerMiddleware1(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 调用一次 perf_counter(),从计算机系统里随机选一个时间点A,计算其距离当前时间点B1有多少秒。
# 当第二次调用该函数时,默认从第一次调用的时间点A算起,距离当前时间点B2有多少秒。
# 两个函数取差,即实现从时间点B1到B2的计时功能。
# logger.info("111111111111111111111111111")
print("我是ProcesMiddleware---1----中间件!!!!!!!!!", request.headers)
_body = await request.body()
# print("我是ProcesMiddleware---1----中间件!我读取JSON内容信息:!!!!!!!!", _body)
print("ProcesMiddleware---1----中间件!中的request的ID是:", id(request))
response = await call_next(request)
return response
class AuthMiddleware2(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 调用一次 perf_counter(),从计算机系统里随机选一个时间点A,计算其距离当前时间点B1有多少秒。
# 当第二次调用该函数时,默认从第一次调用的时间点A算起,距离当前时间点B2有多少秒。
# 两个函数取差,即实现从时间点B1到B2的计时功能。
# logger.info("111111111111111111111111111")
print("我是ProcesMiddleware---2----中间件!!!!!!!!!", request.headers)
_body = await request.body()
# print("我是ProcesMiddleware---2----中间件!我读取JSON内容信息:!!!!!!!!", _body)
print("ProcesMiddleware---2----中间件!中的request的ID是:", id(request))
response = await call_next(request)
return response
from fastapi import APIRouter
bp = APIRouter(tags=['权限管理'], prefix='/api/v1')
class AdminLogin(BaseModel):
"""
管理员登陆 提交的参数信息
"""
password: str = None
username: str = None
@app.post("/one")
async def sum_numbers(req: Request):
print('one接口内的:', id(req))
return 'one'
@app.post("/two")
async def sum_numbers(req: Request, sdsada: AdminLogin):
print('two接口内的:', id(req))
return 'two'
@app.post("/inapprsp")
async def inapprsp(req: Request, sdsada: AdminLogin):
print('inapprsp-接口内的:', id(req))
return 'inapprsp'
# 注册中间件
app.add_middleware(LogerMiddleware1)
app.add_middleware(AuthMiddleware2)
# 添加路由组
app.include_router(bp)
if __name__ == '__main__':
uvicorn.run(
app=app,
host="127.0.0.1",
port=5586,
workers=1,
reload=False,
debug=False,
)
主要区别有:
-
1:路由注册我修改了,仅仅使用app
-
2:中间件那有且只有一个开启获取的请求提信息的时候:
此时再去请求我们的接口: 有时候可以请求成功,有时候又和上面情况一下!有点奇葩!!!
问题尝试解决,修改中间的定义:
完整的示例代码:
#!/usr/bin/evn python
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
文件名称 : ceshi
文件功能描述 : 功能描述
创建人 : 小钟同学
创建时间 : 2021/6/9
-------------------------------------------------
修改描述-2021/6/9:
-------------------------------------------------
"""
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.requests import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
# Init server context
app = FastAPI()
class CommomMiddleware:
def __init__(
self,
app: ASGIApp,
):
self._app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self._app(scope, receive, send)
return
# Create Request
request = Request(scope, receive=receive)
print("新的中间件!!!开始!!我的request-ID 信息是:", id(request))
# After Request
# span = await self._create_span(request=request)
scope.setdefault("trace_ctx", '33333333333333333333333333333333')
# 如果这里再消费的话,也是无法继续执行下去
# sada =await request.body()
status_code = 500
async def wrapped_send(message: Message) -> None:
if message['type'] == 'http.response.start':
nonlocal status_code
status_code = message['status']
await send(message)
try:
await self._app(scope, receive, wrapped_send)
except Exception as err:
pass
print("新的中间件!!!结束!!")
class ProcesMiddleware1(BaseHTTPMiddleware):
'''
耗时的中间件的
'''
async def dispatch(self, request: Request, call_next):
# 调用一次 perf_counter(),从计算机系统里随机选一个时间点A,计算其距离当前时间点B1有多少秒。
# 当第二次调用该函数时,默认从第一次调用的时间点A算起,距离当前时间点B2有多少秒。
# 两个函数取差,即实现从时间点B1到B2的计时功能。
# logger.info("111111111111111111111111111")
print("我是ProcesMiddleware---1----中间件!!!!!!!!!", request.headers)
print("ProcesMiddleware---1----中间件!中的request的ID是:", id(request))
print("ProcesMiddleware--1-----中间件!中的request:scope", request.scope.get('trace_ctx'))
response = await call_next(request)
return response
class ProcesMiddleware2(BaseHTTPMiddleware):
'''
耗时的中间件的
'''
async def dispatch(self, request: Request, call_next):
# 调用一次 perf_counter(),从计算机系统里随机选一个时间点A,计算其距离当前时间点B1有多少秒。
# 当第二次调用该函数时,默认从第一次调用的时间点A算起,距离当前时间点B2有多少秒。
# 两个函数取差,即实现从时间点B1到B2的计时功能。
# logger.info("111111111111111111111111111")
print("我是ProcesMiddleware---2----中间件!!!!!!!!!", request.headers)
print("ProcesMiddleware---2----中间件!中的request的ID是:", id(request))
print("ProcesMiddleware--2-----中间件!中的request:scope", request.scope.get('trace_ctx'))
response = await call_next(request)
return response
from fastapi import APIRouter
bp = APIRouter(tags=['权限管理'], prefix='/api/v1')
# bp.route_class=ValidationErrorLoggingRoute
class AdminLogin(BaseModel):
"""
管理员登陆 提交的参数信息
"""
password: str = None
username: str = None
@bp.post("/one")
async def sum_numbers(req: Request):
print('one接口内的:', id(req))
return 'one'
@bp.post("/two")
async def sum_numbers(req: Request, sdsada: AdminLogin):
print('two接口内的:', id(req))
return 'two'
@app.post("/inapprsp")
async def inapprsp(req: Request, sdsada: AdminLogin):
print('inapprsp-接口内的:', id(req))
return 'inapprsp'
app.add_middleware(ProcesMiddleware2)
app.add_middleware(ProcesMiddleware1)
# 注意必须注册再最后面哟
app.add_middleware(CommomMiddleware)
app.include_router(bp)
if __name__ == '__main__':
uvicorn.run(
app=app,
host="127.0.0.1",
port=5586,
workers=1,
reload=False,
debug=False,
)
上面代码,只是能解决了,自定义的一个新的中间中对我们的request.scope一种方案:
scope.setdefault("trace_ctx", '33333333333333333333333333333333')
然后我们的所有其他内层内其他中间件都可以从
获取到值,这个可以给后续的提供了一个思路!
但是如果一旦继续想消费我们的body的话,一样的是会引发阻塞!!!!死等待!
问题分析,分析对象内存地址:
归根结底,我们注释上面的获取请求提的代码,查看他们的req: Request的内存地址,可以看得到一个问题: 如访问:http://127.0.0.1:5586/inapprsp
它们的中间件传递都不是同一个对象,但是他们的请求头的内容是可以传递到每一个中间件里面,神奇吧:
于是乎我翻了相关的issess,作者自己都承认,这个存在问题,github.com/tiangolo/fa…
如作者的回复:
按作者的回复的意思:
- 一旦我们的req: Request被消费过一次后,如果不在函数的作用域的范围内的话,则就没了!
- 被消费的场景包括读取body的时候,如定义我们的basemodel
2:根据作者的建议使用自定义路由的方式解决
吐槽归吐槽吧,但是Fastapi的设计也还是不错滴,只是希望,未来它可以做到不会因为一次的消费就失效的问题! 像sanic框架的中间件它就很好的可以进行传递!
作者给出的方案如下:
import time
from typing import Callable
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.routing import APIRoute
class TimedRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
before = time.time()
# 这里可以获取的我们的请求的体的信息----
response: Response = await original_route_handler(request)
# 下面可以处理我们的响应体的报文信息
duration = time.time() - before
response.headers["X-Response-Time"] = str(duration)
print(f"route duration: {duration}")
print(f"route response: {response}")
print(f"route response headers: {response.headers}")
return response
return custom_route_handler
app = FastAPI()
router = APIRouter(route_class=TimedRoute)
@app.get("/")
async def not_timed():
return {"message": "Not timed"}
@router.get("/timed")
async def timed():
return {"message": "It's the time of my life"}
app.include_router(router)
然后我们的修改一下我们的上面的custom_route_handler里面的话,就可以解决我们的日志问题,但是要注意的一点是需要进行替换:
router = APIRouter(route_class=TimedRoute)
或者
router = APIRouter()
router.route_class=TimedRoute
在我自己的脚手架框架里面是进行批量的导入注册的:
最终我自己的一个基于自定义的APIRouter日志记录行为代码如下:
from time import perf_counter
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import StreamingResponse
import uuid
import json
from apps.utils import json_helper
from apps.utils import aiowrap_helper
from fastapi import APIRouter, FastAPI, Request, Response, Body
from fastapi.routing import APIRoute
from typing import Callable, List
from fastapi.responses import Response
class ContextIncludedRoute(APIRoute):
@staticmethod
def sync_trace_add_log_record(request: Request, event_des='', msg_dict={}, remarks=''):
'''
:param event_des: 日志记录事件描述
:param msg_dict: 日志记录信息字典
:param remarks: 日志备注信息
:return:
'''
# request.request_links_index = request.request_links_index + 1
if hasattr(request, 'traceid'):
log = {
# 自定义一个新的参数复制到我们的请求上下文的对象中
'traceid': getattr(request, 'traceid'),
# 定义链路所以序号
'trace_index': getattr(request, 'trace_links_index'),
# 日志时间描述
'event_des': event_des,
# 日志内容详情
'msg_dict': msg_dict,
# 日志备注信息
'remarks': remarks
}
# 为少少相关记录,删除不必要的为空的日志内容信息,
if not remarks:
log.pop('remarks')
if not msg_dict:
log.pop('msg_dict')
try:
log_msg = json_helper.dict_to_json_ensure_ascii(log) # 返回文本
logger.info(log_msg)
except:
logger.info(getattr(request, 'traceid') + ':索引:' + str(getattr(request, 'trace_links_index')) + ':日志信息写入异常')
# 封装一下关于记录序号的日志记录用于全链路的日志请求的日志
@staticmethod
async def async_trace_add_log_record(request: Request, event_des='', msg_dict={}, remarks=''):
'''
:param event_des: 日志记录事件描述
:param msg_dict: 日志记录信息字典
:param remarks: 日志备注信息
:return:
'''
# request.request_links_index = request.request_links_index + 1
# 如果没有这个标记的属性的,说明这个接口的不需要记录啦!
if hasattr(request, 'traceid'):
log = {
# 自定义一个新的参数复制到我们的请求上下文的对象中
'traceid': getattr(request, 'traceid'),
# 定义链路所以序号
'trace_index': getattr(request, 'trace_links_index'),
# 日志时间描述
'event_des': event_des,
# 日志内容详情
'msg_dict': msg_dict,
# 日志备注信息
'remarks': remarks
}
# 为少少相关记录,删除不必要的为空的日志内容信息,
if not remarks:
log.pop('remarks')
if not msg_dict:
log.pop('msg_dict')
try:
log_msg = json_helper.dict_to_json_ensure_ascii(log) # 返回文本
logger.info(log_msg)
except:
logger.info(getattr(request, 'traceid') + ':索引:' + str(getattr(request, 'trace_links_index')) + ':日志信息写入异常')
async def _init_trace_start_log_record(self, request: Request):
'''
请求记录初始化
:return:
'''
path_info = request.url.path
if path_info not in ['/favicon.ico'] and 'websocket' not in path_info:
if request.method != 'OPTIONS':
# 追踪索引
request.trace_links_index = 0
# 追踪ID
request.traceid = str(uuid.uuid4()).replace('-', '')
# 计算时间
request.start_time = perf_counter()
# 获取请求来源的IP,请求的方法
ip, method, url = request.client.host, request.method, request.url.path
# print('scope', request.scope)
# 先看表单有没有数据:
try:
body_form = await request.form()
except:
body_form = None
body =None
try:
body_bytes = await request.body()
if body_bytes:
try:
body = await request.json()
except:
pass
if body_bytes:
try:
body =body_bytes.decode('utf-8')
except:
body = body_bytes.decode('gb2312')
except:
pass
log_msg = {
# 'headers': str(request.headers),
# 'user_agent': str(request.user_agent),
# 记录请求头信息
'headers': request.headers if str(request.headers) else '',
# 记录请求URL信息
'url': url,
# 记录请求方法
'method': method,
# 记录请求来源IP
'ip': ip,
# 'path': request.path,
# 记录请求提交的参数信息
'params': {
'query_params': '' if not request.query_params else request.query_params,
'from': body_form,
'body': body
},
# 记录请求的开始时间
# 'start_time': f'{(start_time)}',
}
# 执行写入--日志具体的内容信息
await self.async_trace_add_log_record(request, event_des='请求开始', msg_dict=log_msg)
async def _init_trace_end_log_record(self, request: Request, response: Response):
# https://stackoverflow.com/questions/64115628/get-starlette-request-body-in-the-middleware-context
if hasattr(request, 'traceid'):
start_time = getattr(request, 'start_time')
end_time = f'{(perf_counter() - start_time):.2f}'
# 获取响应报文信息内容
resp_body = None
if isinstance(response,Response):
resp_body = str(response.body)
log_msg = {
# 记录请求耗时
'cost_time': end_time,
# 记录请求响应的最终报文信息
'resp_body': resp_body
}
await self.async_trace_add_log_record(request, event_des='请求结束', msg_dict=log_msg)
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
# request_id = str(uuid4())
# 请求日志的初始化操作
await self._init_trace_start_log_record(request)
response: Response = await original_route_handler(request)
# 日志收尾记录
await self._init_trace_end_log_record(request, response)
#
# if await request.body():
# print(await request.body())
return response
return custom_route_handler
3:综合方案一个思考和实践
在之前我的使用了自定义中间件的另一种方式,其实只是对我们的一个request.scope进行扩展填充,基于这个思路上,我们可以再想一个,如果我坚持使用中间件处理我的日志的话,并且我需要得到传递的请求traceid的话,那我有一个新的思路就是:
- 在自定路由中读取我们body或js,并写入我们的请求traceid和其他需要传递的信息到我们request.scope中,
- 然后再最外层的中间件里读取我们的相关的request.scope中,写入的信息如:traceid
首先自定义我们的ContextIncludedRoute:
from time import perf_counter
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import StreamingResponse
import uuid
import json
from apps.utils import json_helper
from apps.utils import aiowrap_helper
from fastapi import APIRouter, FastAPI, Request, Response, Body
from fastapi.routing import APIRoute
from typing import Callable, List
from fastapi.responses import Response
class ContextIncludedRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
# request_id = str(uuid4())
# 请求日志的初始化操作
# await self._init_trace_start_log_record(request)
try:
body_form = await request.form()
except:
body_form = None
body = None
try:
body_bytes = await request.body()
if body_bytes:
try:
body = await request.json()
except:
pass
if body_bytes:
try:
body = body_bytes.decode('utf-8')
except:
body = body_bytes.decode('gb2312')
except:
pass
request.scope.setdefault("traceid", str(uuid.uuid4()).replace('-', ''))
request.scope.setdefault("trace_links_index", 0)
request.scope.setdefault("requres_body_form", body_form)
request.scope.setdefault("requres_body", body)
# 这里记录下请求入口的时候相关的日志信息
response: Response = await original_route_handler(request)
return response
return custom_route_handler
然后再定义我们的日志记录的中间件读取相关的信息进行记录,末尾再记录的请求结果信息。 感觉这个多多少少会点绕来绕去的感觉!不过可以根据自己的需求来修改咯!
4:总结
- Request请求上下文对象只能被消费一次的问题真的希望未来能解决,不然有点打击我使用这个框架哈哈,当然只是我自己的感觉!比较多数的情况下,我们的日志记录一般不会在应用层内做相关的记录,非常的一些特殊的情况,我个人的业务情况的话,就需要做相关的请求信息的获取,所以对日志记录中间的需求非常强烈!
结尾
简单小笔记!仅供参考!
转载自:https://juejin.cn/post/6972031031155097631