1#![doc = include_str!("driver.md")]
2
3use core::cell::RefCell;
4use core::mem;
5use core::task::{Context, LocalWaker, Poll};
6
7use slab::Slab;
8
9use crate::op::Cancellation;
10
11#[derive(Clone, Copy, Debug)]
12pub struct OpId(usize);
13
14pub struct Driver<P, Ext = ()>(RefCell<DriverInner<P, Ext>>);
15
16struct DriverInner<P, Ext> {
17 ops: Slab<RawOp<P, Ext>>,
18}
19
20struct RawOp<P, Ext> {
21 state: Lifecycle<P>,
22 ext: Ext,
23}
24
25enum Lifecycle<P> {
26 Submitted,
27 Waiting(LocalWaker),
28 Completed(P),
29 Cancelled(#[allow(dead_code)] Cancellation),
30}
31
32impl<P, Ext> Driver<P, Ext> {
33 pub const fn new() -> Self {
34 Self(RefCell::new(DriverInner { ops: Slab::new() }))
35 }
36
37 pub fn with_capacity(capacity: usize) -> Self {
38 Self(RefCell::new(DriverInner {
39 ops: Slab::with_capacity(capacity),
40 }))
41 }
42
43 pub fn len(&self) -> usize {
44 self.0.borrow().ops.len()
45 }
46
47 pub fn is_empty(&self) -> bool {
48 self.0.borrow().ops.is_empty()
49 }
50
51 pub fn contains(&self, id: OpId) -> bool {
52 self.0.borrow().ops.contains(id.0)
53 }
54
55 pub fn submit(&self) -> OpId
56 where
57 Ext: Default,
58 {
59 self.submit_ext(Ext::default())
60 }
61
62 pub fn submit_ext(&self, ext: Ext) -> OpId {
63 self.0.borrow_mut().submit(ext)
64 }
65
66 pub fn try_submit(&self) -> Result<OpId, Ext>
69 where
70 Ext: Default,
71 {
72 self.try_submit_ext(Ext::default())
73 }
74
75 pub fn try_submit_ext(&self, ext: Ext) -> Result<OpId, Ext> {
76 self.0.borrow_mut().try_submit(ext)
77 }
78
79 pub fn complete(&self, id: OpId, payload: P) -> Result<(), P> {
85 self.0
86 .borrow_mut()
87 .complete(id, payload)
88 .map_err(|(p, _)| p)
89 }
90
91 pub fn complete_ext(&self, id: OpId, payload: P) -> Result<(), (P, Ext)> {
95 self.0.borrow_mut().complete(id, payload)
96 }
97
98 pub(crate) fn poll(&self, id: OpId, cx: &mut Context) -> Poll<(P, Ext)> {
99 self.0.borrow_mut().poll(id, cx)
100 }
101
102 pub(crate) fn remove(&self, id: OpId, mut callback: impl FnMut() -> Cancellation) {
103 self.0.borrow_mut().remove(id, &mut callback)
104 }
105}
106
107impl<P, Ext> Default for Driver<P, Ext>
108where
109 Ext: Default,
110{
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116impl<P, Ext> DriverInner<P, Ext> {
117 fn submit(&mut self, ext: Ext) -> OpId {
118 OpId(self.ops.insert(RawOp {
119 state: Lifecycle::Submitted,
120 ext,
121 }))
122 }
123
124 fn try_submit(&mut self, ext: Ext) -> Result<OpId, Ext> {
125 if self.ops.len() == self.ops.capacity() {
126 Err(ext)
127 } else {
128 Ok(self.submit(ext))
129 }
130 }
131
132 fn poll(&mut self, id: OpId, cx: &mut Context) -> Poll<(P, Ext)> {
133 let op = self.ops.get_mut(id.0).expect("invalid driver state");
134 match mem::replace(&mut op.state, Lifecycle::Submitted) {
135 Lifecycle::Submitted => {
136 op.state = Lifecycle::Waiting(cx.local_waker().clone());
137 Poll::Pending
138 },
139 Lifecycle::Waiting(waker) if !waker.will_wake(cx.local_waker()) => {
140 op.state = Lifecycle::Waiting(cx.local_waker().clone());
141 Poll::Pending
142 },
143 Lifecycle::Waiting(waker) => {
144 op.state = Lifecycle::Waiting(waker);
145 Poll::Pending
146 },
147 Lifecycle::Completed(payload) => {
148 let op = self.ops.remove(id.0);
150 Poll::Ready((payload, op.ext))
151 },
152 Lifecycle::Cancelled(_) => unreachable!("invalid operation state"),
153 }
154 }
155
156 fn complete(&mut self, id: OpId, payload: P) -> Result<(), (P, Ext)> {
157 let op = self.ops.get_mut(id.0).expect("invalid driver state");
158 match mem::replace(&mut op.state, Lifecycle::Submitted) {
159 Lifecycle::Submitted => {
160 op.state = Lifecycle::Completed(payload);
161 Ok(())
162 },
163 Lifecycle::Waiting(waker) => {
164 op.state = Lifecycle::Completed(payload);
165 waker.wake();
166 Ok(())
167 },
168 Lifecycle::Completed(_) => unreachable!("invalid operation state"),
169 Lifecycle::Cancelled(_) => {
170 let op = self.ops.remove(id.0);
171 Err((payload, op.ext))
172 },
173 }
174 }
175
176 fn remove(&mut self, id: OpId, callback: &mut dyn FnMut() -> Cancellation) {
177 let Some(op) = self.ops.get_mut(id.0) else {
179 return;
180 };
181 match mem::replace(&mut op.state, Lifecycle::Submitted) {
182 Lifecycle::Submitted | Lifecycle::Waiting(_) => {
183 op.state = Lifecycle::Cancelled(callback());
184 },
185 Lifecycle::Completed(_) => _ = self.ops.remove(id.0),
186 Lifecycle::Cancelled(_) => unreachable!("invalid operation state"),
187 }
188 }
189}
190
191impl<P, Ext> Drop for DriverInner<P, Ext> {
192 fn drop(&mut self) {
193 assert!(
194 self.ops
195 .iter()
196 .all(|(_, op)| matches!(op.state, Lifecycle::Completed(_))),
197 "all operations inside `Driver` must be completed before dropping"
198 );
199 }
200}
201
202pub trait DriverHandle: 'static + Unpin {
203 type Payload;
204 type Ext;
205 type Ref: core::ops::Deref<Target = Driver<Self::Payload, Self::Ext>>;
206
207 fn get(&self) -> Self::Ref;
208}
209impl<P, Ext> DriverHandle for alloc::rc::Weak<Driver<P, Ext>>
210where
211 P: 'static,
212 Ext: 'static,
213{
214 type Payload = P;
215 type Ext = Ext;
216 type Ref = alloc::rc::Rc<Driver<P, Ext>>;
217 fn get(&self) -> Self::Ref {
218 self.upgrade().expect("not inside a valid executor")
219 }
220}