maitake/task/
state.rs

1use super::PollResult;
2use crate::{
3    loom::sync::atomic::{
4        self, AtomicUsize,
5        Ordering::{self, *},
6    },
7    sync::util::Backoff,
8};
9
10use core::fmt;
11use mycelium_util::unreachable_unchecked;
12
13mycelium_bitfield::bitfield! {
14    /// A snapshot of a task's current state.
15    #[derive(PartialEq, Eq)]
16    pub(crate) struct State<usize> {
17        /// If set, this task is currently being polled.
18        pub(crate) const POLLING: bool;
19
20        /// If set, this task's [`Waker`] has been woken.
21        ///
22        /// [`Waker`]: core::task::Waker
23        pub(crate) const WOKEN: bool;
24
25        /// If set, this task's [`Future`] has completed (i.e., it has returned
26        /// [`Poll::Ready`]).
27        ///
28        /// [`Future`]: core::future::Future
29        /// [`Poll::Ready`]: core::task::Poll::Ready
30        pub(crate) const COMPLETED: bool;
31
32        /// If set, this task has been canceled.
33        pub(crate) const CANCELED: bool;
34
35        /// If set, this task has a [`JoinHandle`] awaiting its completion.
36        ///
37        /// If the `JoinHandle` is dropped, this flag is unset.
38        ///
39        /// This flag does *not* indicate the presence of a [`Waker`] in the
40        /// `join_waker` slot; it only indicates that a [`JoinHandle`] for this
41        /// task *exists*. The join waker may not currently be registered if
42        /// this flag is set.
43        ///
44        /// [`Waker`]: core::task::Waker
45        /// [`JoinHandle`]: super::JoinHandle
46        pub(crate) const HAS_JOIN_HANDLE: bool;
47
48        /// The state of the task's [`JoinHandle`] [`Waker`].
49        ///
50        /// [`Waker`]: core::task::Waker
51        /// [`JoinHandle`]: super::JoinHandle
52        const JOIN_WAKER: JoinWakerState;
53
54        /// If set, this task has output ready to be taken by a [`JoinHandle`].
55        ///
56        /// [`JoinHandle`]: super::JoinHandle
57        pub(crate) const HAS_OUTPUT: bool;
58
59        /// The number of currently live references to this task.
60        ///
61        /// When this is 0, the task may be deallocated.
62        const REFS = ..;
63    }
64}
65
66/// An atomic cell that stores a task's current [`State`].
67#[repr(transparent)]
68pub(super) struct StateCell(AtomicUsize);
69
70#[derive(Copy, Clone, Debug, PartialEq, Eq)]
71pub(super) enum ScheduleAction {
72    /// The task should be enqueued.
73    Enqueue,
74
75    /// The task does not need to be enqueued.
76    None,
77}
78
79#[derive(Copy, Clone, Debug, PartialEq, Eq)]
80pub(super) enum JoinAction {
81    /// It's safe to take the task's output!
82    TakeOutput,
83
84    /// The task was canceled, it cannot be joined.
85    Canceled {
86        /// If `true`, the task completed successfully before it was cancelled.
87        completed: bool,
88    },
89
90    /// Register the *first* join waker; there is no previous join waker and the
91    /// slot is not initialized.
92    Register,
93
94    /// The task is not ready to read the output, but a previous join waker is
95    /// registered.
96    Reregister,
97}
98
99#[derive(Copy, Clone, Debug, PartialEq, Eq)]
100pub(super) enum OrDrop<T> {
101    /// Another action should be performed.
102    Action(T),
103
104    /// The task should be deallocated.
105    Drop,
106}
107
108#[derive(Copy, Clone, Debug, PartialEq, Eq)]
109pub(super) enum StartPollAction {
110    /// It's okay to poll the task.
111    Poll,
112
113    /// The task was canceled, and its [`JoinHandle`] waker may need to be woken.
114    ///
115    /// [`JoinHandle`]: super::JoinHandle
116    Canceled {
117        /// If `true`, the task's join waker must be woken.
118        wake_join_waker: bool,
119    },
120
121    /// The task is not in a valid state to start a poll. Do nothing.
122    CantPoll,
123}
124
125pub(super) type WakeAction = OrDrop<ScheduleAction>;
126
127impl State {
128    #[inline]
129    pub(crate) fn ref_count(self) -> usize {
130        self.get(Self::REFS)
131    }
132
133    fn drop_ref(self) -> Self {
134        Self(self.0 - REF_ONE)
135    }
136
137    fn clone_ref(self) -> Self {
138        Self(self.0 + REF_ONE)
139    }
140}
141
142const REF_ONE: usize = State::REFS.first_bit();
143const REF_MAX: usize = State::REFS.raw_mask();
144
145#[derive(Copy, Clone, Debug, Eq, PartialEq)]
146#[repr(u8)]
147enum JoinWakerState {
148    /// There is no join waker; the slot is uninitialized.
149    Empty = 0b00,
150    /// A join waker is *being* registered.
151    Registering = 0b01,
152    /// A join waker is registered, the slot is initialized.
153    Waiting = 0b10,
154    /// The join waker has been woken.
155    Woken = 0b11,
156}
157
158// === impl StateCell ===
159
160impl State {
161    fn has_join_waker(&mut self, should_wait: &mut bool) -> bool {
162        match self.get(State::JOIN_WAKER) {
163            JoinWakerState::Empty => false,
164            JoinWakerState::Registering => {
165                *should_wait = true;
166                debug_assert!(
167                    self.get(State::HAS_JOIN_HANDLE),
168                    "a task cannot register a join waker if it does not have a join handle!",
169                );
170                true
171            }
172            JoinWakerState::Waiting => {
173                debug_assert!(
174                    self.get(State::HAS_JOIN_HANDLE),
175                    "a task cannot have a join waker if it does not have a join handle!",
176                );
177                *should_wait = false;
178                self.set(State::JOIN_WAKER, JoinWakerState::Empty);
179                true
180            }
181            JoinWakerState::Woken => {
182                debug_assert!(
183                    false,
184                    "join waker should not be woken until task has completed, wtf"
185                );
186                false
187            }
188        }
189    }
190}
191
192impl StateCell {
193    #[cfg(not(loom))]
194    pub const fn new() -> Self {
195        Self(AtomicUsize::new(REF_ONE))
196    }
197
198    #[cfg(loom)]
199    pub fn new() -> Self {
200        Self(AtomicUsize::new(REF_ONE))
201    }
202
203    pub(super) fn start_poll(&self) -> StartPollAction {
204        let mut should_wait_for_join_waker = false;
205        let action = self.transition(|state| {
206            // cannot start polling a task which is being polled on another
207            // thread, or a task which has completed
208            if test_dbg!(state.get(State::POLLING)) || test_dbg!(state.get(State::COMPLETED)) {
209                return StartPollAction::CantPoll;
210            }
211
212            // if the task has been canceled, don't poll it.
213            if test_dbg!(state.get(State::CANCELED)) {
214                let wake_join_waker = state.has_join_waker(&mut should_wait_for_join_waker);
215                return StartPollAction::Canceled { wake_join_waker };
216            }
217
218            state
219                // the task is now being polled.
220                .set(State::POLLING, true)
221                // if the task was woken, consume the wakeup.
222                .set(State::WOKEN, false);
223            StartPollAction::Poll
224        });
225
226        if should_wait_for_join_waker {
227            debug_assert!(matches!(action, StartPollAction::Canceled { .. }));
228            self.wait_for_join_waker(self.load(Acquire));
229        }
230
231        action
232    }
233
234    pub(super) fn end_poll(&self, completed: bool) -> PollResult {
235        let mut should_wait_for_join_waker = false;
236        let action = self.transition(|state| {
237            // Cannot end a poll if a task is not being polled!
238            debug_assert!(state.get(State::POLLING));
239            debug_assert!(!state.get(State::COMPLETED));
240            debug_assert!(
241                state.ref_count() > 0,
242                "cannot poll a task that has zero references, what is happening!"
243            );
244
245            state
246                .set(State::POLLING, false)
247                .set(State::COMPLETED, completed);
248
249            // Was the task woken during the poll?
250            if !test_dbg!(completed) && test_dbg!(state.get(State::WOKEN)) {
251                return PollResult::PendingSchedule;
252            }
253
254            let had_join_waker = if test_dbg!(completed) {
255                // set the output flag so that the joinhandle knows it is now
256                // safe to read the task's output.
257                state.set(State::HAS_OUTPUT, true);
258                state.has_join_waker(&mut should_wait_for_join_waker)
259            } else {
260                false
261            };
262
263            if had_join_waker {
264                PollResult::ReadyJoined
265            } else if completed {
266                PollResult::Ready
267            } else {
268                PollResult::Pending
269            }
270        });
271
272        if should_wait_for_join_waker {
273            debug_assert_eq!(action, PollResult::ReadyJoined);
274            self.wait_for_join_waker(self.load(Acquire));
275        }
276
277        action
278    }
279
280    /// Transition to the woken state by value, returning `true` if the task
281    /// should be enqueued.
282    pub(super) fn wake_by_val(&self) -> WakeAction {
283        self.transition(|state| {
284            // If the task was woken *during* a poll, it will be re-queued by the
285            // scheduler at the end of the poll if needed, so don't enqueue it now.
286            if test_dbg!(state.get(State::POLLING)) {
287                *state = state.with(State::WOKEN, true).drop_ref();
288                assert!(state.ref_count() > 0);
289
290                return OrDrop::Action(ScheduleAction::None);
291            }
292
293            // If the task is already completed or woken, we don't need to
294            // requeue it, but decrement the ref count for the waker that was
295            // used for this wakeup.
296            if test_dbg!(state.get(State::COMPLETED)) || test_dbg!(state.get(State::WOKEN)) {
297                let new_state = state.drop_ref();
298                *state = new_state;
299                return if new_state.ref_count() == 0 {
300                    OrDrop::Drop
301                } else {
302                    OrDrop::Action(ScheduleAction::None)
303                };
304            }
305
306            // Otherwise, transition to the notified state and enqueue the task.
307            *state = state.with(State::WOKEN, true).clone_ref();
308            OrDrop::Action(ScheduleAction::Enqueue)
309        })
310    }
311
312    /// Transition to the woken state by ref, returning `true` if the task
313    /// should be enqueued.
314    pub(super) fn wake_by_ref(&self) -> ScheduleAction {
315        self.transition(|state| {
316            if test_dbg!(state.get(State::COMPLETED)) || test_dbg!(state.get(State::WOKEN)) {
317                return ScheduleAction::None;
318            }
319
320            if test_dbg!(state.get(State::POLLING)) {
321                state.set(State::WOKEN, true);
322                return ScheduleAction::None;
323            }
324
325            *state = state.with(State::WOKEN, true).clone_ref();
326            ScheduleAction::Enqueue
327        })
328    }
329
330    pub(super) fn set_woken(&self) {
331        self.0.fetch_or(State::WOKEN.raw_mask(), AcqRel);
332    }
333
334    #[inline]
335    pub(super) fn clone_ref(&self) {
336        // Using a relaxed ordering is alright here, as knowledge of the
337        // original reference prevents other threads from erroneously deleting
338        // the object.
339        //
340        // As explained in the [Boost documentation][1], Increasing the
341        // reference counter can always be done with memory_order_relaxed: New
342        // references to an object can only be formed from an existing
343        // reference, and passing an existing reference from one thread to
344        // another must already provide any required synchronization.
345        //
346        // [1]: (www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html)
347        let old_refs = self.0.fetch_add(REF_ONE, Relaxed);
348        test_dbg!(State::REFS.unpack(old_refs));
349
350        // However we need to guard against massive refcounts in case someone
351        // is `mem::forget`ing tasks. If we don't do this the count can overflow
352        // and users will use-after free. We racily saturate to `isize::MAX` on
353        // the assumption that there aren't ~2 billion threads incrementing
354        // the reference count at once. This branch will never be taken in
355        // any realistic program.
356        //
357        // We abort because such a program is incredibly degenerate, and we
358        // don't care to support it.
359        if test_dbg!(old_refs > REF_MAX) {
360            panic!("task reference count overflow");
361        }
362    }
363
364    #[inline]
365    pub(super) fn drop_ref(&self) -> bool {
366        test_debug!("StateCell::drop_ref");
367        // We do not need to synchronize with other cores unless we are going to
368        // delete the task.
369        let old_refs = self.0.fetch_sub(REF_ONE, Release);
370
371        // Manually shift over the refcount to clear the state bits. We don't
372        // use the packing spec here, because it would also mask out any high
373        // bits, and we can avoid doing the bitwise-and (since there are no
374        // higher bits that are not part of the ref count). This is probably a
375        // premature optimization lol.
376        test_dbg!(State::REFS.unpack(old_refs));
377        let old_refs = old_refs >> State::REFS.least_significant_index();
378
379        // Did we drop the last ref?
380        if test_dbg!(old_refs) > 1 {
381            return false;
382        }
383
384        atomic::fence(Acquire);
385        true
386    }
387
388    /// Cancel the task.
389    ///
390    /// Returns `true` if the task was successfully canceled.
391    pub(super) fn cancel(&self) -> bool {
392        test_debug!("StateCell::cancel");
393        // XXX(eliza): this *could* probably just be a `fetch_or`, instead of a
394        // whole `transition`...
395        self.transition(|state| {
396            // you can't cancel a task that has already been canceled, that doesn't make sense.
397            if state.get(State::CANCELED) {
398                return false;
399            }
400
401            // this task is CANCELED! can't believe some of you are still
402            // following it, smh...
403            state.set(State::CANCELED, true).set(State::WOKEN, true);
404
405            true
406        })
407    }
408
409    #[inline]
410    pub(super) fn create_join_handle(&self) {
411        test_debug!("StateCell::create_join_handle");
412        self.transition(|state| {
413            debug_assert!(
414                !state.get(State::HAS_JOIN_HANDLE),
415                "task already has a join handle, cannot create a new one! state={state:?}"
416            );
417
418            *state = state.with(State::HAS_JOIN_HANDLE, true);
419        })
420    }
421
422    #[inline]
423    pub(super) fn drop_join_handle(&self) {
424        test_debug!("StateCell::drop_join_handle");
425        const MASK: usize = !State::HAS_JOIN_HANDLE.raw_mask();
426        let _prev = self.0.fetch_and(MASK, Release);
427        test_trace!(
428            "drop_join_handle; prev_state:\n{}\nstate:\n{}",
429            State::from_bits(_prev),
430            self.load(Acquire),
431        );
432        debug_assert!(
433            State(_prev).get(State::HAS_JOIN_HANDLE),
434            "tried to drop a join handle when the task did not have a join handle!\nstate: {:#?}",
435            State(_prev),
436        )
437    }
438
439    /// Returns whether if it's okay to take the task's output.
440    pub(super) fn try_join(&self) -> JoinAction {
441        fn should_register(state: &mut State) -> JoinAction {
442            let action = match state.get(State::JOIN_WAKER) {
443                JoinWakerState::Empty => JoinAction::Register,
444                x => {
445                    debug_assert_eq!(x, JoinWakerState::Waiting);
446                    JoinAction::Reregister
447                }
448            };
449            state.set(State::JOIN_WAKER, JoinWakerState::Registering);
450
451            action
452        }
453
454        self.transition(|state| {
455            let has_output = test_dbg!(state.get(State::HAS_OUTPUT));
456
457            if test_dbg!(state.get(State::CANCELED)) {
458                return JoinAction::Canceled {
459                    completed: has_output,
460                };
461            }
462
463            // If the task has not completed, we can't take its join output.
464            if test_dbg!(!state.get(State::COMPLETED)) {
465                return should_register(state);
466            }
467
468            // If the task does not have output, we cannot take it.
469            if !has_output {
470                return should_register(state);
471            }
472
473            *state = state.with(State::HAS_OUTPUT, false);
474            JoinAction::TakeOutput
475        })
476    }
477
478    pub(super) fn set_join_waker_registered(&self) {
479        self.transition(|state| {
480            debug_assert_eq!(state.get(State::JOIN_WAKER), JoinWakerState::Registering);
481            state
482                .set(State::HAS_JOIN_HANDLE, true)
483                .set(State::JOIN_WAKER, JoinWakerState::Waiting);
484        })
485    }
486
487    /// Returns `true` if this task has an un-dropped [`JoinHandle`] [`Waker`] that
488    /// needs to be dropped.
489    ///
490    /// [`JoinHandle`]: super::JoinHandle
491    /// [`Waker`]: core::task::Waker
492    pub(super) fn join_waker_needs_drop(&self) -> bool {
493        let state = self.load(Acquire);
494        match test_dbg!(state.get(State::JOIN_WAKER)) {
495            JoinWakerState::Empty | JoinWakerState::Woken => return false,
496            JoinWakerState::Registering => self.wait_for_join_waker(state),
497            JoinWakerState::Waiting => {}
498        }
499
500        true
501    }
502
503    pub(super) fn load(&self, order: Ordering) -> State {
504        State(self.0.load(order))
505    }
506
507    /// Advance this task's state by running the provided
508    /// `transition` function on the current [`State`].
509    #[cfg_attr(test, track_caller)]
510    fn transition<T>(&self, mut transition: impl FnMut(&mut State) -> T) -> T {
511        let mut current = self.load(Acquire);
512        loop {
513            test_trace!("StateCell::transition; current:\n{}", current);
514            let mut next = current;
515            // Run the transition function.
516            let res = transition(&mut next);
517
518            if test_dbg!(current.0 == next.0) {
519                return res;
520            }
521
522            test_trace!("StateCell::transition; next:\n{}", next);
523            match self
524                .0
525                .compare_exchange_weak(current.0, next.0, AcqRel, Acquire)
526            {
527                Ok(_) => return res,
528                Err(actual) => current = State(actual),
529            }
530        }
531    }
532
533    fn wait_for_join_waker(&self, mut state: State) {
534        test_trace!("StateCell::wait_for_join_waker");
535        let mut boff = Backoff::new();
536        loop {
537            state.set(State::JOIN_WAKER, JoinWakerState::Waiting);
538            let next = state.with(State::JOIN_WAKER, JoinWakerState::Woken);
539            match self
540                .0
541                .compare_exchange_weak(state.0, next.0, AcqRel, Acquire)
542            {
543                Ok(_) => return,
544                Err(actual) => state = State(actual),
545            }
546            boff.spin();
547        }
548    }
549}
550
551impl fmt::Debug for StateCell {
552    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
553        self.load(Relaxed).fmt(f)
554    }
555}
556
557impl mycelium_bitfield::FromBits<usize> for JoinWakerState {
558    type Error = core::convert::Infallible;
559
560    /// The number of bits required to represent a value of this type.
561    const BITS: u32 = 2;
562
563    #[inline]
564    #[allow(clippy::literal_string_with_formatting_args)]
565    fn try_from_bits(bits: usize) -> Result<Self, Self::Error> {
566        match bits {
567            b if b == Self::Registering as usize => Ok(Self::Registering),
568            b if b == Self::Waiting as usize => Ok(Self::Waiting),
569            b if b == Self::Empty as usize => Ok(Self::Empty),
570            b if b == Self::Woken as usize => Ok(Self::Woken),
571            _ => unsafe {
572                // this should never happen unless the bitpacking code is broken
573                unreachable_unchecked!("invalid join waker state {bits:#b}")
574            },
575        }
576    }
577
578    #[inline]
579    fn into_bits(self) -> usize {
580        self as u8 as usize
581    }
582}
583
584#[cfg(all(test, not(loom)))]
585mod tests {
586    use super::*;
587
588    #[test]
589    // No sense spending time running these trivial tests under Miri...
590    #[cfg_attr(miri, ignore)]
591    fn packing_specs_valid() {
592        State::assert_valid()
593    }
594
595    #[test]
596    // No sense spending time running these trivial tests under Miri...
597    #[cfg_attr(miri, ignore)]
598    fn debug_alt() {
599        let state = StateCell::new();
600        println!("{state:#?}");
601    }
602}