1use alloc::rc::Rc;
2use alloc::task::LocalWake;
3use core::any::Any;
4use core::cell::{RefCell, RefMut};
5use core::marker::PhantomData;
6use core::pin::Pin;
7use core::task::{Context, LocalWaker, Poll};
8
9use crate::executor::ExecutorHandle;
10
11pub struct Task<T> {
12 inner: TaskRef,
13 marker: PhantomData<T>,
14}
15
16impl<T> Task<T> {
17 pub(crate) fn new<Ex>(executor: Ex, fut: impl 'static + Future<Output = T>) -> Self
18 where
19 T: 'static,
20 Ex: ExecutorHandle,
21 {
22 let task = WakeableTaskImpl {
23 task: RefCell::new(TaskImpl::Pending { fut, waker: None }),
24 executor,
25 };
26 Self {
27 inner: TaskRef(Rc::pin(task)),
28 marker: PhantomData,
29 }
30 }
31
32 pub(crate) fn inner(&self) -> TaskRef {
33 self.inner.clone()
34 }
35
36 pub fn abort(self) {
37 self.inner.0.as_ref().abort();
38 }
39}
40
41impl<T: 'static> Future for Task<T> {
42 type Output = T;
43
44 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
45 let mut output = Poll::Pending;
48 self.inner.0.as_ref().read(cx.local_waker(), &mut output);
49 output
50 }
51}
52
53#[derive(Clone)]
54pub(crate) struct TaskRef(Pin<Rc<dyn WakeableTask>>);
55
56impl TaskRef {
57 pub(crate) fn poll_wakeable(&self) -> Poll<()> {
58 use core::task::{ContextBuilder, Waker};
59 let waker = self.0.clone().waker();
60 let mut cx = ContextBuilder::from_waker(Waker::noop())
61 .local_waker(&waker)
62 .build();
63 self.0.as_ref().poll(&mut cx)
64 }
65}
66
67trait WakeableTask {
68 fn abort(self: Pin<&Self>);
69 fn poll(self: Pin<&Self>, cx: &mut Context) -> Poll<()>;
70 fn read(self: Pin<&Self>, waker: &LocalWaker, output: &mut dyn Any);
71 fn waker(self: Pin<Rc<Self>>) -> LocalWaker;
72}
73
74struct WakeableTaskImpl<T, Ex> {
75 task: RefCell<T>,
76 executor: Ex,
77}
78
79impl<T, Ex> WakeableTaskImpl<T, Ex> {
80 fn exclusive_access(self: Pin<&Self>) -> Pin<RefMut<T>> {
81 unsafe { Pin::new_unchecked(self.get_ref().task.borrow_mut()) }
85 }
86}
87
88impl<T, Ex> WakeableTask for WakeableTaskImpl<T, Ex>
89where
90 T: AnyTask,
91 Ex: ExecutorHandle,
92{
93 fn abort(self: Pin<&Self>) {
94 self.exclusive_access().as_mut().abort()
95 }
96 fn poll(self: Pin<&Self>, cx: &mut Context) -> Poll<()> {
97 self.exclusive_access().as_mut().poll(cx)
98 }
99 fn read(self: Pin<&Self>, waker: &LocalWaker, output: &mut dyn Any) {
100 self.exclusive_access().as_mut().read(waker, output)
101 }
102 fn waker(self: Pin<Rc<Self>>) -> LocalWaker {
103 LocalWaker::from(unsafe { Pin::into_inner_unchecked(self) })
106 }
107}
108
109impl<T, Ex> LocalWake for WakeableTaskImpl<T, Ex>
110where
111 T: AnyTask,
112 Ex: ExecutorHandle,
113{
114 fn wake(self: Rc<Self>) {
115 self.executor
117 .get()
118 .wake(TaskRef(unsafe { Pin::new_unchecked(self) }))
119 }
120}
121
122trait AnyTask: 'static {
123 fn abort(self: Pin<&mut Self>);
124 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()>;
125 fn read(self: Pin<&mut Self>, waker: &LocalWaker, output: &mut dyn Any);
126}
127
128pin_project_lite::pin_project! {
129 #[project = TaskState]
130 enum TaskImpl<F: Future> {
131 Ready { val: Poll<F::Output> },
132 Pending { #[pin] fut: F, waker: Option<LocalWaker> },
133 }
134}
135
136impl<F> AnyTask for TaskImpl<F>
137where
138 F: 'static + Future,
139{
140 fn abort(mut self: Pin<&mut Self>) {
141 self.set(Self::Ready { val: Poll::Pending });
142 }
143
144 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
145 match self.as_mut().project() {
146 TaskState::Ready { .. } => Poll::Ready(()),
147 TaskState::Pending { fut, waker } => {
148 let val = fut.poll(cx);
149 if val.is_pending() {
150 return Poll::Pending;
151 }
152 let waker = waker.take();
153 self.set(Self::Ready { val });
154 _ = waker.map(LocalWaker::wake);
155 Poll::Ready(())
156 },
157 }
158 }
159
160 fn read(mut self: Pin<&mut Self>, waker: &LocalWaker, output: &mut dyn Any) {
161 match self.as_mut().project() {
162 TaskState::Ready { val } => {
163 let output = output.downcast_mut().expect("invalid task state");
164 core::mem::swap(val, output)
165 },
166 TaskState::Pending { waker: Some(w), .. } if !w.will_wake(waker) => *w = waker.clone(),
167 TaskState::Pending { waker: w, .. } => *w = Some(waker.clone()),
168 }
169 }
170}