maitake_sync/
wait_map.rs

1//! A map of [`Waker`]s associated with keys, so that a task can be woken by
2//! key.
3//!
4//! See the documentation for the [`WaitMap`] type for details.
5use crate::{
6    blocking::{DefaultMutex, Mutex, ScopedRawMutex},
7    loom::{
8        cell::UnsafeCell,
9        sync::atomic::{AtomicUsize, Ordering::*},
10    },
11    util::{fmt, CachePadded, WakeBatch},
12};
13use cordyceps::{
14    list::{self, List},
15    Linked,
16};
17use core::{
18    fmt::Debug,
19    future::Future,
20    marker::PhantomPinned,
21    mem,
22    pin::Pin,
23    ptr::{self, NonNull},
24    task::{Context, Poll, Waker},
25};
26use mycelium_bitfield::{enum_from_bits, FromBits};
27use pin_project::{pin_project, pinned_drop};
28
29#[cfg(test)]
30mod tests;
31
32/// Errors returned by [`WaitMap::wait`], indicating a failed wake.
33#[derive(Copy, Clone, Debug, Eq, PartialEq)]
34#[non_exhaustive]
35pub enum WaitError {
36    /// The [`WaitMap`] has already been [closed].
37    ///
38    /// [closed]: WaitMap::close
39    Closed,
40
41    /// The received data has already been extracted
42    AlreadyConsumed,
43
44    /// The [`Wait`] was never added to the [`WaitMap`]
45    NeverAdded,
46
47    /// The [`WaitMap`] already had an item matching the given
48    /// key
49    Duplicate,
50}
51
52/// The result of a call to [`WaitMap::wait()`].
53pub type WaitResult<T> = Result<T, WaitError>;
54
55const fn closed<T>() -> Poll<WaitResult<T>> {
56    Poll::Ready(Err(WaitError::Closed))
57}
58
59const fn consumed<T>() -> Poll<WaitResult<T>> {
60    Poll::Ready(Err(WaitError::AlreadyConsumed))
61}
62
63const fn never_added<T>() -> Poll<WaitResult<T>> {
64    Poll::Ready(Err(WaitError::NeverAdded))
65}
66
67const fn duplicate<T>() -> Poll<WaitResult<T>> {
68    Poll::Ready(Err(WaitError::Duplicate))
69}
70
71const fn notified<T>(data: T) -> Poll<WaitResult<T>> {
72    Poll::Ready(Ok(data))
73}
74
75/// A map of [`Waker`]s associated with keys, allowing tasks to be woken by
76/// their key.
77///
78/// A `WaitMap` allows any number of tasks to [wait] asynchronously and be
79/// woken when a value with a certain key arrives. This can be used to
80/// implement structures like "async mailboxes", where an async function
81/// requests some data (such as a response) associated with a certain
82/// key (such as a message ID). When the data is received, the key can
83/// be used to provide the task with the desired data, as well as wake
84/// the task for further processing.
85///
86/// # Overriding the blocking mutex
87///
88/// This type uses a [blocking `Mutex`](crate::blocking::Mutex) internally to
89/// synchronize access to its wait list. By default, this is a [`DefaultMutex`]. To
90/// use an alternative [`ScopedRawMutex`] implementation, use the
91/// [`new_with_raw_mutex`](Self::new_with_raw_mutex) constructor. See [the documentation
92/// on overriding mutex
93/// implementations](crate::blocking#overriding-mutex-implementations) for more
94/// details.
95///
96/// # Examples
97///
98/// Waking a single task at a time by calling [`wake`][wake]:
99///
100/// ```ignore
101/// use std::sync::Arc;
102/// use maitake::scheduler;
103/// use maitake_sync::wait_map::{WaitMap, WakeOutcome};
104///
105/// const TASKS: usize = 10;
106///
107/// // In order to spawn tasks, we need a `Scheduler` instance.
108/// let scheduler = Scheduler::new();
109///
110/// // Construct a new `WaitMap`.
111/// let q = Arc::new(WaitMap::new());
112///
113/// // Spawn some tasks that will wait on the queue.
114/// // We'll use the task index (0..10) as the key.
115/// for i in 0..TASKS {
116///     let q = q.clone();
117///     scheduler.spawn(async move {
118///         let val = q.wait(i).await.unwrap();
119///         assert_eq!(val, i + 100);
120///     });
121/// }
122///
123/// // Tick the scheduler once.
124/// let tick = scheduler.tick();
125///
126/// // No tasks should complete on this tick, as they are all waiting
127/// // to be woken by the queue.
128/// assert_eq!(tick.completed, 0, "no tasks have been woken");
129///
130/// // We now wake each of the tasks, using the same key (0..10),
131/// // and provide them with a value that is their `key + 100`,
132/// // e.g. 100..110. Only the task that has been woken will be
133/// // notified.
134/// for i in 0..TASKS {
135///     let result = q.wake(&i, i + 100);
136///     assert!(matches!(result, WakeOutcome::Woke));
137///
138///     // Tick the scheduler.
139///     let tick = scheduler.tick();
140///
141///     // Exactly one task should have completed
142///     assert_eq!(tick.completed, 1);
143/// }
144///
145/// // Tick the scheduler.
146/// let tick = scheduler.tick();
147///
148/// // No additional tasks should be completed
149/// assert_eq!(tick.completed, 0);
150/// assert!(!tick.has_remaining);
151/// ```
152///
153/// # Implementation Notes
154///
155/// This type is currently implemented using [intrusive doubly-linked
156/// list][ilist].
157///
158/// The *[intrusive]* aspect of this map is important, as it means that it does
159/// not allocate memory. Instead, nodes in the linked list are stored in the
160/// futures of tasks trying to wait for capacity. This means that it is not
161/// necessary to allocate any heap memory for each task waiting to be woken.
162///
163/// However, the intrusive linked list introduces one new danger: because
164/// futures can be *cancelled*, and the linked list nodes live within the
165/// futures trying to wait on the queue, we *must* ensure that the node
166/// is unlinked from the list before dropping a cancelled future. Failure to do
167/// so would result in the list containing dangling pointers. Therefore, we must
168/// use a *doubly-linked* list, so that nodes can edit both the previous and
169/// next node when they have to remove themselves. This is kind of a bummer, as
170/// it means we can't use something nice like this [intrusive queue by Dmitry
171/// Vyukov][2], and there are not really practical designs for lock-free
172/// doubly-linked lists that don't rely on some kind of deferred reclamation
173/// scheme such as hazard pointers or QSBR.
174///
175/// Instead, we just stick a [`Mutex`] around the linked list, which must be
176/// acquired to pop nodes from it, or for nodes to remove themselves when
177/// futures are cancelled. This is a bit sad, but the critical sections for this
178/// mutex are short enough that we still get pretty good performance despite it.
179///
180/// [`Waker`]: core::task::Waker
181/// [wait]: WaitMap::wait
182/// [wake]: WaitMap::wake
183/// [`UnsafeCell`]: core::cell::UnsafeCell
184/// [ilist]: cordyceps::List
185/// [intrusive]: https://fuchsia.dev/fuchsia-src/development/languages/c-cpp/fbl_containers_guide/introduction
186/// [2]: https://www.1024cores.net/home/lock-free-algorithms/queues/intrusive-mpsc-node-based-queue
187pub struct WaitMap<K: PartialEq, V, Lock: ScopedRawMutex = DefaultMutex> {
188    /// The wait queue's state variable.
189    state: CachePadded<AtomicUsize>,
190
191    /// The linked list of waiters.
192    ///
193    /// # Safety
194    ///
195    /// This is protected by a mutex; the mutex *must* be acquired when
196    /// manipulating the linked list, OR when manipulating waiter nodes that may
197    /// be linked into the list. If a node is known to not be linked, it is safe
198    /// to modify that node (such as by waking the stored [`Waker`]) without
199    /// holding the lock; otherwise, it may be modified through the list, so the
200    /// lock must be held when modifying the
201    /// node.
202    ///
203    /// A spinlock (from `mycelium_util`) is used here, in order to support
204    /// `no_std` platforms; when running `loom` tests, a `loom` mutex is used
205    /// instead to simulate the spinlock, because loom doesn't play nice with
206    /// real spinlocks.
207    queue: Mutex<List<Waiter<K, V>>, Lock>,
208}
209
210impl<K, V, Lock> Debug for WaitMap<K, V, Lock>
211where
212    K: PartialEq,
213    Lock: ScopedRawMutex,
214{
215    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
216        f.debug_struct("WaitMap")
217            .field("state", &self.state)
218            .field("queue", &self.queue)
219            .finish()
220    }
221}
222
223/// Future returned from [`WaitMap::wait()`].
224///
225/// This future is fused, so once it has completed, any future calls to poll
226/// will immediately return [`Poll::Ready`].
227///
228/// # Notes
229///
230/// This future is `!Unpin`, as it is unsafe to [`core::mem::forget`] a
231/// `Wait` future once it has been polled. For instance, the following code
232/// must not compile:
233///
234///```compile_fail
235/// use maitake_sync::wait_map::Wait;
236///
237/// // Calls to this function should only compile if `T` is `Unpin`.
238/// fn assert_unpin<T: Unpin>() {}
239///
240/// assert_unpin::<Wait<'_, usize, ()>>();
241/// ```
242#[derive(Debug)]
243#[pin_project(PinnedDrop)]
244#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
245pub struct Wait<'a, K: PartialEq, V, Lock: ScopedRawMutex = DefaultMutex> {
246    /// The [`WaitMap`] being waited on from.
247    queue: &'a WaitMap<K, V, Lock>,
248
249    /// Entry in the wait queue linked list.
250    #[pin]
251    waiter: Waiter<K, V>,
252}
253
254impl<'map, 'wait, K: PartialEq, V, Lock: ScopedRawMutex> Wait<'map, K, V, Lock> {
255    /// Returns a future that completes when the `Wait` item has been
256    /// added to the [`WaitMap`], and is ready to receive data
257    ///
258    /// This is useful for ensuring that a receiver is ready before
259    /// sending a message that will elicit the expected response.
260    ///
261    /// # Example
262    ///
263    /// ```ignore
264    /// use std::sync::Arc;
265    /// use maitake::scheduler;
266    /// use maitake_sync::wait_map::{WaitMap, WakeOutcome};
267    /// use futures_util::pin_mut;
268    ///
269    /// let scheduler = Scheduler::new();
270    /// let q = Arc::new(WaitMap::new());
271    ///
272    /// let q2 = q.clone();
273    /// scheduler.spawn(async move {
274    ///     let wait = q2.wait(0);
275    ///
276    ///     // At this point, we have created the future, but it has not yet
277    ///     // been added to the queue. We could immediately await 'wait',
278    ///     // but then we would be unable to progress further. We must
279    ///     // first pin the `wait` future, to ensure that it does not move
280    ///     // until it has been completed.
281    ///     pin_mut!(wait);
282    ///     wait.as_mut().subscribe().await.unwrap();
283    ///
284    ///     // We now know the waiter has been enqueued, at this point we could
285    ///     // send a message that will cause key == 0 to be returned, without
286    ///     // worrying about racing with the expected response, e.g:
287    ///     //
288    ///     // sender.send_with_id(0, SomeMessage).await?;
289    ///     //
290    ///     let val = wait.await.unwrap();
291    ///     assert_eq!(val, 10);
292    /// });
293    ///
294    /// assert!(matches!(q.wake(&0, 100), WakeOutcome::NoMatch(_)));
295    ///
296    /// let tick = scheduler.tick();
297    ///
298    /// assert!(matches!(q.wake(&0, 100), WakeOutcome::Woke));
299    /// ```
300    pub fn subscribe(self: Pin<&'wait mut Self>) -> Subscribe<'wait, 'map, K, V, Lock> {
301        Subscribe { wait: self }
302    }
303
304    /// Deprecated alias for [`Wait::subscribe`]. See that method for details.
305    #[deprecated(
306        since = "0.1.3",
307        note = "renamed to `subscribe` for consistency, use that instead"
308    )]
309    #[allow(deprecated)] // let us use the deprecated type alias
310    pub fn enqueue(self: Pin<&'wait mut Self>) -> EnqueueWait<'wait, 'map, K, V, Lock> {
311        self.subscribe()
312    }
313}
314
315/// A waiter node which may be linked into a wait queue.
316#[pin_project]
317struct Waiter<K: PartialEq, V> {
318    /// The intrusive linked list node.
319    #[pin]
320    node: UnsafeCell<Node<K, V>>,
321
322    /// The future's state.
323    state: WaitState,
324
325    key: K,
326}
327
328impl<K: PartialEq, V> Debug for Waiter<K, V> {
329    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
330        f.debug_struct("Waiter")
331            .field("node", &self.node)
332            .field("state", &self.state)
333            .field("key", &fmt::display(core::any::type_name::<K>()))
334            .field("val", &fmt::display(core::any::type_name::<V>()))
335            .finish()
336    }
337}
338
339#[repr(C)]
340struct Node<K: PartialEq, V> {
341    /// Intrusive linked list pointers.
342    ///
343    /// # Safety
344    ///
345    /// This *must* be the first field in the struct in order for the `Linked`
346    /// impl to be sound.
347    links: list::Links<Waiter<K, V>>,
348
349    /// The node's waker, if it has yet to be woken, or the data assigned to the
350    /// node, if it has been woken.
351    waker: Wakeup<V>,
352
353    // This type is !Unpin due to the heuristic from:
354    // <https://github.com/rust-lang/rust/pull/82834>
355    _pin: PhantomPinned,
356}
357
358impl<K: PartialEq, V> Debug for Node<K, V> {
359    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
360        f.debug_struct("Node")
361            .field("links", &self.links)
362            .field("waker", &self.waker)
363            .finish()
364    }
365}
366
367enum_from_bits! {
368    /// The state of a [`Waiter`] node in a [`WaitMap`].
369    #[derive(Debug, Eq, PartialEq)]
370    enum WaitState<u8> {
371        /// The waiter has not yet been enqueued.
372        ///
373        /// When in this state, the node is **not** part of the linked list, and
374        /// can be dropped without removing it from the list.
375        Start = 0b01,
376
377        /// The waiter is waiting.
378        ///
379        /// When in this state, the node **is** part of the linked list. If the
380        /// node is dropped in this state, it **must** be removed from the list
381        /// before dropping it. Failure to ensure this will result in dangling
382        /// pointers in the linked list!
383        Waiting = 0b10,
384
385        /// The waiter has been woken.
386        ///
387        /// When in this state, the node is **not** part of the linked list, and
388        /// can be dropped without removing it from the list.
389        Completed = 0b11,
390    }
391}
392
393/// The queue's current state.
394#[derive(Debug, Copy, Clone, Eq, PartialEq)]
395#[repr(u8)]
396enum State {
397    /// No waiters are queued, and there is no pending notification.
398    /// Waiting while the queue is in this state will enqueue the waiter
399    Empty = 0b00,
400
401    /// There are one or more waiters in the queue. Waiting while
402    /// the queue is in this state will not transition the state. Waking while
403    /// in this state will wake the appropriate waiter in the queue; if this empties
404    /// the queue, then the queue will transition to [`State::Empty`].
405    Waiting = 0b01,
406
407    // TODO(AJM): We have a state gap here. Is this okay?
408    /// The queue is closed. Waiting while in this state will return
409    /// [`Closed`] without transitioning the queue's state.
410    ///
411    /// *Note*: This *must* correspond to all state bits being set, as it's set
412    /// via a [`fetch_or`].
413    ///
414    /// [`Closed`]: crate::Closed
415    /// [`fetch_or`]: core::sync::atomic::AtomicUsize::fetch_or
416    Closed = 0b11,
417}
418
419#[derive(Clone)]
420enum Wakeup<V> {
421    /// The Waiter has been created, but no wake has occurred. This should
422    /// be the ONLY state while in `WaitState::Start`
423    Empty,
424
425    /// The Waiter has moved to the `WaitState::Waiting` state. We now
426    /// have the relevant waker, and are still waiting for data. This
427    /// corresponds to `WaitState::Waiting`.
428    Waiting(Waker),
429
430    /// The Waiter has received data, and is waiting for the woken task
431    /// to notice, and take the data by polling+completing the future.
432    /// This corresponds to `WaitState::Completed`.
433    ///
434    /// This state stores the received value; taking the value out of the waiter
435    /// advances the state to `Retrieved`.
436    DataReceived(V),
437
438    /// The waiter has received data, and already given it away, and has
439    /// no more data to give. This corresponds to `WaitState::Completed`.
440    Retreived,
441
442    /// The Queue the waiter is part of has been closed. No data will
443    /// be received from this future. This corresponds to
444    /// `WaitState::Completed`.
445    Closed,
446}
447
448// === impl WaitMap ===
449
450impl<K: PartialEq, V> WaitMap<K, V> {
451    loom_const_fn! {
452        /// Returns a new `WaitMap`.
453        ///
454        /// This constructor returns a `WaitMap` that uses a [`DefaultMutex`] as
455        /// the [`ScopedRawMutex`] implementation for wait list synchronization.
456        /// To use a different [`ScopedRawMutex`] implementation, use the
457        /// [`new_with_raw_mutex`](Self::new_with_raw_mutex) constructor, instead. See
458        /// [the documentation on overriding mutex
459        /// implementations](crate::blocking#overriding-mutex-implementations)
460        /// for more details.
461        #[must_use]
462        pub fn new() -> Self {
463            Self::new_with_raw_mutex(DefaultMutex::new())
464        }
465    }
466}
467
468impl<K, V, Lock> Default for WaitMap<K, V, Lock>
469where
470    K: PartialEq,
471    Lock: ScopedRawMutex + Default,
472{
473    fn default() -> Self {
474        Self::new_with_raw_mutex(Lock::default())
475    }
476}
477
478impl<K, V, Lock> WaitMap<K, V, Lock>
479where
480    K: PartialEq,
481    Lock: ScopedRawMutex,
482{
483    loom_const_fn! {
484        /// Returns a new `WaitMap`, using the provided [`ScopedRawMutex`]
485        /// implementation for wait-list synchronization.
486        ///
487        /// This constructor allows a `WaitMap` to be constructed with any type that
488        /// implements [`ScopedRawMutex`] as the underlying raw blocking mutex
489        /// implementation. See [the documentation on overriding mutex
490        /// implementations](crate::blocking#overriding-mutex-implementations)
491        /// for more details.
492        #[must_use]
493        pub fn new_with_raw_mutex(lock: Lock) -> Self {
494            Self {
495                state: CachePadded::new(AtomicUsize::new(State::Empty.into_usize())),
496                queue: Mutex::new_with_raw_mutex(List::new(), lock),
497            }
498        }
499    }
500}
501
502impl<K: PartialEq, V, Lock: ScopedRawMutex> WaitMap<K, V, Lock> {
503    /// Wake a certain task in the queue.
504    ///
505    /// If the queue is empty, a wakeup is stored in the `WaitMap`, and the
506    /// next call to [`wait`] will complete immediately.
507    ///
508    /// [`wait`]: WaitMap::wait
509    #[inline]
510    pub fn wake(&self, key: &K, val: V) -> WakeOutcome<V> {
511        // snapshot the queue's current state.
512        let mut state = self.load();
513
514        // check if any tasks are currently waiting on this queue. if there are
515        // no waiting tasks, store the wakeup to be consumed by the next call to
516        // `wait`.
517        match state {
518            // Something is waiting!
519            State::Waiting => {}
520
521            // if the queue is closed, bail.
522            State::Closed => return WakeOutcome::Closed(val),
523
524            // if the queue is empty, bail.
525            State::Empty => return WakeOutcome::NoMatch(val),
526        }
527
528        // okay, there are tasks waiting on the queue; we must acquire the lock
529        // on the linked list and wake the next task from the queue.
530        let mut val = Some(val);
531        let maybe_waker = self.queue.with_lock(|queue| {
532            test_debug!("wake: -> locked");
533
534            // the queue's state may have changed while we were waiting to acquire
535            // the lock, so we need to acquire a new snapshot.
536            state = self.load();
537
538            let node = self.node_match_locked(key, &mut *queue, state)?;
539            // if there's a node, give it the value and take the waker and
540            // return it. we return the waker from this closure rather than
541            // waking it, because we need to release the lock before waking the
542            // task.
543            let val = val
544                .take()
545                .expect("value is only taken elsewhere if there is no waker, but there is one");
546            let waker = Waiter::<K, V>::wake(node, &mut *queue, Wakeup::DataReceived(val));
547            Some(waker)
548        });
549
550        if let Some(waker) = maybe_waker {
551            waker.wake();
552            WakeOutcome::Woke
553        } else {
554            let val =
555                val.expect("value is only taken elsewhere if there is a waker, and there isn't");
556            WakeOutcome::NoMatch(val)
557        }
558    }
559
560    /// Returns `true` if this `WaitMap` is [closed](Self::close).
561    #[must_use]
562    pub fn is_closed(&self) -> bool {
563        self.load() == State::Closed
564    }
565
566    /// Close the queue, indicating that it may no longer be used.
567    ///
568    /// Once a queue is closed, all [`wait`] calls (current or future) will
569    /// return an error.
570    ///
571    /// This method is generally used when implementing higher-level
572    /// synchronization primitives or resources: when an event makes a resource
573    /// permanently unavailable, the queue can be closed.
574    ///
575    /// [`wait`]: Self::wait
576    pub fn close(&self) {
577        let state = self.state.fetch_or(State::Closed.into_usize(), SeqCst);
578        let state = test_dbg!(State::from_bits(state));
579        if state != State::Waiting {
580            return;
581        }
582
583        let mut batch = WakeBatch::new();
584        let mut waiters_remaining = true;
585        while waiters_remaining {
586            waiters_remaining = self.queue.with_lock(|waiters| {
587                while let Some(node) = waiters.pop_back() {
588                    let waker = Waiter::wake(node, waiters, Wakeup::Closed);
589                    if !batch.add_waker(waker) {
590                        // there's still room in the wake set, just keep adding to it.
591                        return true;
592                    }
593                }
594                false
595            });
596            batch.wake_all();
597        }
598    }
599
600    /// Wait to be woken up by this queue.
601    ///
602    /// This returns a [`Wait`] future that will complete when the task is
603    /// woken by a call to [`wake`] with a matching `key`, or when the `WaitMap`
604    /// is dropped.
605    ///
606    /// **Note**: `key`s must be unique. If the given key already exists in the
607    /// `WaitMap`, the future will resolve to an Error the first time it is polled
608    ///
609    /// [`wake`]: Self::wake
610    pub fn wait(&self, key: K) -> Wait<'_, K, V, Lock> {
611        Wait {
612            queue: self,
613            waiter: self.waiter(key),
614        }
615    }
616
617    /// Returns a [`Waiter`] entry in this queue.
618    ///
619    /// This is factored out into a separate function because it's used by both
620    /// [`WaitMap::wait`] and [`WaitMap::wait_owned`].
621    fn waiter(&self, key: K) -> Waiter<K, V> {
622        let state = WaitState::Start;
623        Waiter {
624            state,
625            node: UnsafeCell::new(Node {
626                links: list::Links::new(),
627                waker: Wakeup::Empty,
628                _pin: PhantomPinned,
629            }),
630            key,
631        }
632    }
633
634    #[cfg_attr(test, track_caller)]
635    fn load(&self) -> State {
636        #[allow(clippy::let_and_return)]
637        let state = State::from_bits(self.state.load(SeqCst));
638        test_debug!("state.load() = {state:?}");
639        state
640    }
641
642    #[cfg_attr(test, track_caller)]
643    fn store(&self, state: State) {
644        test_debug!("state.store({state:?}");
645        self.state.store(state as usize, SeqCst);
646    }
647
648    #[cfg_attr(test, track_caller)]
649    fn compare_exchange(&self, current: State, new: State) -> Result<State, State> {
650        #[allow(clippy::let_and_return)]
651        let res = self
652            .state
653            .compare_exchange(current as usize, new as usize, SeqCst, SeqCst)
654            .map(State::from_bits)
655            .map_err(State::from_bits);
656        test_debug!("state.compare_exchange({current:?}, {new:?}) = {res:?}");
657        res
658    }
659
660    #[cold]
661    #[inline(never)]
662    fn node_match_locked(
663        &self,
664        key: &K,
665        queue: &mut List<Waiter<K, V>>,
666        curr: State,
667    ) -> Option<NonNull<Waiter<K, V>>> {
668        let state = curr;
669
670        // is the queue still in the `Waiting` state? it is possible that we
671        // transitioned to a different state while locking the queue.
672        if test_dbg!(state) != State::Waiting {
673            // If we are not waiting, we are either empty or closed.
674            // Not much to do.
675            return None;
676        }
677
678        let mut cursor = queue.cursor_front_mut();
679        let opt_node = cursor.remove_first(|t| &t.key == key);
680
681        // if we took the final waiter currently in the queue, transition to the
682        // `Empty` state.
683        if test_dbg!(queue.is_empty()) {
684            self.store(State::Empty);
685        }
686
687        opt_node
688    }
689}
690
691/// The result of an attempted [`WaitMap::wake()`] operation.
692#[derive(Debug)]
693pub enum WakeOutcome<V> {
694    /// The task was successfully woken, and the data was provided.
695    Woke,
696
697    /// No task matching the given key was found in the queue.
698    NoMatch(V),
699
700    /// The queue was already closed when the wake was attempted,
701    /// and the data was not provided to any task.
702    Closed(V),
703}
704
705// === impl WaitError ===
706
707impl fmt::Display for WaitError {
708    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
709        match self {
710            Self::Closed => f.pad("WaitMap closed"),
711            Self::Duplicate => f.pad("duplicate key"),
712            &Self::AlreadyConsumed => f.pad("received data has already been consumed"),
713            Self::NeverAdded => f.pad("Wait was never added to WaitMap"),
714        }
715    }
716}
717
718feature! {
719    #![feature = "core-error"]
720    impl core::error::Error for WaitError {}
721}
722
723// === impl Waiter ===
724
725/// A future that ensures a [`Wait`] has been added to a [`WaitMap`].
726///
727/// See [`Wait::subscribe`] for more information and usage example.
728#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
729#[derive(Debug)]
730pub struct Subscribe<'a, 'b, K, V, Lock = DefaultMutex>
731where
732    K: PartialEq,
733    Lock: ScopedRawMutex,
734{
735    wait: Pin<&'a mut Wait<'b, K, V, Lock>>,
736}
737
738impl<K, V, Lock> Future for Subscribe<'_, '_, K, V, Lock>
739where
740    K: PartialEq,
741    Lock: ScopedRawMutex,
742{
743    type Output = WaitResult<()>;
744
745    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
746        let this = self.wait.as_mut().project();
747        if let WaitState::Start = test_dbg!(&this.waiter.state) {
748            this.waiter.start_to_wait(this.queue, cx)
749        } else {
750            Poll::Ready(Ok(()))
751        }
752    }
753}
754
755/// Deprecated alias for [`Subscribe`]. See the [`Wait::subscribe`]
756/// documentation for more details.
757#[deprecated(
758    since = "0.1.3",
759    note = "renamed to `Subscribe` for consistency, use that instead"
760)]
761pub type EnqueueWait<'a, 'b, K, V, Lock> = Subscribe<'a, 'b, K, V, Lock>;
762
763impl<K: PartialEq, V> Waiter<K, V> {
764    /// Wake the task that owns this `Waiter`.
765    ///
766    /// # Safety
767    ///
768    /// This is only safe to call while the list is locked. The `list`
769    /// parameter ensures this method is only called while holding the lock, so
770    /// this can be safe.
771    ///
772    /// Of course, that must be the *same* list that this waiter is a member of,
773    /// and currently, there is no way to ensure that...
774    #[inline(always)]
775    #[cfg_attr(loom, track_caller)]
776    fn wake(this: NonNull<Self>, list: &mut List<Self>, wakeup: Wakeup<V>) -> Waker {
777        Waiter::with_node(this, list, |node| {
778            let waker = test_dbg!(mem::replace(&mut node.waker, wakeup));
779            match waker {
780                Wakeup::Waiting(waker) => waker,
781                _ => unreachable!("tried to wake a waiter in the {:?} state!", waker),
782            }
783        })
784    }
785
786    /// # Safety
787    ///
788    /// This is only safe to call while the list is locked. The dummy `_list`
789    /// parameter ensures this method is only called while holding the lock, so
790    /// this can be safe.
791    ///
792    /// Of course, that must be the *same* list that this waiter is a member of,
793    /// and currently, there is no way to ensure that...
794    #[inline(always)]
795    #[cfg_attr(loom, track_caller)]
796    fn with_node<T>(
797        mut this: NonNull<Self>,
798        _list: &mut List<Self>,
799        f: impl FnOnce(&mut Node<K, V>) -> T,
800    ) -> T {
801        unsafe {
802            // safety: this is only called while holding the lock on the queue,
803            // so it's safe to mutate the waiter.
804            this.as_mut().node.with_mut(|node| f(&mut *node))
805        }
806    }
807
808    /// Moves a `Wait` from the `Start` condition.
809    ///
810    /// Caller MUST ensure the `Wait` is in the start condition before calling.
811    fn start_to_wait<Lock>(
812        mut self: Pin<&mut Self>,
813        queue: &WaitMap<K, V, Lock>,
814        cx: &mut Context<'_>,
815    ) -> Poll<WaitResult<()>>
816    where
817        Lock: ScopedRawMutex,
818    {
819        // Try to wait...
820        test_debug!("poll_wait: locking...");
821        queue.queue.with_lock(move |waiters| {
822            test_debug!("poll_wait: -> locked");
823            let mut this = self.as_mut().project();
824
825            debug_assert!(
826                matches!(this.state, WaitState::Start),
827                "start_to_wait should ONLY be called from the Start state!"
828            );
829
830            let mut queue_state = queue.load();
831
832            // transition the queue to the waiting state
833            'to_waiting: loop {
834                match test_dbg!(queue_state) {
835                    // the queue is `Empty`, transition to `Waiting`
836                    State::Empty => match queue.compare_exchange(queue_state, State::Waiting) {
837                        Ok(_) => break 'to_waiting,
838                        Err(actual) => queue_state = actual,
839                    },
840                    // the queue is already `Waiting`
841                    State::Waiting => break 'to_waiting,
842                    State::Closed => return closed(),
843                }
844            }
845
846            // Check if key already exists
847            //
848            // Note: It's okay not to re-update the state here, if we were empty
849            // this check will never trigger, if we are already waiting, we should
850            // still be waiting.
851            let mut cursor = waiters.cursor_front_mut();
852            if cursor.any(|n| &n.key == this.key) {
853                return duplicate();
854            }
855
856            // enqueue the node
857            *this.state = WaitState::Waiting;
858            this.node.as_mut().with_mut(|node| {
859                unsafe {
860                    // safety: we may mutate the node because we are
861                    // holding the lock.
862                    (*node).waker = Wakeup::Waiting(cx.waker().clone());
863                }
864            });
865            let ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(self)) };
866            waiters.push_front(ptr);
867
868            Poll::Ready(Ok(()))
869        })
870    }
871
872    fn poll_wait<Lock>(
873        mut self: Pin<&mut Self>,
874        queue: &WaitMap<K, V, Lock>,
875        cx: &mut Context<'_>,
876    ) -> Poll<WaitResult<V>>
877    where
878        Lock: ScopedRawMutex,
879    {
880        test_debug!(ptr = ?fmt::ptr(self.as_mut()), "Waiter::poll_wait");
881        let this = self.as_mut().project();
882
883        match test_dbg!(&this.state) {
884            WaitState::Start => {
885                let _ = self.start_to_wait(queue, cx)?;
886                Poll::Pending
887            }
888            WaitState::Waiting => {
889                // We must lock the linked list in order to safely mutate our node in
890                // the list. We don't actually need the mutable reference to the
891                // queue here, though.
892                queue.queue.with_lock(|_waiters| {
893                    this.node.with_mut(|node| unsafe {
894                        // safety: we may mutate the node because we are
895                        // holding the lock.
896                        let node = &mut *node;
897                        let result;
898                        node.waker = match mem::replace(&mut node.waker, Wakeup::Empty) {
899                            // We already had a waker, but are now getting another one.
900                            // Store the new one, droping the old one
901                            Wakeup::Waiting(waker) => {
902                                result = Poll::Pending;
903                                if !waker.will_wake(cx.waker()) {
904                                    Wakeup::Waiting(cx.waker().clone())
905                                } else {
906                                    Wakeup::Waiting(waker)
907                                }
908                            }
909                            // We have received the data, take the data out of the
910                            // future, and provide it to the poller
911                            Wakeup::DataReceived(val) => {
912                                result = notified(val);
913                                Wakeup::Retreived
914                            }
915                            Wakeup::Retreived => {
916                                result = consumed();
917                                Wakeup::Retreived
918                            }
919
920                            Wakeup::Closed => {
921                                *this.state = WaitState::Completed;
922                                result = closed();
923                                Wakeup::Closed
924                            }
925                            Wakeup::Empty => {
926                                result = never_added();
927                                Wakeup::Closed
928                            }
929                        };
930                        result
931                    })
932                })
933            }
934            WaitState::Completed => consumed(),
935        }
936    }
937
938    /// Release this `Waiter` from the queue.
939    ///
940    /// This is called from the `drop` implementation for the [`Wait`] and
941    /// [`WaitOwned`] futures.
942    fn release<Lock>(mut self: Pin<&mut Self>, queue: &WaitMap<K, V, Lock>)
943    where
944        Lock: ScopedRawMutex,
945    {
946        let state = *(self.as_mut().project().state);
947        let ptr = NonNull::from(unsafe { Pin::into_inner_unchecked(self) });
948        test_debug!(self = ?fmt::ptr(ptr), ?state, ?queue, "Waiter::release");
949
950        // if we're not enqueued, we don't have to do anything else.
951        if state != WaitState::Waiting {
952            return;
953        }
954
955        queue.queue.with_lock(|waiters| {
956            let state = queue.load();
957
958            // remove the node
959            unsafe {
960                // safety: we have the lock on the queue, so this is safe.
961                waiters.remove(ptr);
962            };
963
964            // if we removed the last waiter from the queue, transition the state to
965            // `Empty`.
966            if test_dbg!(waiters.is_empty()) && state == State::Waiting {
967                queue.store(State::Empty);
968            }
969        })
970    }
971}
972
973unsafe impl<K: PartialEq, V> Linked<list::Links<Waiter<K, V>>> for Waiter<K, V> {
974    type Handle = NonNull<Waiter<K, V>>;
975
976    fn into_ptr(r: Self::Handle) -> NonNull<Self> {
977        r
978    }
979
980    unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
981        ptr
982    }
983
984    unsafe fn links(target: NonNull<Self>) -> NonNull<list::Links<Waiter<K, V>>> {
985        // Safety: using `ptr::addr_of!` avoids creating a temporary
986        // reference, which stacked borrows dislikes.
987        let node = ptr::addr_of!((*target.as_ptr()).node);
988        (*node).with_mut(|node| {
989            let links = ptr::addr_of_mut!((*node).links);
990            // Safety: since the `target` pointer is `NonNull`, we can assume
991            // that pointers to its members are also not null, making this use
992            // of `new_unchecked` fine.
993            NonNull::new_unchecked(links)
994        })
995    }
996}
997
998// === impl Wait ===
999
1000impl<K, V, Lock> Future for Wait<'_, K, V, Lock>
1001where
1002    K: PartialEq,
1003    Lock: ScopedRawMutex,
1004{
1005    type Output = WaitResult<V>;
1006
1007    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1008        let this = self.project();
1009        this.waiter.poll_wait(this.queue, cx)
1010    }
1011}
1012
1013#[pinned_drop]
1014impl<K, V, Lock> PinnedDrop for Wait<'_, K, V, Lock>
1015where
1016    K: PartialEq,
1017    Lock: ScopedRawMutex,
1018{
1019    fn drop(mut self: Pin<&mut Self>) {
1020        let this = self.project();
1021        this.waiter.release(this.queue);
1022    }
1023}
1024
1025// === impl MapState ===
1026
1027impl State {
1028    #[inline]
1029    fn from_bits(bits: usize) -> Self {
1030        Self::try_from_bits(bits).expect("This shouldn't be possible")
1031    }
1032}
1033
1034impl FromBits<usize> for State {
1035    const BITS: u32 = 2;
1036    type Error = core::convert::Infallible;
1037
1038    fn try_from_bits(bits: usize) -> Result<Self, Self::Error> {
1039        Ok(match bits as u8 {
1040            bits if bits == Self::Empty as u8 => Self::Empty,
1041            bits if bits == Self::Waiting as u8 => Self::Waiting,
1042            bits if bits == Self::Closed as u8 => Self::Closed,
1043            _ => unsafe {
1044                // TODO(AJM): this isn't *totally* true anymore...
1045                unreachable_unchecked!("all potential 2-bit patterns should be covered!")
1046            },
1047        })
1048    }
1049
1050    fn into_bits(self) -> usize {
1051        self.into_usize()
1052    }
1053}
1054
1055impl State {
1056    const fn into_usize(self) -> usize {
1057        self as u8 as usize
1058    }
1059}
1060
1061// === impl WaitOwned ===
1062
1063feature! {
1064    #![feature = "alloc"]
1065
1066    use alloc::sync::Arc;
1067
1068    /// Future returned from [`WaitMap::wait_owned()`].
1069    ///
1070    /// This is identical to the [`Wait`] future, except that it takes an
1071    /// [`Arc`] reference to the [`WaitMap`], allowing the returned future to
1072    /// live for the `'static` lifetime.
1073    ///
1074    /// This future is fused, so once it has completed, any future calls to poll
1075    /// will immediately return [`Poll::Ready`].
1076    ///
1077    /// # Notes
1078    ///
1079    /// This future is `!Unpin`, as it is unsafe to [`core::mem::forget`] a
1080    /// `Wait` future once it has been polled. For instance, the following code
1081    /// must not compile:
1082    ///
1083    ///```compile_fail
1084    /// use maitake_sync::wait_map::WaitOwned;
1085    ///
1086    /// // Calls to this function should only compile if `T` is `Unpin`.
1087    /// fn assert_unpin<T: Unpin>() {}
1088    ///
1089    /// assert_unpin::<WaitOwned<'_, usize, ()>>();
1090    #[derive(Debug)]
1091    #[pin_project(PinnedDrop)]
1092    pub struct WaitOwned<K: PartialEq, V, Lock: ScopedRawMutex = DefaultMutex> {
1093        /// The `WaitMap` being waited on.
1094        queue: Arc<WaitMap<K, V, Lock>>,
1095
1096        /// Entry in the wait queue.
1097        #[pin]
1098        waiter: Waiter<K, V>,
1099    }
1100
1101    impl<K: PartialEq, V, Lock: ScopedRawMutex> WaitMap<K, V, Lock> {
1102        /// Wait to be woken up by this queue, returning a future that's valid
1103        /// for the `'static` lifetime.
1104        ///
1105        /// This is identical to the [`wait`] method, except that it takes a
1106        /// [`Arc`] reference to the [`WaitMap`], allowing the returned future to
1107        /// live for the `'static` lifetime.
1108        ///
1109        /// This returns a [`WaitOwned`] future that will complete when the task is
1110        /// woken by a call to [`wake`] with a matching `key`, or when the `WaitMap`
1111        /// is dropped.
1112        ///
1113        /// **Note**: `key`s must be unique. If the given key already exists in the
1114        /// `WaitMap`, the future will resolve to an Error the first time it is polled
1115        ///
1116        /// [`wake`]: Self::wake
1117        /// [`wait`]: Self::wait
1118        pub fn wait_owned(self: &Arc<Self>, key: K) -> WaitOwned<K, V, Lock> {
1119            let waiter = self.waiter(key);
1120            let queue = self.clone();
1121            WaitOwned { queue, waiter }
1122        }
1123    }
1124
1125    impl<K: PartialEq, V, Lock: ScopedRawMutex> Future for WaitOwned<K, V, Lock> {
1126        type Output = WaitResult<V>;
1127
1128        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1129            let this = self.project();
1130            this.waiter.poll_wait(&*this.queue, cx)
1131        }
1132    }
1133
1134    #[pinned_drop]
1135    impl<K, V, Lock> PinnedDrop for WaitOwned<K, V, Lock>
1136    where
1137        K: PartialEq,
1138        Lock: ScopedRawMutex,
1139    {
1140        fn drop(mut self: Pin<&mut Self>) {
1141            let this = self.project();
1142            this.waiter.release(&*this.queue);
1143        }
1144    }
1145}
1146
1147impl<V> fmt::Debug for Wakeup<V> {
1148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1149        match self {
1150            Self::Empty => f.write_str("Wakeup::Empty"),
1151            Self::Waiting(waker) => f.debug_tuple("Wakeup::Waiting").field(waker).finish(),
1152            Self::DataReceived(_) => f.write_str("Wakeup::DataReceived(..)"),
1153            Self::Retreived => f.write_str("Wakeup::Retrieved"),
1154            Self::Closed => f.write_str("Wakeup::Closed"),
1155        }
1156    }
1157}