evering/
uring.rs

1#![doc = include_str!("uring.md")]
2
3use alloc::alloc::Layout;
4use core::fmt;
5use core::marker::PhantomData;
6use core::ptr::NonNull;
7use core::sync::atomic::{AtomicU32, Ordering};
8
9mod private {
10    pub trait Sealed {}
11}
12
13#[non_exhaustive]
14pub struct DisposeError {}
15
16impl fmt::Debug for DisposeError {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        f.debug_struct("DisposeError").finish_non_exhaustive()
19    }
20}
21
22impl fmt::Display for DisposeError {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.write_str("Uring is still connected")
25    }
26}
27
28impl core::error::Error for DisposeError {}
29
30pub trait Uring: private::Sealed {
31    type A;
32    type B;
33    type Ext;
34
35    fn header(&self) -> &Header<Self::Ext>;
36
37    fn sender(&self) -> Queue<Self::A>;
38
39    fn receiver(&self) -> Queue<Self::B>;
40
41    fn ext(&self) -> &Self::Ext
42    where
43        Self::Ext: Sync,
44    {
45        &self.header().ext
46    }
47
48    /// Returns `true` if the remote [`Uring`] is not dropped.
49    fn is_connected(&self) -> bool {
50        self.header().rc.load(Ordering::Relaxed) > 1
51    }
52
53    fn send(&mut self, val: Self::A) -> Result<(), Self::A> {
54        unsafe { self.sender().enqueue(val) }
55    }
56
57    fn send_bulk<I>(&mut self, vals: I) -> usize
58    where
59        I: Iterator<Item = Self::A>,
60    {
61        unsafe { self.sender().enqueue_bulk(vals) }
62    }
63
64    fn recv(&mut self) -> Option<Self::B> {
65        unsafe { self.receiver().dequeue() }
66    }
67
68    fn recv_bulk(&mut self) -> Drain<Self::B> {
69        unsafe { self.receiver().dequeue_bulk() }
70    }
71}
72
73pub enum UringEither<T, Ext = ()> {
74    A(UringA<T, T, Ext>),
75    B(UringB<T, T, Ext>),
76}
77
78impl<T, Ext> private::Sealed for UringEither<T, Ext> {}
79impl<T, Ext> Uring for UringEither<T, Ext> {
80    type A = T;
81    type B = T;
82    type Ext = Ext;
83
84    fn header(&self) -> &Header<Ext> {
85        match self {
86            UringEither::A(a) => a.header(),
87            UringEither::B(b) => b.header(),
88        }
89    }
90
91    fn sender(&self) -> Queue<T> {
92        match self {
93            UringEither::A(a) => a.sender(),
94            UringEither::B(b) => b.sender(),
95        }
96    }
97
98    fn receiver(&self) -> Queue<T> {
99        match self {
100            UringEither::A(a) => a.receiver(),
101            UringEither::B(b) => b.receiver(),
102        }
103    }
104}
105
106pub type Sender<Sqe, Rqe, Ext = ()> = UringA<Sqe, Rqe, Ext>;
107pub type Receiver<Sqe, Rqe, Ext = ()> = UringB<Sqe, Rqe, Ext>;
108
109pub struct UringA<A, B, Ext = ()>(RawUring<A, B, Ext>);
110pub struct UringB<A, B, Ext = ()>(RawUring<A, B, Ext>);
111
112unsafe impl<A: Send, B: Send, Ext: Send> Send for UringA<A, B, Ext> {}
113unsafe impl<A: Send, B: Send, Ext: Send> Send for UringB<A, B, Ext> {}
114
115macro_rules! common_methods {
116    ($A:ident, $B:ident, $Ext:ident) => {
117        pub fn into_raw(self) -> RawUring<A, B, Ext> {
118            let inner = RawUring {
119                header: self.0.header,
120                buf_a: self.0.buf_a,
121                buf_b: self.0.buf_b,
122                marker: PhantomData,
123            };
124            core::mem::forget(self);
125            inner
126        }
127
128        /// Drops this [`Uring`] and all enqueued entries.
129        ///
130        /// It does nothing and returns an error if `self` is still connected.
131        /// Otherwise, the returned [`RawUring`] is safe to deallocate without
132        /// synchronization.
133        pub fn dispose_raw(self) -> Result<RawUring<A, B, Ext>, DisposeError> {
134            let mut raw = self.into_raw();
135            unsafe {
136                match raw.dispose() {
137                    Ok(_) => Ok(raw),
138                    Err(e) => Err(e),
139                }
140            }
141        }
142
143        /// # Safety
144        ///
145        /// The specified [`RawUring`] must be a valid value returned from
146        /// [`into_raw`](Self::into_raw).
147        pub unsafe fn from_raw(uring: RawUring<A, B, Ext>) -> Self {
148            Self(uring)
149        }
150    };
151}
152
153impl<A, B, Ext> UringA<A, B, Ext> {
154    common_methods!(A, B, Ext);
155}
156
157impl<A, B, Ext> UringB<A, B, Ext> {
158    common_methods!(A, B, Ext);
159}
160
161impl<A, B, Ext> private::Sealed for UringA<A, B, Ext> {}
162impl<A, B, Ext> Uring for UringA<A, B, Ext> {
163    type A = A;
164    type B = B;
165    type Ext = Ext;
166
167    fn header(&self) -> &Header<Ext> {
168        unsafe { self.0.header() }
169    }
170    fn sender(&self) -> Queue<Self::A> {
171        unsafe { self.0.queue_a() }
172    }
173    fn receiver(&self) -> Queue<Self::B> {
174        unsafe { self.0.queue_b() }
175    }
176}
177
178impl<A, B, Ext> private::Sealed for UringB<A, B, Ext> {}
179impl<A, B, Ext> Uring for UringB<A, B, Ext> {
180    type A = B;
181    type B = A;
182    type Ext = Ext;
183
184    fn header(&self) -> &Header<Ext> {
185        unsafe { self.0.header() }
186    }
187    fn sender(&self) -> Queue<Self::A> {
188        unsafe { self.0.queue_b() }
189    }
190    fn receiver(&self) -> Queue<Self::B> {
191        unsafe { self.0.queue_a() }
192    }
193}
194
195impl<A, B, Ext> Drop for UringA<A, B, Ext> {
196    fn drop(&mut self) {
197        unsafe { self.0.drop_in_place() }
198    }
199}
200
201impl<A, B, Ext> Drop for UringB<A, B, Ext> {
202    fn drop(&mut self) {
203        unsafe { self.0.drop_in_place() }
204    }
205}
206
207pub struct Header<Ext = ()> {
208    off_a: Offsets,
209    off_b: Offsets,
210    rc: AtomicU32,
211    ext: Ext,
212}
213
214impl<Ext> Header<Ext> {
215    pub fn size_a(&self) -> usize {
216        self.off_a.ring_mask as usize + 1
217    }
218
219    pub fn size_b(&self) -> usize {
220        self.off_b.ring_mask as usize + 1
221    }
222}
223
224struct Offsets {
225    head: AtomicU32,
226    tail: AtomicU32,
227    ring_mask: u32,
228}
229
230impl Offsets {
231    fn new(size: u32) -> Self {
232        debug_assert!(size.is_power_of_two());
233        Self {
234            head: AtomicU32::new(0),
235            tail: AtomicU32::new(0),
236            ring_mask: size - 1,
237        }
238    }
239
240    fn inc(&self, n: u32) -> u32 {
241        n.wrapping_add(1) & self.ring_mask
242    }
243}
244
245pub struct RawUring<A, B, Ext = ()> {
246    pub header: NonNull<Header<Ext>>,
247    pub buf_a: NonNull<A>,
248    pub buf_b: NonNull<B>,
249    marker: PhantomData<fn(A, B, Ext) -> (A, B, Ext)>,
250}
251
252impl<A, B, Ext> RawUring<A, B, Ext> {
253    pub const fn dangling() -> Self {
254        Self {
255            header: NonNull::dangling(),
256            buf_a: NonNull::dangling(),
257            buf_b: NonNull::dangling(),
258            marker: PhantomData,
259        }
260    }
261
262    unsafe fn header(&self) -> &Header<Ext> {
263        unsafe { self.header.as_ref() }
264    }
265
266    unsafe fn queue_a(&self) -> Queue<'_, A> {
267        Queue {
268            off: unsafe { &self.header().off_a },
269            buf: self.buf_a,
270        }
271    }
272
273    unsafe fn queue_b(&self) -> Queue<'_, B> {
274        Queue {
275            off: unsafe { &self.header().off_b },
276            buf: self.buf_b,
277        }
278    }
279
280    unsafe fn dispose(&mut self) -> Result<(), DisposeError> {
281        let rc = unsafe { &self.header().rc };
282        debug_assert!(rc.load(Ordering::Relaxed) >= 1);
283        // `Release` enforeces any use of the data to happen before here.
284        if rc.fetch_sub(1, Ordering::Release) != 1 {
285            return Err(DisposeError {});
286        }
287        // `Acquire` enforces the deletion of the data to happen after here.
288        core::sync::atomic::fence(Ordering::Acquire);
289
290        unsafe {
291            self.queue_a().drop_in_place();
292            self.queue_b().drop_in_place();
293        }
294        Ok(())
295    }
296
297    unsafe fn drop_in_place(&mut self) {
298        unsafe {
299            if self.dispose().is_ok() {
300                let h = self.header.as_ref();
301                dealloc_buffer(self.buf_a, h.off_a.ring_mask as usize + 1);
302                dealloc_buffer(self.buf_b, h.off_b.ring_mask as usize + 1);
303                dealloc(self.header);
304            }
305        }
306    }
307}
308
309pub struct Queue<'a, T> {
310    off: &'a Offsets,
311    buf: NonNull<T>,
312}
313
314impl<'a, T> Queue<'a, T> {
315    pub fn len(&self) -> usize {
316        let head = self.off.head.load(Ordering::Relaxed);
317        let tail = self.off.tail.load(Ordering::Relaxed);
318        (tail.wrapping_sub(head) & self.off.ring_mask) as usize
319    }
320
321    pub fn is_empty(&self) -> bool {
322        self.len() == 0
323    }
324
325    unsafe fn enqueue(&mut self, val: T) -> Result<(), T> {
326        let Self { off, buf } = self;
327        debug_assert!((off.ring_mask + 1).is_power_of_two());
328
329        let tail = off.tail.load(Ordering::Relaxed);
330        let head = off.head.load(Ordering::Acquire);
331
332        let next_tail = off.inc(tail);
333        if next_tail == head {
334            return Err(val);
335        }
336
337        unsafe { buf.add(tail as usize).write(val) };
338        off.tail.store(next_tail, Ordering::Release);
339
340        Ok(())
341    }
342
343    unsafe fn enqueue_bulk(&mut self, mut vals: impl Iterator<Item = T>) -> usize {
344        let Self { off, buf } = self;
345        debug_assert!((off.ring_mask + 1).is_power_of_two());
346
347        let mut tail = off.tail.load(Ordering::Relaxed);
348        let head = off.head.load(Ordering::Acquire);
349
350        let mut n = 0;
351        let mut next_tail;
352        loop {
353            next_tail = off.inc(tail);
354            if next_tail == head {
355                break;
356            }
357            let Some(val) = vals.next() else {
358                break;
359            };
360            unsafe { buf.add(tail as usize).write(val) };
361            off.tail.store(next_tail, Ordering::Release);
362            n += 1;
363            tail = next_tail;
364        }
365
366        n
367    }
368
369    unsafe fn dequeue(&mut self) -> Option<T> {
370        let Self { off, buf } = self;
371        debug_assert!((off.ring_mask + 1).is_power_of_two());
372
373        let head = off.head.load(Ordering::Relaxed);
374        let tail = off.tail.load(Ordering::Acquire);
375
376        if head == tail {
377            return None;
378        }
379        let next_head = off.inc(head);
380
381        let val = unsafe { buf.add(head as usize).read() };
382        off.head.store(next_head, Ordering::Release);
383
384        Some(val)
385    }
386
387    unsafe fn dequeue_bulk(&mut self) -> Drain<'a, T> {
388        let Self { off, buf } = self;
389        debug_assert!((off.ring_mask + 1).is_power_of_two());
390
391        let head = off.head.load(Ordering::Relaxed);
392        let tail = off.tail.load(Ordering::Acquire);
393
394        Drain {
395            off,
396            buf: *buf,
397            head,
398            tail,
399        }
400    }
401
402    unsafe fn drop_in_place(&mut self) {
403        debug_assert!((self.off.ring_mask + 1).is_power_of_two());
404        unsafe {
405            let mut head = self.off.head.as_ptr().read();
406            let tail = self.off.tail.as_ptr().read();
407            while head != tail {
408                self.buf.add(head as usize).drop_in_place();
409                head = self.off.inc(head);
410            }
411        }
412    }
413}
414
415pub struct Drain<'a, T> {
416    off: &'a Offsets,
417    buf: NonNull<T>,
418    head: u32,
419    tail: u32,
420}
421
422impl<T> Iterator for Drain<'_, T> {
423    type Item = T;
424    fn next(&mut self) -> Option<Self::Item> {
425        if self.head == self.tail {
426            return None;
427        }
428        let next_head = self.off.inc(self.head);
429        let val = unsafe { self.buf.add(self.head as usize).read() };
430        self.off.head.store(next_head, Ordering::Release);
431        self.head = next_head;
432        Some(val)
433    }
434}
435
436pub struct Builder<A, B, Ext = ()> {
437    size_a: usize,
438    size_b: usize,
439    ext: Ext,
440    marker: PhantomData<(A, B)>,
441}
442
443impl<A, B, Ext> Builder<A, B, Ext> {
444    pub fn new() -> Self
445    where
446        Ext: Default,
447    {
448        Self::new_ext(Ext::default())
449    }
450
451    pub fn new_ext(ext: Ext) -> Self {
452        Self {
453            size_a: 32,
454            size_b: 32,
455            ext,
456            marker: PhantomData,
457        }
458    }
459
460    pub fn size_a(&mut self, size: usize) -> &mut Self {
461        assert!(size.is_power_of_two());
462        self.size_a = size;
463        self
464    }
465
466    pub fn size_b(&mut self, size: usize) -> &mut Self {
467        assert!(size.is_power_of_two());
468        self.size_b = size;
469        self
470    }
471
472    pub fn build_header(self) -> Header<Ext> {
473        Header {
474            off_a: Offsets::new(self.size_a as u32),
475            off_b: Offsets::new(self.size_b as u32),
476            rc: AtomicU32::new(2),
477            ext: self.ext,
478        }
479    }
480
481    pub fn build(self) -> (UringA<A, B, Ext>, UringB<A, B, Ext>) {
482        let header;
483        let buf_a;
484        let buf_b;
485
486        unsafe {
487            header = alloc::<Header<Ext>>();
488            buf_a = alloc_buffer(self.size_a);
489            buf_b = alloc_buffer(self.size_b);
490
491            header.write(self.build_header());
492        }
493
494        let ring_a = UringA(RawUring {
495            header,
496            buf_a,
497            buf_b,
498            marker: PhantomData,
499        });
500        let ring_b = UringB(RawUring {
501            header,
502            buf_a,
503            buf_b,
504            marker: PhantomData,
505        });
506
507        (ring_a, ring_b)
508    }
509}
510
511impl<A, B, Ext: Default> Default for Builder<A, B, Ext> {
512    fn default() -> Self {
513        Self::new()
514    }
515}
516
517unsafe fn alloc_buffer<T>(size: usize) -> NonNull<T> {
518    let layout = Layout::array::<T>(size).unwrap();
519    NonNull::new(unsafe { alloc::alloc::alloc(layout) })
520        .unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout))
521        .cast()
522}
523
524unsafe fn alloc<T>() -> NonNull<T> {
525    let layout = Layout::new::<T>();
526    NonNull::new(unsafe { alloc::alloc::alloc(layout) })
527        .unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout))
528        .cast()
529}
530
531unsafe fn dealloc_buffer<T>(ptr: NonNull<T>, size: usize) {
532    let layout = Layout::array::<T>(size).unwrap();
533    unsafe { alloc::alloc::dealloc(ptr.as_ptr().cast(), layout) }
534}
535
536unsafe fn dealloc<T>(ptr: NonNull<T>) {
537    let layout = Layout::new::<T>();
538    unsafe { alloc::alloc::dealloc(ptr.as_ptr().cast(), layout) }
539}
540
541#[cfg(test)]
542mod tests {
543    use std::sync::atomic::{AtomicBool, AtomicUsize};
544
545    use super::*;
546
547    #[test]
548    fn queue_len() {
549        let mut len_a = 0;
550        let mut len_b = 0;
551        let (mut pa, mut pb) = Builder::<(), ()>::new().build();
552        for _ in 0..32 {
553            match fastrand::u8(0..4) {
554                0 => len_a += pa.send(()).map_or(0, |_| 1),
555                1 => len_b += pb.send(()).map_or(0, |_| 1),
556                2 => len_a -= pb.recv().map_or(0, |_| 1),
557                3 => len_b -= pa.recv().map_or(0, |_| 1),
558                _ => unreachable!(),
559            }
560            assert_eq!(pa.sender().len(), pb.receiver().len());
561            assert_eq!(pa.receiver().len(), pb.sender().len());
562            assert_eq!(pa.sender().len(), len_a);
563            assert_eq!(pb.sender().len(), len_b);
564        }
565    }
566
567    #[test]
568    fn uring_drop() {
569        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
570
571        #[derive(Debug)]
572        struct DropCounter(char);
573        impl Drop for DropCounter {
574            fn drop(&mut self) {
575                DROP_COUNT.fetch_add(1, Ordering::Relaxed);
576            }
577        }
578
579        let input = std::iter::repeat_with(fastrand::alphabetic)
580            .take(30)
581            .collect::<Vec<_>>();
582
583        let (mut pa, mut pb) = Builder::<DropCounter, DropCounter>::new().build();
584        std::thread::scope(|cx| {
585            cx.spawn(|| {
586                for i in input.iter().copied().map(DropCounter) {
587                    if i.0.is_uppercase() {
588                        pa.send(i).unwrap();
589                    } else {
590                        _ = pa.recv();
591                    }
592                }
593                drop(pa);
594            });
595            cx.spawn(|| {
596                for i in input.iter().copied().map(DropCounter) {
597                    if i.0.is_lowercase() {
598                        pb.send(i).unwrap();
599                    } else {
600                        _ = pb.recv();
601                    }
602                }
603                drop(pb);
604            });
605        });
606
607        assert_eq!(DROP_COUNT.load(Ordering::Relaxed), input.len() * 2);
608    }
609
610    #[test]
611    fn uring_threaded() {
612        let input = std::iter::repeat_with(fastrand::alphabetic)
613            .take(30)
614            .collect::<Vec<_>>();
615
616        let (mut pa, mut pb) = Builder::<char, char>::new().build();
617        let (pa_finished, pb_finished) = (AtomicBool::new(false), AtomicBool::new(false));
618        std::thread::scope(|cx| {
619            cx.spawn(|| {
620                let mut r = vec![];
621                for i in input.iter().copied() {
622                    pa.send(i).unwrap();
623                    while let Some(i) = pa.recv() {
624                        r.push(i);
625                    }
626                }
627                pa_finished.store(true, Ordering::Release);
628                while !pb_finished.load(Ordering::Acquire) {
629                    std::thread::yield_now();
630                }
631                while let Some(i) = pa.recv() {
632                    r.push(i);
633                }
634                assert_eq!(r, input);
635            });
636            cx.spawn(|| {
637                let mut r = vec![];
638                for i in input.iter().copied() {
639                    pb.send(i).unwrap();
640                    while let Some(i) = pb.recv() {
641                        r.push(i);
642                    }
643                }
644                pb_finished.store(true, Ordering::Release);
645                while !pa_finished.load(Ordering::Acquire) {
646                    std::thread::yield_now();
647                }
648                while let Some(i) = pb.recv() {
649                    r.push(i);
650                }
651                assert_eq!(r, input);
652            });
653        });
654    }
655
656    #[test]
657    fn uring_threaded_bulk() {
658        let input = std::iter::repeat_with(fastrand::alphabetic)
659            .take(30)
660            .collect::<Vec<_>>();
661
662        let (mut pa, mut pb) = Builder::<char, char>::new().build();
663        let (pa_finished, pb_finished) = (AtomicBool::new(false), AtomicBool::new(false));
664        std::thread::scope(|cx| {
665            cx.spawn(|| {
666                let mut r = vec![];
667                pa.send_bulk(input.iter().copied());
668                pa_finished.store(true, Ordering::Release);
669                while !pb_finished.load(Ordering::Acquire) {
670                    r.extend(pa.recv_bulk());
671                    std::thread::yield_now();
672                }
673                r.extend(pa.recv_bulk());
674                assert_eq!(r, input);
675            });
676            cx.spawn(|| {
677                let mut r = vec![];
678                pb.send_bulk(input.iter().copied());
679                pb_finished.store(true, Ordering::Release);
680                while !pa_finished.load(Ordering::Acquire) {
681                    r.extend(pb.recv_bulk());
682                    std::thread::yield_now();
683                }
684                r.extend(pb.recv_bulk());
685                assert_eq!(r, input);
686            });
687        });
688    }
689}