• 如何写一个简单的Rust异步运行时
  • 发布于 1周前
  • 58 热度
    0 评论
async { println!("run in async") };
当你写出上面的代码, 你会面临一个很尴尬的事情,rust 标准库里面没有直接运行 async 代码 的方法,需要你引入 tokio 或者 smol 这一类的异步运行时
本文将介绍如何写一个简单的异步运行时,运行异步代码
定义基本类型
type TaskID = u32;

struct Task {
    future: Pin<Box<dyn Future<Output = ()>>>,
}
unsafe impl Send for Task {}
Task 是 future 的封装 ,可以在线程之间安全的传递
struct TaskManager {
    pending_check_task_ids: Vec<TaskID>,
    tasks: BTreeMap<TaskID, Task>,
}
TaskManager存储所有的 Task, 其中 pending_check_task_ids 字段包含了所有需要检查状态的任务 ids, 也就是说要调用 poll 以查看其状态是否 Ready, tasks 里面包含了所有的Task
struct AsyncRuntime {
    task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
    workers: Vec<std::thread::JoinHandle<()>>,
}
AsyncRuntime 就是我们的异步运行时,task_manager 将TaskManager 用条件变量包裹起来,workers 存储所有 worker 线程
总结这个异步运行时实现的原理就是:使用条件变量唤醒线程处理pending_check_task_ids 中等待被 poll 的 future
定义方法
需要实现的异步运行时方法如下
impl AsyncRuntime {
    // 创建一个 worker 线程处理任务
    fn new_worker(&mut self) {}
    // 堆代码 duidaima.com
    // 新建一个异步任务
    fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {}
    
    // block
    fn wait(self) {}
}
worker
worker 应该使用条件变量等待 pending_check_task_ids 是否有需要检查的任务 id,有就从里面取出来检查,没有继续等待,代码如下
 fn new_worker(&mut self) {
    let join_handle = std::thread::spawn({
        let task_manager = self.task_manager.clone();
        move || loop {
            let (lock, cond_var) = &*task_manager;

            let mut task_manager = lock.lock().unwrap();

            // 等待 pending_check_task_ids 里面有需要检查的任务 id
            while task_manager.pending_check_task_ids.is_empty() {
                task_manager = cond_var.wait(task_manager).unwrap();
            }

            // 取出代检查任务检查其状态是否 Ready
            for task_id in std::mem::take(&mut task_manager.pending_check_task_ids) {
                if let Some(task) = task_manager.tasks.get_mut(&task_id) {
                    let task = Pin::new(&mut task.future);
                    task.poll(cx);
                }
            }
        }
    });

    self.workers.push(join_handle);
}
Context
上面的代码缺少 cx 也就是 Context,其实 Context 很简单,通过查看 rust 的源代码就可以知道, Context 可以从 Waker 里面直接转换过来, Waker 从 RawWaker 转换过来,RawWaker可以自定义实现,代码大概如下所示
let raw_waker = RawWaker::new(data, vtable);
let waker = unsafe { Waker::from_raw(raw_waker) };
let mut context = Context::from_waker(&waker);
task.poll(&mut context);
所以实现Context 就要实现一个 RawWaker, RawWaker 接受一个指针,一个虚表
说白了就是要你实现一个结构体,然后为结构体实现如下的方法
unsafe fn clone(ptr: *const ()) -> RawWaker {}
unsafe fn wake(ptr: *const ()) {}
unsafe fn wake_by_ref(ptr: *const ()) {}
unsafe fn drop(ptr: *const ()) {}
drop 和 clone 不用说了
wake函数唤醒一个worker线程去检查该任务的状态,也就是把一个 TaskID 放进pending_check_task_ids 里面,然后唤醒线程去检查。而 waker_by_ref 和 wake 的区别在于 wake_by_ref  传递的是一个指针,所以不应该在函数结束的时候析构ptr指向的对象,理解了之后很容易就可以写出下面的代码,封装一个独立的 mod, raw_waker_impl
mod raw_waker_impl {
    use super::*;
    use std::task::RawWakerVTable;

    pub struct Data {
        task_id: TaskID,
        task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
    }

    pub static V_TABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);

    unsafe fn clone(ptr: *const ()) -> RawWaker {
        Arc::increment_strong_count(ptr);
        RawWaker::new(ptr, &V_TABLE)
    }

    unsafe fn wake(ptr: *const ()) {
        let data = Arc::from_raw(ptr as *const Data);
        let (lock, cond_var) = &*data.task_manager;
        let mut task_manager = lock.lock().unwrap();
        task_manager.pending_check_task_ids.push(data.task_id);
        std::mem::drop(task_manager);
        cond_var.notify_one();
    }

    unsafe fn wake_by_ref(ptr: *const ()) {
        let data = ptr as *const Arc<Data>;
        let data = &*data;
        let (lock, cond_var) = &*data.task_manager;
        let mut task_manager = lock.lock().unwrap();
        task_manager.pending_check_task_ids.push(data.task_id);
        std::mem::drop(task_manager);
        cond_var.notify_one();
    }

    unsafe fn drop(ptr: *const ()) {
        Arc::from_raw(ptr as *const Data);
    }
}
可以看到,ptr 指向的对象是 Arc<Data>
clone  增加 Arc 的引用计数
drop 把指针还原回来,rust 自动析构
wake_by_ref 把需要检查的任务ID 存入pending_check_task_ids,然后唤醒一个 worker 线程
wake 和 wake_by_ref 逻辑一致,只不过要在函数结束的时候析构
所以 Context 的构建代码就变成了下面这样
 let data = Arc::into_raw(Arc::new(raw_waker_impl::Data::new(
                            task_id,
                            task_manager_clone,
                        ))) as *const ();

let raw_waker = RawWaker::new(data, &raw_waker_impl::V_TABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
let mut context = Context::from_waker(&waker);
当 future 的状态变成 ready 的时候,就移除该任务
if task.poll(&mut context).is_ready() {
    task_manager.tasks.remove(&task_id);
}
spawn
spawn 创建一个 Task, 然后将 TaskID 放入pending_check_task_ids, 唤醒一个worker线程检查它的状态是否 ready
 pub fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {
    static TASK_ID_COUNT: AtomicU32 = AtomicU32::new(0);
    let current_task_id = TASK_ID_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);

    let task = Task {
        future: Box::into_pin(Box::new(future)),
    };

    let (lock, cond_var) = &*self.task_manager;
    let mut task_manager = lock.lock().unwrap();
    task_manager.pending_check_task_ids.push(current_task_id);
    task_manager.tasks.insert(current_task_id, task);
    drop(task_manager);
    cond_var.notify_one();
}
wait
wait 的代码很简单,只是等所有 worker 线程结束而已
 fn wait(self) {
    for worker in self.workers {
        worker.join().unwrap();
    }
}
测试一下
 fn test_async_runtime() {
    use async_timer::oneshot::Timer;

    let mut runtime = AsyncRuntime::default();

    runtime.new_worker();

    for task_id in 1..100 {
        runtime.spawn(async move {
            let mut count = 0;
            loop {
                Timer::new(Duration::from_secs(1)).await;
                count += 1;
                println!("task {task_id} : count -> {count}")
            }
        });
    }

    runtime.wait();
}
总结
所有代码如下所示, 仅供参考
use std::{
    collections::BTreeMap,
    future::Future,
    pin::Pin,
    sync::{atomic::AtomicU32, Arc, Condvar, Mutex},
    task::{Context, RawWaker, Waker},
};

type TaskID = u32;

struct Task {
    future: Pin<Box<dyn Future<Output = ()>>>,
}

unsafe impl Send for Task {}

#[derive(Default)]
struct TaskManager {
    pending_check_task_ids: Vec<TaskID>,
    tasks: BTreeMap<TaskID, Task>,
}

#[derive(Default)]
struct AsyncRuntime {
    task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
    workers: Vec<std::thread::JoinHandle<()>>,
}

impl AsyncRuntime {
    pub fn new_worker(&mut self) {
        let join_handle = std::thread::spawn({
            let task_manager = self.task_manager.clone();
            let task_manager_clone = self.task_manager.clone();
            move || loop {
                let (lock, cond_var) = &*task_manager;

                let mut task_manager = lock.lock().unwrap();

                while task_manager.pending_check_task_ids.is_empty() {
                    task_manager = cond_var.wait(task_manager).unwrap();
                }

                for task_id in std::mem::take(&mut task_manager.pending_check_task_ids) {
                    if let Some(task) = task_manager.tasks.get_mut(&task_id) {
                        let task = Pin::new(&mut task.future);

                        let task_manager_clone = task_manager_clone.clone();
                        let data = Arc::into_raw(Arc::new(raw_waker_impl::Data::new(
                            task_id,
                            task_manager_clone,
                        ))) as *const ();

                        let raw_waker = RawWaker::new(data, &raw_waker_impl::V_TABLE);
                        let waker = unsafe { Waker::from_raw(raw_waker) };
                        let mut context = Context::from_waker(&waker);

                        if task.poll(&mut context).is_ready() {
                            task_manager.tasks.remove(&task_id);
                        }
                    }
                }
            }
        });

        self.workers.push(join_handle);
    }

    pub fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {
        static TASK_ID_COUNT: AtomicU32 = AtomicU32::new(0);
        let current_task_id = TASK_ID_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);

        let task = Task {
            future: Box::into_pin(Box::new(future)),
        };

        let (lock, cond_var) = &*self.task_manager;
        let mut task_manager = lock.lock().unwrap();
        task_manager.pending_check_task_ids.push(current_task_id);
        task_manager.tasks.insert(current_task_id, task);
        drop(task_manager);
        cond_var.notify_one();
    }

    pub fn wait(self) {
        for worker in self.workers {
            worker.join().unwrap();
        }
    }
}

mod raw_waker_impl {
    use super::*;
    use std::task::RawWakerVTable;

    pub struct Data {
        task_id: TaskID,
        task_manager: Arc<(Mutex<TaskManager>, Condvar)>,
    }

    impl Data {
        pub fn new(task_id: TaskID, task_manager: Arc<(Mutex<TaskManager>, Condvar)>) -> Self {
            Self {
                task_id,
                task_manager,
            }
        }
    }

    pub static V_TABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);

    unsafe fn clone(ptr: *const ()) -> RawWaker {
        Arc::increment_strong_count(ptr);
        RawWaker::new(ptr, &V_TABLE)
    }

    unsafe fn wake(ptr: *const ()) {
        let data = Arc::from_raw(ptr as *const Data);
        let (lock, cond_var) = &*data.task_manager;
        let mut task_manager = lock.lock().unwrap();
        task_manager.pending_check_task_ids.push(data.task_id);
        std::mem::drop(task_manager);
        cond_var.notify_one();
    }

    unsafe fn wake_by_ref(ptr: *const ()) {
        let data = ptr as *const Arc<Data>;
        let data = &*data;
        let (lock, cond_var) = &*data.task_manager;
        let mut task_manager = lock.lock().unwrap();
        task_manager.pending_check_task_ids.push(data.task_id);
        std::mem::drop(task_manager);
        cond_var.notify_one();
    }

    unsafe fn drop(ptr: *const ()) {
        Arc::from_raw(ptr as *const Data);
    }
}

用户评论