Celery 源码分析(六): 一个完整的Tasks生命周期
一个完整的Tasks生命周期
这个时候我们终于可以解答我们在Celery基础架构文末留出来的第二个问题了:
我们定义的task celery 是如何扫描并注册到celery的?
我们知道,在python中,貌似并没有java哪种注解的概念,也没有发现类似于注解扫描的机制,那celery是如何知道我们哪些函数是被@task修饰过并添加到上下文的呢?
在详细分析源码之前,我们需要先搞清楚一个概念,还记得我们在之前的文章中曾分析了task的协议,其中task_id这些源信息是放在header里面的,参数是放在body里面的,这说明了一个非常有趣的事情,celery貌似并没有把函数作为一个单元作为消息体在消息队列中传输,而是只告诉了worker对应的任务的名称以及参数。为什么不把函数作为一整个计算单元传到worker呢,worker拿到代码直接执行不就好了?
celery这样做的好处非常的明显:那就是极大的压缩了消息的体积,因为一整个函数传输的数据量通常来说并不小。这波叫节省数据传输的成本,提高效率。
但缺点也是显而易见的,那就是worker和我们代码形成了比较深的耦合,即worker必须要从我们项目启动,不然无法根据task_name 去找到并执行相关的task的逻辑。
既然我们通常声明一个celery 任务的方式就是给对应的函数加上@task
的这个装饰器,那不妨我们就从这里下手。
@Task 装饰器
def task(*args, **kwargs):
"""Deprecated decorator, please use :func:`celery.task`."""
return current_app.task(*args, **dict({'base': Task}, **kwargs))
点进去一看,非常的简单的,注意,这里的current_app
是我们Worker中全局的Celery对象的实例。这个装饰器的实际作用就是把我们的函数包装成了一个Task的代理对象,不然你的函数哪里来的delay
和apply_ansyc
方法,是吧,非常有道理。而这里的task方法,注意,就是我们前面说非常非常重要的Celery类的task方法。再一次把代码贴上来。对应的代码路径: celery.app.base.Celery
def task(self, *args, **opts):
if USING_EXECV and opts.get('lazy', True):
# When using execv the task in the original module will point to a
# different app, so doing things like 'add.request' will point to
# a different task instance. This makes sure it will always use
# the task instance from the current app.
# Really need a better solution for this :(
from . import shared_task
return shared_task(*args, lazy=False, **opts)
def inner_create_task_cls(shared=True, filter=None, lazy=True, **opts):
_filt = filter
def _create_task_cls(fun):
if shared:
def cons(app):
return app._task_from_fun(fun, **opts)
cons.__name__ = fun.__name__
connect_on_app_finalize(cons)
if not lazy or self.finalized:
ret = self._task_from_fun(fun, **opts)
else:
# 生成一个代理对象
# return a proxy object that evaluates on first use
ret = PromiseProxy(self._task_from_fun, (fun,), opts,
__doc__=fun.__doc__)
# 放到待处理的队列里面,注意,这部分和connect_on_app_finalize(cons)会各执行一次。
self._pending.append(ret)
if _filt:
return _filt(ret)
return ret
return _create_task_cls
if len(args) == 1:
if callable(args[0]):
return inner_create_task_cls(**opts)(*args)
raise TypeError('argument 1 to @task() must be a callable')
if args:
raise TypeError(
'@task() takes exactly 1 argument ({0} given)'.format(
sum([len(args), len(opts)])))
return inner_create_task_cls(**opts)
注意inner_create_task_cls
, 最终我们把task包装成了一个这玩意。首先我们需要特别留意两个地方,一个就是connect_on_app_finalize
。因为shared
默认值是True,所以一定会执行这行代码。点进去看看:
_on_app_finalizers = set()
def connect_on_app_finalize(callback):
"""Connect callback to be called when any app is finalized."""
_on_app_finalizers.add(callback)
return callback
到这里我们发现一个惊天小秘密,在celery内部,维护了一个set,这个set的作用现在我们也不知道是啥。但是我们知道他把一个函数给add进去了,cons
这个函数。他调用了_task_from_fun
方法。点进去看看:
def _task_from_fun(self, fun, name=None, base=None, bind=False, **options):
if not self.finalized and not self.autofinalize:
raise RuntimeError('Contract breach: app not finalized')
name = name or self.gen_task_name(fun.__name__, fun.__module__)
base = base or self.Task
if name not in self._tasks:
run = fun if bind else staticmethod(fun)
task = type(fun.__name__, (base,), dict({
'app': self,
'name': name,
'run': run,
'_decorated': True,
'__doc__': fun.__doc__,
'__module__': fun.__module__,
'__header__': staticmethod(head_from_fun(fun, bound=bind)),
'__wrapped__': run}, **options))()
# for some reason __qualname__ cannot be set in type()
# so we have to set it here.
try:
task.__qualname__ = fun.__qualname__
except AttributeError:
pass
# 注册到app维护的列表里面
self._tasks[task.name] = task
# 这里的作用是吧app的各种属性复制到task内部去
task.bind(self) # connects task to this app
"""
...... 省略一部分不重要的
"""
if autoretry_for and not hasattr(task, '_orig_run'):
@wraps(task.run)
def run(*args, **kwargs):
try:
return task._orig_run(*args, **kwargs)
except Ignore:
# If Ignore signal occures task shouldn't be retried,
# even if it suits autoretry_for list
raise
except Retry:
raise
except autoretry_for as exc:
if retry_backoff:
retry_kwargs['countdown'] = \
get_exponential_backoff_interval(
factor=retry_backoff,
retries=task.request.retries,
maximum=retry_backoff_max,
full_jitter=retry_jitter)
raise task.retry(exc=exc, **retry_kwargs)
task._orig_run, task.run = task.run, run
else:
task = self._tasks[name]
return task
这段代码的可读性出神入化了已经,name的话比较容易理解, 如果你指定了就用你指定的,如果你没指定,那就自动给你生成一个,生成的规则是:itsm.ticket.tasks.notify_task
其实就是路径。这里的fun
其实就是我们的函数对象。
我们看到首先是用type
关键字给生成一个task对象,这里的task对象是Task
类的实例。然后把它加到了app的_tasks
列表里面。然后将task与app进行绑定的。特别要注意的是self._tasks
这个属性,里面的内容大致上长这个样子:
{
"itsm.ticket.tasks.notify_task": <@task: itsm.ticket.tasks.notify_task of proj at 0x7fbcecd3d9b0 (v2 compatible)>
}
这样一来我们的task_name
和对应的 执行对象就有了
注意哦,我们的_on_app_finalizers
存的可不是什么函数实例,当时我们cons
这个函数传进去了,也就是函数此刻还是没执行的。那究竟是在哪里执行的呢? 这个需要我们回到Worker的初始化过程中去。我们前面有提到,Worker对象在初始化过程中有执行到on_before_init
函数,然后在这个函数里面我们发现了一段代码: trace.setup_worker_optimizations(self.app, self.hostname)
class Worker(WorkController):
"""Worker as a program."""
def on_before_init(self, quiet=False, **kwargs):
self.quiet = quiet
# 就是这一行
trace.setup_worker_optimizations(self.app, self.hostname)
# this signal can be used to set up configuration for
# workers by name.
signals.celeryd_init.send(
sender=self.hostname, instance=self,
conf=self.app.conf, options=kwargs,
)
check_privileges(self.app.conf.accept_content)
当我们赶到setup_worker_optimizations
现场的时候,就只发现了这段代码:
# evaluate all task classes by finalizing the app.
app.finalize()
再次回到Celery
这个类的finalize
方法。
def finalize(self, auto=False):
"""Finalize the app.
This loads built-in tasks, evaluates pending task decorators,
reads configuration, etc.
"""
with self._finalize_mutex:
if not self.finalized:
if auto and not self.autofinalize:
raise RuntimeError('Contract breach: app not finalized')
self.finalized = True
_announce_app_finalized(self)
pending = self._pending
while pending:
maybe_evaluate(pending.popleft())
for task in values(self._tasks):
task.bind(self)
self.on_after_finalize.send(sender=self)
发现这样一行代码,点进去看看_announce_app_finalized(self)
:
def _announce_app_finalized(app):
callbacks = set(_on_app_finalizers)
for callback in callbacks:
callback(app)
执行到这一步,发现,诶,我们的self._tasks
终于算是完整了。这里的设计非常的绕,可能需要多看看才能搞清楚这之间的调用关系。
到这里我们几乎已经回答了上面提到的那个问题,Task是如何被注册到全局上下文的。当我么import某个模块的时候,装饰器会自动执行,所以动态加载项目下所有模块的时候,task就会被注册到celery里面去了。
任务的消费
讲完任务如何发现了,那下面就需要去梳理任务是如何被消费的。在celery中维护了多进程的任务消费模型,但是任务执行的细节并不在本文的范畴中,可能会放到后面的章节中去梳理。
大家现在停下来仔细的思考一下,如果说任务的注册是从@task
作为入口开始的,那任务消费的入口又是哪里?
那必然是从接收到消息的那一刻开始的。因为只有接受到了任务才可以开启任务的消费过程。还记得我们在Consumer中发现的那个有用的函数么,on_task_received
, 它和kombu的消息回调函数绑定到了一块,也就是当监听到消息队列中有新消息的时候,就会调用on_task_received
。
def create_task_handler(self, promise=promise):
strategies = self.strategies
on_unknown_message = self.on_unknown_message
on_unknown_task = self.on_unknown_task
on_invalid_task = self.on_invalid_task
callbacks = self.on_task_message
call_soon = self.call_soon
def on_task_received(message):
# payload will only be set for v1 protocol, since v2
# will defer deserializing the message body to the pool.
payload = None
try:
type_ = message.headers['task'] # protocol v2
except TypeError:
return on_unknown_message(None, message)
except KeyError:
try:
payload = message.decode()
except Exception as exc: # pylint: disable=broad-except
return self.on_decode_error(message, exc)
try:
type_, payload = payload['task'], payload # protocol v1
except (TypeError, KeyError):
return on_unknown_message(payload, message)
try:
strategy = strategies[type_]
except KeyError as exc:
return on_unknown_task(None, message, exc)
else:
try:
strategy(
message, payload,
promise(call_soon, (message.ack_log_error,)),
promise(call_soon, (message.reject_log_error,)),
callbacks,
)
except (InvalidTaskError, ContentDisallowed) as exc:
return on_invalid_task(payload, message, exc)
except DecodeError as exc:
return self.on_decode_error(message, exc)
return on_task_received
这里有个地方非常值得我们注意,self.strategies
, 我们的tasks呢???,上面讲半天,好家伙,根本没用着。
大可不必,其中必有引擎,我们找一下self.strategies
的初始化逻辑:
def update_strategies(self):
loader = self.app.loader
for name, task in items(self.app.tasks):
self.strategies[name] = task.start_strategy(self.app, self)
task.__trace__ = build_tracer(name, task, loader, self.hostname,
app=self.app)
这不联系上了吗,再点进去start_sreategy
方法看看是干啥的:
def start_strategy(self, app, consumer, **kwargs):
return instantiate(self.Strategy, self, app, consumer, **kwargs)
发现是生成了一个Strategy
对象的实例。追查发现,Strategy
对应的位置是: celery.worker.strategy:default
,点进去看看,说不定能发现什么意外惊喜呢。
def default(task, app, consumer,
info=logger.info, error=logger.error, task_reserved=task_reserved,
to_system_tz=timezone.to_system, bytes=bytes, buffer_t=buffer_t,
proto1_to_proto2=proto1_to_proto2):
"""Default task execution strategy.
Note:
Strategies are here as an optimization, so sadly
it's not very easy to override.
"""
hostname = consumer.hostname
connection_errors = consumer.connection_errors
_does_info = logger.isEnabledFor(logging.INFO)
# task event related
# (optimized to avoid calling request.send_event)
eventer = consumer.event_dispatcher
events = eventer and eventer.enabled
send_event = eventer and eventer.send
task_sends_events = events and task.send_events
call_at = consumer.timer.call_at
apply_eta_task = consumer.apply_eta_task
rate_limits_enabled = not consumer.disable_rate_limits
get_bucket = consumer.task_buckets.__getitem__
handle = consumer.on_task_request
limit_task = consumer._limit_task
limit_post_eta = consumer._limit_post_eta
body_can_be_buffer = consumer.pool.body_can_be_buffer
Request = symbol_by_name(task.Request)
Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
revoked_tasks = consumer.controller.state.revoked
def task_message_handler(message, body, ack, reject, callbacks,
to_timestamp=to_timestamp):
if body is None and 'args' not in message.payload:
body, headers, decoded, utc = (
message.body, message.headers, False, app.uses_utc_timezone(),
)
if not body_can_be_buffer:
body = bytes(body) if isinstance(body, buffer_t) else body
else:
if 'args' in message.payload:
body, headers, decoded, utc = hybrid_to_proto2(message,
message.payload)
else:
body, headers, decoded, utc = proto1_to_proto2(message, body)
req = Req(
message,
on_ack=ack, on_reject=reject, app=app, hostname=hostname,
eventer=eventer, task=task, connection_errors=connection_errors,
body=body, headers=headers, decoded=decoded, utc=utc,
)
if _does_info:
info('Received task: %s', req)
if (req.expires or req.id in revoked_tasks) and req.revoked():
return
signals.task_received.send(sender=consumer, request=req)
if task_sends_events:
send_event(
'task-received',
uuid=req.id, name=req.name,
args=req.argsrepr, kwargs=req.kwargsrepr,
root_id=req.root_id, parent_id=req.parent_id,
retries=req.request_dict.get('retries', 0),
eta=req.eta and req.eta.isoformat(),
expires=req.expires and req.expires.isoformat(),
)
bucket = None
eta = None
if req.eta:
try:
if req.utc:
eta = to_timestamp(to_system_tz(req.eta))
else:
eta = to_timestamp(req.eta, app.timezone)
except (OverflowError, ValueError) as exc:
error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
req.eta, exc, req.info(safe=True), exc_info=True)
req.reject(requeue=False)
if rate_limits_enabled:
bucket = get_bucket(task.name)
if eta and bucket:
consumer.qos.increment_eventually()
return call_at(eta, limit_post_eta, (req, bucket, 1),
priority=6)
if eta:
consumer.qos.increment_eventually()
call_at(eta, apply_eta_task, (req,), priority=6)
return task_message_handler
if bucket:
return limit_task(req, bucket, 1)
task_reserved(req)
if callbacks:
[callback(req) for callback in callbacks]
handle(req)
return task_message_handler
哇靠,有点意思,也就是说,celery再收到消息之后,马上调用了这个消息对应的执行策略,默认的是default
, 点进去看看它到底干了啥。注意,实际上strategies对应的value是task_message_handler
这个函数。
这里特别需要留意的地方是这一句。
Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
点进去看看创建了个啥,嘿,有点意思:
def create_request_cls(base, task, pool, hostname, eventer,
ref=ref, revoked_tasks=revoked_tasks,
task_ready=task_ready, trace=trace_task_ret):
default_time_limit = task.time_limit
default_soft_time_limit = task.soft_time_limit
apply_async = pool.apply_async
acks_late = task.acks_late
events = eventer and eventer.enabled
class Request(base):
def execute_using_pool(self, pool, **kwargs):
task_id = self.task_id
if (self.expires or task_id in revoked_tasks) and self.revoked():
raise TaskRevokedError(task_id)
time_limit, soft_time_limit = self.time_limits
# 注意这里
result = apply_async(
trace,
args=(self.type, task_id, self.request_dict, self.body,
self.content_type, self.content_encoding),
accept_callback=self.on_accepted,
timeout_callback=self.on_timeout,
callback=self.on_success,
error_callback=self.on_failure,
soft_timeout=soft_time_limit or default_soft_time_limit,
timeout=time_limit or default_time_limit,
correlation_id=task_id,
)
# cannot create weakref to None
# pylint: disable=attribute-defined-outside-init
self._apply_result = maybe(ref, result)
return result
def on_success(self, failed__retval__runtime, **kwargs):
failed, retval, runtime = failed__retval__runtime
if failed:
if isinstance(retval.exception, (
SystemExit, KeyboardInterrupt)):
raise retval.exception
return self.on_failure(retval, return_ok=True)
task_ready(self)
if acks_late:
self.acknowledge()
if events:
self.send_event(
'task-succeeded', result=retval, runtime=runtime,
)
return Request
这里实际上返回了一个内置类,Request
, 他有一个方法 execute_using_pool
,然后它还调用了apply_async
, 注意,这里真的就开始执行我们的任务了,因为任务名,参数啥的都给传进去了。 第二个从方法名上也可以看到,大概率是和并发有关系的,不过这个不属于本章的重点,让我们现在回到task_message_handler
函数去,这次我们只看重点代码:
def task_message_handler(message, body, ack, reject, callbacks,
to_timestamp=to_timestamp):
req = Req(
message,
on_ack=ack, on_reject=reject, app=app, hostname=hostname,
eventer=eventer, task=task, connection_errors=connection_errors,
body=body, headers=headers, decoded=decoded, utc=utc,
)
# 打印日志
if _does_info:
info('Received task: %s', req)
if (req.expires or req.id in revoked_tasks) and req.revoked():
return
# 发送信号
signals.task_received.send(sender=consumer, request=req)
# 发送事件
if task_sends_events:
send_event(
'task-received',
uuid=req.id, name=req.name,
args=req.argsrepr, kwargs=req.kwargsrepr,
root_id=req.root_id, parent_id=req.parent_id,
retries=req.request_dict.get('retries', 0),
eta=req.eta and req.eta.isoformat(),
expires=req.expires and req.expires.isoformat(),
)
bucket = None
eta = None
# 处理延时任务
if req.eta:
try:
if req.utc:
eta = to_timestamp(to_system_tz(req.eta))
else:
eta = to_timestamp(req.eta, app.timezone)
except (OverflowError, ValueError) as exc:
error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
req.eta, exc, req.info(safe=True), exc_info=True)
req.reject(requeue=False)
if rate_limits_enabled:
bucket = get_bucket(task.name)
if eta and bucket:
consumer.qos.increment_eventually()
return call_at(eta, limit_post_eta, (req, bucket, 1),
priority=6)
if eta:
consumer.qos.increment_eventually()
call_at(eta, apply_eta_task, (req,), priority=6)
return task_message_handler
if bucket:
return limit_task(req, bucket, 1)
task_reserved(req)
# 如果注册了回调函数,则调用回调函数
if callbacks:
[callback(req) for callback in callbacks]
# 重点来了,重点来了
handle(req)
return task_message_handler
通过查看源码,handle实际上绑定的函数是handle = consumer.on_task_request
,consumer.on_task_request
也不是Consumer类的,是初始化的时候传进去的,继续往上找:
class Consumer(bootsteps.StartStopStep):
"""Bootstep starting the Consumer blueprint."""
last = True
def create(self, w):
if w.max_concurrency:
prefetch_count = max(w.max_concurrency, 1) * w.prefetch_multiplier
else:
prefetch_count = w.concurrency * w.prefetch_multiplier
c = w.consumer = self.instantiate(
w.consumer_cls, w.process_task,
hostname=w.hostname,
task_events=w.task_events,
init_callback=w.ready_callback,
initial_prefetch_count=prefetch_count,
pool=w.pool,
timer=w.timer,
app=w.app,
controller=w,
hub=w.hub,
worker_options=w.options,
disable_rate_limits=w.disable_rate_limits,
prefetch_multiplier=w.prefetch_multiplier,
)
return c
发现consumer.on_task_request
实际指向的是 w.process_task
, 有点眉目了,继续往Worker
找,只发现了这个:
def _process_task(self, req):
"""Process task by sending it to the pool of workers."""
try:
req.execute_using_pool(self.pool)
except TaskRevokedError:
try:
self._quick_release() # Issue 877
except AttributeError:
pass
只到这里,才把execute_using_pool
调用了。但是问题出现了,_process_task
它多了个下划线啊,在Worker类中,我们并没有发现process_task
这个函数,根据我们的经验,只能懵了,还记得我说execute_using_pool
看起来和多进程有关系吗? 就算没有关系,那大概率也和Pool
有点关系,于是我们找Worker的子组件.
class Pool(bootsteps.StartStopStep):
def create(self, w):
semaphore = None
max_restarts = None
if w.app.conf.worker_pool in GREEN_POOLS: # pragma: no cover
warnings.warn(UserWarning(W_POOL_SETTING))
threaded = not w.use_eventloop or IS_WINDOWS
procs = w.min_concurrency
# 果然在这儿
w.process_task = w._process_task
"""
省略部分无关代码
"""
return pool
def info(self, w):
return {'pool': w.pool.info if w.pool else 'N/A'}
def register_with_event_loop(self, w, hub):
w.pool.register_with_event_loop(hub)
真相大白,自此,我们已经知道了我们的任务是如何被发现,以及如何被消费的了。但是任务究竟是怎么被具体消费的,我们仍然无从知晓。但是我们已经非常幸运的找到了打开这快空白的入口,那就是result = apply_async()
。
转载自:https://juejin.cn/post/7352792335904391195