实现一个 Rust 的异步运行时
async { println!("run in async") };
当你写出上面的代码, 你会面临一个很尴尬的事情,rust
标准库里面没有直接运行 async 代码
的方法,需要你引入 tokio
或者 smol
这一类的异步运行时
本文将介绍如何写一个简单的异步运行时,运行异步代码
定义基本类型
type TaskID = u32;
struct Task {
future: Pin<Box<dyn Future<Output = ()>>>,
}
unsafe impl Send for Task {}
Task
是 future
的封装 ,可以在线程之间安全的传递
struct TaskManager {
pending_check_task_ids: Vec<TaskID>,
tasks: BTreeMap<TaskID, Task>,
}
TaskManager
存储所有的 Task
, 其中 pending_check_task_ids
字段包含了所有需要检查状态的任务 ids
, 也就是说要调用 poll
以查看其状态是否 Ready
, tasks
里面包含了所有的Task
struct AsyncRuntime {
task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
workers: Vec<std::thread::JoinHandle<()>>,
}
AsyncRuntime
就是我们的异步运行时,task_manager
将TaskManager
用条件变量包裹起来,workers
存储所有 worker
线程
总结这个异步运行时实现的原理就是:
使用条件变量唤醒线程处理pending_check_task_ids
中等待被 poll
的 future
定义方法
需要实现的异步运行时方法如下
impl AsyncRuntime {
// 创建一个 worker 线程处理任务
fn new_worker(&mut self) {}
// 新建一个异步任务
fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {}
// block
fn wait(self) {}
}
worker
worker
应该使用条件变量等待 pending_check_task_ids
是否有需要检查的任务 id
,有就从里面取出来检查,没有继续等待,代码如下
fn new_worker(&mut self) {
let join_handle = std::thread::spawn({
let task_manager = self.task_manager.clone();
move || loop {
let (lock, cond_var) = &*task_manager;
let mut task_manager = lock.lock().unwrap();
// 等待 pending_check_task_ids 里面有需要检查的任务 id
while task_manager.pending_check_task_ids.is_empty() {
task_manager = cond_var.wait(task_manager).unwrap();
}
// 取出代检查任务检查其状态是否 Ready
for task_id in std::mem::take(&mut task_manager.pending_check_task_ids) {
if let Some(task) = task_manager.tasks.get_mut(&task_id) {
let task = Pin::new(&mut task.future);
task.poll(cx);
}
}
}
});
self.workers.push(join_handle);
}
Context
上面的代码缺少 cx
也就是 Context
,其实 Context
很简单,通过查看 rust
的源代码就可以知道, Context
可以从 Waker
里面直接转换过来, Waker
从 RawWaker
转换过来,RawWaker
可以自定义实现,代码大概如下所示
let raw_waker = RawWaker::new(data, vtable);
let waker = unsafe { Waker::from_raw(raw_waker) };
let mut context = Context::from_waker(&waker);
task.poll(&mut context);
所以实现Context
就要实现一个 RawWaker
, RawWaker
接受一个指针,一个虚表
说白了就是要你实现一个结构体,然后为结构体实现如下的方法
unsafe fn clone(ptr: *const ()) -> RawWaker {}
unsafe fn wake(ptr: *const ()) {}
unsafe fn wake_by_ref(ptr: *const ()) {}
unsafe fn drop(ptr: *const ()) {}
drop
和 clone
不用说了
wake
函数唤醒一个worker
线程去检查该任务的状态,也就是把一个 TaskID
放进pending_check_task_ids
里面,然后唤醒线程去检查
而 waker_by_ref
和 wake
的区别在于 wake_by_ref
传递的是一个指针,所以不应该在函数结束的时候析构ptr
指向的对象,
理解了之后很容易就可以写出下面的代码,封装一个独立的 mod, raw_waker_impl
mod raw_waker_impl {
use super::*;
use std::task::RawWakerVTable;
pub struct Data {
task_id: TaskID,
task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
}
pub static V_TABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn clone(ptr: *const ()) -> RawWaker {
Arc::increment_strong_count(ptr);
RawWaker::new(ptr, &V_TABLE)
}
unsafe fn wake(ptr: *const ()) {
let data = Arc::from_raw(ptr as *const Data);
let (lock, cond_var) = &*data.task_manager;
let mut task_manager = lock.lock().unwrap();
task_manager.pending_check_task_ids.push(data.task_id);
std::mem::drop(task_manager);
cond_var.notify_one();
}
unsafe fn wake_by_ref(ptr: *const ()) {
let data = ptr as *const Arc<Data>;
let data = &*data;
let (lock, cond_var) = &*data.task_manager;
let mut task_manager = lock.lock().unwrap();
task_manager.pending_check_task_ids.push(data.task_id);
std::mem::drop(task_manager);
cond_var.notify_one();
}
unsafe fn drop(ptr: *const ()) {
Arc::from_raw(ptr as *const Data);
}
}
可以看到,ptr
指向的对象是 Arc<Data>
clone
增加 Arc
的引用计数
drop
把指针还原回来,rust
自动析构
wake_by_ref
把需要检查的任务ID 存入pending_check_task_ids
,然后唤醒一个 worker
线程
wake
和 wake_by_ref
逻辑一致,只不过要在函数结束的时候析构
所以 Context
的构建代码就变成了下面这样
let data = Arc::into_raw(Arc::new(raw_waker_impl::Data::new(
task_id,
task_manager_clone,
))) as *const ();
let raw_waker = RawWaker::new(data, &raw_waker_impl::V_TABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
let mut context = Context::from_waker(&waker);
当 future
的状态变成 ready
的时候,就移除该任务
if task.poll(&mut context).is_ready() {
task_manager.tasks.remove(&task_id);
}
spawn
spawn
创建一个 Task
, 然后将 TaskID
放入pending_check_task_ids
, 唤醒一个worker
线程检查它的状态是否 ready
pub fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {
static TASK_ID_COUNT: AtomicU32 = AtomicU32::new(0);
let current_task_id = TASK_ID_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let task = Task {
future: Box::into_pin(Box::new(future)),
};
let (lock, cond_var) = &*self.task_manager;
let mut task_manager = lock.lock().unwrap();
task_manager.pending_check_task_ids.push(current_task_id);
task_manager.tasks.insert(current_task_id, task);
drop(task_manager);
cond_var.notify_one();
}
wait
wait
的代码很简单,只是等所有 worker
线程结束而已
fn wait(self) {
for worker in self.workers {
worker.join().unwrap();
}
}
测试一下
fn test_async_runtime() {
use async_timer::oneshot::Timer;
let mut runtime = AsyncRuntime::default();
runtime.new_worker();
for task_id in 1..100 {
runtime.spawn(async move {
let mut count = 0;
loop {
Timer::new(Duration::from_secs(1)).await;
count += 1;
println!("task {task_id} : count -> {count}")
}
});
}
runtime.wait();
}
总结
所有代码如下所示, 仅供参考
use std::{
collections::BTreeMap,
future::Future,
pin::Pin,
sync::{atomic::AtomicU32, Arc, Condvar, Mutex},
task::{Context, RawWaker, Waker},
};
type TaskID = u32;
struct Task {
future: Pin<Box<dyn Future<Output = ()>>>,
}
unsafe impl Send for Task {}
#[derive(Default)]
struct TaskManager {
pending_check_task_ids: Vec<TaskID>,
tasks: BTreeMap<TaskID, Task>,
}
#[derive(Default)]
struct AsyncRuntime {
task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
workers: Vec<std::thread::JoinHandle<()>>,
}
impl AsyncRuntime {
pub fn new_worker(&mut self) {
let join_handle = std::thread::spawn({
let task_manager = self.task_manager.clone();
let task_manager_clone = self.task_manager.clone();
move || loop {
let (lock, cond_var) = &*task_manager;
let mut task_manager = lock.lock().unwrap();
while task_manager.pending_check_task_ids.is_empty() {
task_manager = cond_var.wait(task_manager).unwrap();
}
for task_id in std::mem::take(&mut task_manager.pending_check_task_ids) {
if let Some(task) = task_manager.tasks.get_mut(&task_id) {
let task = Pin::new(&mut task.future);
let task_manager_clone = task_manager_clone.clone();
let data = Arc::into_raw(Arc::new(raw_waker_impl::Data::new(
task_id,
task_manager_clone,
))) as *const ();
let raw_waker = RawWaker::new(data, &raw_waker_impl::V_TABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
let mut context = Context::from_waker(&waker);
if task.poll(&mut context).is_ready() {
task_manager.tasks.remove(&task_id);
}
}
}
}
});
self.workers.push(join_handle);
}
pub fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {
static TASK_ID_COUNT: AtomicU32 = AtomicU32::new(0);
let current_task_id = TASK_ID_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let task = Task {
future: Box::into_pin(Box::new(future)),
};
let (lock, cond_var) = &*self.task_manager;
let mut task_manager = lock.lock().unwrap();
task_manager.pending_check_task_ids.push(current_task_id);
task_manager.tasks.insert(current_task_id, task);
drop(task_manager);
cond_var.notify_one();
}
pub fn wait(self) {
for worker in self.workers {
worker.join().unwrap();
}
}
}
mod raw_waker_impl {
use super::*;
use std::task::RawWakerVTable;
pub struct Data {
task_id: TaskID,
task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
}
impl Data {
pub fn new(task_id: TaskID, task_manager: Arc<(Mutex<TaskManager>, Condvar)>) -> Self {
Self {
task_id,
task_manager,
}
}
}
pub static V_TABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn clone(ptr: *const ()) -> RawWaker {
Arc::increment_strong_count(ptr);
RawWaker::new(ptr, &V_TABLE)
}
unsafe fn wake(ptr: *const ()) {
let data = Arc::from_raw(ptr as *const Data);
let (lock, cond_var) = &*data.task_manager;
let mut task_manager = lock.lock().unwrap();
task_manager.pending_check_task_ids.push(data.task_id);
std::mem::drop(task_manager);
cond_var.notify_one();
}
unsafe fn wake_by_ref(ptr: *const ()) {
let data = ptr as *const Arc<Data>;
let data = &*data;
let (lock, cond_var) = &*data.task_manager;
let mut task_manager = lock.lock().unwrap();
task_manager.pending_check_task_ids.push(data.task_id);
std::mem::drop(task_manager);
cond_var.notify_one();
}
unsafe fn drop(ptr: *const ()) {
Arc::from_raw(ptr as *const Data);
}
}
转载自:https://juejin.cn/post/7354174438005440550