local_executor/
task.rs

1use alloc::rc::Rc;
2use alloc::task::LocalWake;
3use core::any::Any;
4use core::cell::{RefCell, RefMut};
5use core::marker::PhantomData;
6use core::pin::Pin;
7use core::task::{Context, LocalWaker, Poll};
8
9use crate::executor::ExecutorHandle;
10
11pub struct Task<T> {
12    inner: TaskRef,
13    marker: PhantomData<T>,
14}
15
16impl<T> Task<T> {
17    pub(crate) fn new<Ex>(executor: Ex, fut: impl 'static + Future<Output = T>) -> Self
18    where
19        T: 'static,
20        Ex: ExecutorHandle,
21    {
22        let task = WakeableTaskImpl {
23            task: RefCell::new(TaskImpl::Pending { fut, waker: None }),
24            executor,
25        };
26        Self {
27            inner: TaskRef(Rc::pin(task)),
28            marker: PhantomData,
29        }
30    }
31
32    pub(crate) fn inner(&self) -> TaskRef {
33        self.inner.clone()
34    }
35
36    pub fn abort(self) {
37        self.inner.0.as_ref().abort();
38    }
39}
40
41impl<T: 'static> Future for Task<T> {
42    type Output = T;
43
44    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
45        // This `Future` will remain pending until the corresponding task is
46        // ready and wake it.
47        let mut output = Poll::Pending;
48        self.inner.0.as_ref().read(cx.local_waker(), &mut output);
49        output
50    }
51}
52
53#[derive(Clone)]
54pub(crate) struct TaskRef(Pin<Rc<dyn WakeableTask>>);
55
56impl TaskRef {
57    pub(crate) fn poll_wakeable(&self) -> Poll<()> {
58        use core::task::{ContextBuilder, Waker};
59        let waker = self.0.clone().waker();
60        let mut cx = ContextBuilder::from_waker(Waker::noop())
61            .local_waker(&waker)
62            .build();
63        self.0.as_ref().poll(&mut cx)
64    }
65}
66
67trait WakeableTask {
68    fn abort(self: Pin<&Self>);
69    fn poll(self: Pin<&Self>, cx: &mut Context) -> Poll<()>;
70    fn read(self: Pin<&Self>, waker: &LocalWaker, output: &mut dyn Any);
71    fn waker(self: Pin<Rc<Self>>) -> LocalWaker;
72}
73
74struct WakeableTaskImpl<T, Ex> {
75    task: RefCell<T>,
76    executor: Ex,
77}
78
79impl<T, Ex> WakeableTaskImpl<T, Ex> {
80    fn exclusive_access(self: Pin<&Self>) -> Pin<RefMut<T>> {
81        // SAFETY: This is a projection from `Pin<&RefCell>` to `Pin<RefMut>`.
82        // It's safe because this method is the only way to grant access to the
83        // underlying value, and the returned pointers are always pinned.
84        unsafe { Pin::new_unchecked(self.get_ref().task.borrow_mut()) }
85    }
86}
87
88impl<T, Ex> WakeableTask for WakeableTaskImpl<T, Ex>
89where
90    T: AnyTask,
91    Ex: ExecutorHandle,
92{
93    fn abort(self: Pin<&Self>) {
94        self.exclusive_access().as_mut().abort()
95    }
96    fn poll(self: Pin<&Self>, cx: &mut Context) -> Poll<()> {
97        self.exclusive_access().as_mut().poll(cx)
98    }
99    fn read(self: Pin<&Self>, waker: &LocalWaker, output: &mut dyn Any) {
100        self.exclusive_access().as_mut().read(waker, output)
101    }
102    fn waker(self: Pin<Rc<Self>>) -> LocalWaker {
103        // SAFETY: The pointer is temporarily unpinned to satisfy the signature,
104        // and then we immediately pin it back inside `LocalWake::wake`.
105        LocalWaker::from(unsafe { Pin::into_inner_unchecked(self) })
106    }
107}
108
109impl<T, Ex> LocalWake for WakeableTaskImpl<T, Ex>
110where
111    T: AnyTask,
112    Ex: ExecutorHandle,
113{
114    fn wake(self: Rc<Self>) {
115        // SAFETY: See the comments above.
116        self.executor
117            .get()
118            .wake(TaskRef(unsafe { Pin::new_unchecked(self) }))
119    }
120}
121
122trait AnyTask: 'static {
123    fn abort(self: Pin<&mut Self>);
124    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()>;
125    fn read(self: Pin<&mut Self>, waker: &LocalWaker, output: &mut dyn Any);
126}
127
128pin_project_lite::pin_project! {
129    #[project = TaskState]
130    enum TaskImpl<F: Future> {
131        Ready { val: Poll<F::Output> },
132        Pending { #[pin] fut: F, waker: Option<LocalWaker> },
133    }
134}
135
136impl<F> AnyTask for TaskImpl<F>
137where
138    F: 'static + Future,
139{
140    fn abort(mut self: Pin<&mut Self>) {
141        self.set(Self::Ready { val: Poll::Pending });
142    }
143
144    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
145        match self.as_mut().project() {
146            TaskState::Ready { .. } => Poll::Ready(()),
147            TaskState::Pending { fut, waker } => {
148                let val = fut.poll(cx);
149                if val.is_pending() {
150                    return Poll::Pending;
151                }
152                let waker = waker.take();
153                self.set(Self::Ready { val });
154                _ = waker.map(LocalWaker::wake);
155                Poll::Ready(())
156            },
157        }
158    }
159
160    fn read(mut self: Pin<&mut Self>, waker: &LocalWaker, output: &mut dyn Any) {
161        match self.as_mut().project() {
162            TaskState::Ready { val } => {
163                let output = output.downcast_mut().expect("invalid task state");
164                core::mem::swap(val, output)
165            },
166            TaskState::Pending { waker: Some(w), .. } if !w.will_wake(waker) => *w = waker.clone(),
167            TaskState::Pending { waker: w, .. } => *w = Some(waker.clone()),
168        }
169    }
170}