evering/
driver.rs

1#![doc = include_str!("driver.md")]
2
3use core::cell::RefCell;
4use core::mem;
5use core::task::{Context, LocalWaker, Poll};
6
7use slab::Slab;
8
9use crate::op::Cancellation;
10
11#[derive(Clone, Copy, Debug)]
12pub struct OpId(usize);
13
14pub struct Driver<P, Ext = ()>(RefCell<DriverInner<P, Ext>>);
15
16struct DriverInner<P, Ext> {
17    ops: Slab<RawOp<P, Ext>>,
18}
19
20struct RawOp<P, Ext> {
21    state: Lifecycle<P>,
22    ext: Ext,
23}
24
25enum Lifecycle<P> {
26    Submitted,
27    Waiting(LocalWaker),
28    Completed(P),
29    Cancelled(#[allow(dead_code)] Cancellation),
30}
31
32impl<P, Ext> Driver<P, Ext> {
33    pub const fn new() -> Self {
34        Self(RefCell::new(DriverInner { ops: Slab::new() }))
35    }
36
37    pub fn with_capacity(capacity: usize) -> Self {
38        Self(RefCell::new(DriverInner {
39            ops: Slab::with_capacity(capacity),
40        }))
41    }
42
43    pub fn len(&self) -> usize {
44        self.0.borrow().ops.len()
45    }
46
47    pub fn is_empty(&self) -> bool {
48        self.0.borrow().ops.is_empty()
49    }
50
51    pub fn contains(&self, id: OpId) -> bool {
52        self.0.borrow().ops.contains(id.0)
53    }
54
55    pub fn submit(&self) -> OpId
56    where
57        Ext: Default,
58    {
59        self.submit_ext(Ext::default())
60    }
61
62    pub fn submit_ext(&self, ext: Ext) -> OpId {
63        self.0.borrow_mut().submit(ext)
64    }
65
66    /// Submits an operation if there is sufficient spare capacity, otherwise an
67    /// error is returned with the element.
68    pub fn try_submit(&self) -> Result<OpId, Ext>
69    where
70        Ext: Default,
71    {
72        self.try_submit_ext(Ext::default())
73    }
74
75    pub fn try_submit_ext(&self, ext: Ext) -> Result<OpId, Ext> {
76        self.0.borrow_mut().try_submit(ext)
77    }
78
79    /// Completes a operation. It returns the given `payload` as an [`Err`] if
80    /// the specified operation has been cancelled.
81    ///
82    /// The given `id` is always recycled even if the corresponding operation is
83    /// cancelled.
84    pub fn complete(&self, id: OpId, payload: P) -> Result<(), P> {
85        self.0
86            .borrow_mut()
87            .complete(id, payload)
88            .map_err(|(p, _)| p)
89    }
90
91    /// Completes a operation with the submitted extension.
92    ///
93    /// For more information, see [`complete`](Self::complete).
94    pub fn complete_ext(&self, id: OpId, payload: P) -> Result<(), (P, Ext)> {
95        self.0.borrow_mut().complete(id, payload)
96    }
97
98    pub(crate) fn poll(&self, id: OpId, cx: &mut Context) -> Poll<(P, Ext)> {
99        self.0.borrow_mut().poll(id, cx)
100    }
101
102    pub(crate) fn remove(&self, id: OpId, mut callback: impl FnMut() -> Cancellation) {
103        self.0.borrow_mut().remove(id, &mut callback)
104    }
105}
106
107impl<P, Ext> Default for Driver<P, Ext>
108where
109    Ext: Default,
110{
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl<P, Ext> DriverInner<P, Ext> {
117    fn submit(&mut self, ext: Ext) -> OpId {
118        OpId(self.ops.insert(RawOp {
119            state: Lifecycle::Submitted,
120            ext,
121        }))
122    }
123
124    fn try_submit(&mut self, ext: Ext) -> Result<OpId, Ext> {
125        if self.ops.len() == self.ops.capacity() {
126            Err(ext)
127        } else {
128            Ok(self.submit(ext))
129        }
130    }
131
132    fn poll(&mut self, id: OpId, cx: &mut Context) -> Poll<(P, Ext)> {
133        let op = self.ops.get_mut(id.0).expect("invalid driver state");
134        match mem::replace(&mut op.state, Lifecycle::Submitted) {
135            Lifecycle::Submitted => {
136                op.state = Lifecycle::Waiting(cx.local_waker().clone());
137                Poll::Pending
138            },
139            Lifecycle::Waiting(waker) if !waker.will_wake(cx.local_waker()) => {
140                op.state = Lifecycle::Waiting(cx.local_waker().clone());
141                Poll::Pending
142            },
143            Lifecycle::Waiting(waker) => {
144                op.state = Lifecycle::Waiting(waker);
145                Poll::Pending
146            },
147            Lifecycle::Completed(payload) => {
148                // Remove this operation immediately if completed.
149                let op = self.ops.remove(id.0);
150                Poll::Ready((payload, op.ext))
151            },
152            Lifecycle::Cancelled(_) => unreachable!("invalid operation state"),
153        }
154    }
155
156    fn complete(&mut self, id: OpId, payload: P) -> Result<(), (P, Ext)> {
157        let op = self.ops.get_mut(id.0).expect("invalid driver state");
158        match mem::replace(&mut op.state, Lifecycle::Submitted) {
159            Lifecycle::Submitted => {
160                op.state = Lifecycle::Completed(payload);
161                Ok(())
162            },
163            Lifecycle::Waiting(waker) => {
164                op.state = Lifecycle::Completed(payload);
165                waker.wake();
166                Ok(())
167            },
168            Lifecycle::Completed(_) => unreachable!("invalid operation state"),
169            Lifecycle::Cancelled(_) => {
170                let op = self.ops.remove(id.0);
171                Err((payload, op.ext))
172            },
173        }
174    }
175
176    fn remove(&mut self, id: OpId, callback: &mut dyn FnMut() -> Cancellation) {
177        // The operation may have been removed inside `poll`.
178        let Some(op) = self.ops.get_mut(id.0) else {
179            return;
180        };
181        match mem::replace(&mut op.state, Lifecycle::Submitted) {
182            Lifecycle::Submitted | Lifecycle::Waiting(_) => {
183                op.state = Lifecycle::Cancelled(callback());
184            },
185            Lifecycle::Completed(_) => _ = self.ops.remove(id.0),
186            Lifecycle::Cancelled(_) => unreachable!("invalid operation state"),
187        }
188    }
189}
190
191impl<P, Ext> Drop for DriverInner<P, Ext> {
192    fn drop(&mut self) {
193        assert!(
194            self.ops
195                .iter()
196                .all(|(_, op)| matches!(op.state, Lifecycle::Completed(_))),
197            "all operations inside `Driver` must be completed before dropping"
198        );
199    }
200}
201
202pub trait DriverHandle: 'static + Unpin {
203    type Payload;
204    type Ext;
205    type Ref: core::ops::Deref<Target = Driver<Self::Payload, Self::Ext>>;
206
207    fn get(&self) -> Self::Ref;
208}
209impl<P, Ext> DriverHandle for alloc::rc::Weak<Driver<P, Ext>>
210where
211    P: 'static,
212    Ext: 'static,
213{
214    type Payload = P;
215    type Ext = Ext;
216    type Ref = alloc::rc::Rc<Driver<P, Ext>>;
217    fn get(&self) -> Self::Ref {
218        self.upgrade().expect("not inside a valid executor")
219    }
220}