1use super::PollResult;
2use crate::{
3 loom::sync::atomic::{
4 self, AtomicUsize,
5 Ordering::{self, *},
6 },
7 sync::util::Backoff,
8};
9
10use core::fmt;
11use mycelium_util::unreachable_unchecked;
12
13mycelium_bitfield::bitfield! {
14 #[derive(PartialEq, Eq)]
16 pub(crate) struct State<usize> {
17 pub(crate) const POLLING: bool;
19
20 pub(crate) const WOKEN: bool;
24
25 pub(crate) const COMPLETED: bool;
31
32 pub(crate) const CANCELED: bool;
34
35 pub(crate) const HAS_JOIN_HANDLE: bool;
47
48 const JOIN_WAKER: JoinWakerState;
53
54 pub(crate) const HAS_OUTPUT: bool;
58
59 const REFS = ..;
63 }
64}
65
66#[repr(transparent)]
68pub(super) struct StateCell(AtomicUsize);
69
70#[derive(Copy, Clone, Debug, PartialEq, Eq)]
71pub(super) enum ScheduleAction {
72 Enqueue,
74
75 None,
77}
78
79#[derive(Copy, Clone, Debug, PartialEq, Eq)]
80pub(super) enum JoinAction {
81 TakeOutput,
83
84 Canceled {
86 completed: bool,
88 },
89
90 Register,
93
94 Reregister,
97}
98
99#[derive(Copy, Clone, Debug, PartialEq, Eq)]
100pub(super) enum OrDrop<T> {
101 Action(T),
103
104 Drop,
106}
107
108#[derive(Copy, Clone, Debug, PartialEq, Eq)]
109pub(super) enum StartPollAction {
110 Poll,
112
113 Canceled {
117 wake_join_waker: bool,
119 },
120
121 CantPoll,
123}
124
125pub(super) type WakeAction = OrDrop<ScheduleAction>;
126
127impl State {
128 #[inline]
129 pub(crate) fn ref_count(self) -> usize {
130 self.get(Self::REFS)
131 }
132
133 fn drop_ref(self) -> Self {
134 Self(self.0 - REF_ONE)
135 }
136
137 fn clone_ref(self) -> Self {
138 Self(self.0 + REF_ONE)
139 }
140}
141
142const REF_ONE: usize = State::REFS.first_bit();
143const REF_MAX: usize = State::REFS.raw_mask();
144
145#[derive(Copy, Clone, Debug, Eq, PartialEq)]
146#[repr(u8)]
147enum JoinWakerState {
148 Empty = 0b00,
150 Registering = 0b01,
152 Waiting = 0b10,
154 Woken = 0b11,
156}
157
158impl State {
161 fn has_join_waker(&mut self, should_wait: &mut bool) -> bool {
162 match self.get(State::JOIN_WAKER) {
163 JoinWakerState::Empty => false,
164 JoinWakerState::Registering => {
165 *should_wait = true;
166 debug_assert!(
167 self.get(State::HAS_JOIN_HANDLE),
168 "a task cannot register a join waker if it does not have a join handle!",
169 );
170 true
171 }
172 JoinWakerState::Waiting => {
173 debug_assert!(
174 self.get(State::HAS_JOIN_HANDLE),
175 "a task cannot have a join waker if it does not have a join handle!",
176 );
177 *should_wait = false;
178 self.set(State::JOIN_WAKER, JoinWakerState::Empty);
179 true
180 }
181 JoinWakerState::Woken => {
182 debug_assert!(
183 false,
184 "join waker should not be woken until task has completed, wtf"
185 );
186 false
187 }
188 }
189 }
190}
191
192impl StateCell {
193 #[cfg(not(loom))]
194 pub const fn new() -> Self {
195 Self(AtomicUsize::new(REF_ONE))
196 }
197
198 #[cfg(loom)]
199 pub fn new() -> Self {
200 Self(AtomicUsize::new(REF_ONE))
201 }
202
203 pub(super) fn start_poll(&self) -> StartPollAction {
204 let mut should_wait_for_join_waker = false;
205 let action = self.transition(|state| {
206 if test_dbg!(state.get(State::POLLING)) || test_dbg!(state.get(State::COMPLETED)) {
209 return StartPollAction::CantPoll;
210 }
211
212 if test_dbg!(state.get(State::CANCELED)) {
214 let wake_join_waker = state.has_join_waker(&mut should_wait_for_join_waker);
215 return StartPollAction::Canceled { wake_join_waker };
216 }
217
218 state
219 .set(State::POLLING, true)
221 .set(State::WOKEN, false);
223 StartPollAction::Poll
224 });
225
226 if should_wait_for_join_waker {
227 debug_assert!(matches!(action, StartPollAction::Canceled { .. }));
228 self.wait_for_join_waker(self.load(Acquire));
229 }
230
231 action
232 }
233
234 pub(super) fn end_poll(&self, completed: bool) -> PollResult {
235 let mut should_wait_for_join_waker = false;
236 let action = self.transition(|state| {
237 debug_assert!(state.get(State::POLLING));
239 debug_assert!(!state.get(State::COMPLETED));
240 debug_assert!(
241 state.ref_count() > 0,
242 "cannot poll a task that has zero references, what is happening!"
243 );
244
245 state
246 .set(State::POLLING, false)
247 .set(State::COMPLETED, completed);
248
249 if !test_dbg!(completed) && test_dbg!(state.get(State::WOKEN)) {
251 return PollResult::PendingSchedule;
252 }
253
254 let had_join_waker = if test_dbg!(completed) {
255 state.set(State::HAS_OUTPUT, true);
258 state.has_join_waker(&mut should_wait_for_join_waker)
259 } else {
260 false
261 };
262
263 if had_join_waker {
264 PollResult::ReadyJoined
265 } else if completed {
266 PollResult::Ready
267 } else {
268 PollResult::Pending
269 }
270 });
271
272 if should_wait_for_join_waker {
273 debug_assert_eq!(action, PollResult::ReadyJoined);
274 self.wait_for_join_waker(self.load(Acquire));
275 }
276
277 action
278 }
279
280 pub(super) fn wake_by_val(&self) -> WakeAction {
283 self.transition(|state| {
284 if test_dbg!(state.get(State::POLLING)) {
287 *state = state.with(State::WOKEN, true).drop_ref();
288 assert!(state.ref_count() > 0);
289
290 return OrDrop::Action(ScheduleAction::None);
291 }
292
293 if test_dbg!(state.get(State::COMPLETED)) || test_dbg!(state.get(State::WOKEN)) {
297 let new_state = state.drop_ref();
298 *state = new_state;
299 return if new_state.ref_count() == 0 {
300 OrDrop::Drop
301 } else {
302 OrDrop::Action(ScheduleAction::None)
303 };
304 }
305
306 *state = state.with(State::WOKEN, true).clone_ref();
308 OrDrop::Action(ScheduleAction::Enqueue)
309 })
310 }
311
312 pub(super) fn wake_by_ref(&self) -> ScheduleAction {
315 self.transition(|state| {
316 if test_dbg!(state.get(State::COMPLETED)) || test_dbg!(state.get(State::WOKEN)) {
317 return ScheduleAction::None;
318 }
319
320 if test_dbg!(state.get(State::POLLING)) {
321 state.set(State::WOKEN, true);
322 return ScheduleAction::None;
323 }
324
325 *state = state.with(State::WOKEN, true).clone_ref();
326 ScheduleAction::Enqueue
327 })
328 }
329
330 pub(super) fn set_woken(&self) {
331 self.0.fetch_or(State::WOKEN.raw_mask(), AcqRel);
332 }
333
334 #[inline]
335 pub(super) fn clone_ref(&self) {
336 let old_refs = self.0.fetch_add(REF_ONE, Relaxed);
348 test_dbg!(State::REFS.unpack(old_refs));
349
350 if test_dbg!(old_refs > REF_MAX) {
360 panic!("task reference count overflow");
361 }
362 }
363
364 #[inline]
365 pub(super) fn drop_ref(&self) -> bool {
366 test_debug!("StateCell::drop_ref");
367 let old_refs = self.0.fetch_sub(REF_ONE, Release);
370
371 test_dbg!(State::REFS.unpack(old_refs));
377 let old_refs = old_refs >> State::REFS.least_significant_index();
378
379 if test_dbg!(old_refs) > 1 {
381 return false;
382 }
383
384 atomic::fence(Acquire);
385 true
386 }
387
388 pub(super) fn cancel(&self) -> bool {
392 test_debug!("StateCell::cancel");
393 self.transition(|state| {
396 if state.get(State::CANCELED) {
398 return false;
399 }
400
401 state.set(State::CANCELED, true).set(State::WOKEN, true);
404
405 true
406 })
407 }
408
409 #[inline]
410 pub(super) fn create_join_handle(&self) {
411 test_debug!("StateCell::create_join_handle");
412 self.transition(|state| {
413 debug_assert!(
414 !state.get(State::HAS_JOIN_HANDLE),
415 "task already has a join handle, cannot create a new one! state={state:?}"
416 );
417
418 *state = state.with(State::HAS_JOIN_HANDLE, true);
419 })
420 }
421
422 #[inline]
423 pub(super) fn drop_join_handle(&self) {
424 test_debug!("StateCell::drop_join_handle");
425 const MASK: usize = !State::HAS_JOIN_HANDLE.raw_mask();
426 let _prev = self.0.fetch_and(MASK, Release);
427 test_trace!(
428 "drop_join_handle; prev_state:\n{}\nstate:\n{}",
429 State::from_bits(_prev),
430 self.load(Acquire),
431 );
432 debug_assert!(
433 State(_prev).get(State::HAS_JOIN_HANDLE),
434 "tried to drop a join handle when the task did not have a join handle!\nstate: {:#?}",
435 State(_prev),
436 )
437 }
438
439 pub(super) fn try_join(&self) -> JoinAction {
441 fn should_register(state: &mut State) -> JoinAction {
442 let action = match state.get(State::JOIN_WAKER) {
443 JoinWakerState::Empty => JoinAction::Register,
444 x => {
445 debug_assert_eq!(x, JoinWakerState::Waiting);
446 JoinAction::Reregister
447 }
448 };
449 state.set(State::JOIN_WAKER, JoinWakerState::Registering);
450
451 action
452 }
453
454 self.transition(|state| {
455 let has_output = test_dbg!(state.get(State::HAS_OUTPUT));
456
457 if test_dbg!(state.get(State::CANCELED)) {
458 return JoinAction::Canceled {
459 completed: has_output,
460 };
461 }
462
463 if test_dbg!(!state.get(State::COMPLETED)) {
465 return should_register(state);
466 }
467
468 if !has_output {
470 return should_register(state);
471 }
472
473 *state = state.with(State::HAS_OUTPUT, false);
474 JoinAction::TakeOutput
475 })
476 }
477
478 pub(super) fn set_join_waker_registered(&self) {
479 self.transition(|state| {
480 debug_assert_eq!(state.get(State::JOIN_WAKER), JoinWakerState::Registering);
481 state
482 .set(State::HAS_JOIN_HANDLE, true)
483 .set(State::JOIN_WAKER, JoinWakerState::Waiting);
484 })
485 }
486
487 pub(super) fn join_waker_needs_drop(&self) -> bool {
493 let state = self.load(Acquire);
494 match test_dbg!(state.get(State::JOIN_WAKER)) {
495 JoinWakerState::Empty | JoinWakerState::Woken => return false,
496 JoinWakerState::Registering => self.wait_for_join_waker(state),
497 JoinWakerState::Waiting => {}
498 }
499
500 true
501 }
502
503 pub(super) fn load(&self, order: Ordering) -> State {
504 State(self.0.load(order))
505 }
506
507 #[cfg_attr(test, track_caller)]
510 fn transition<T>(&self, mut transition: impl FnMut(&mut State) -> T) -> T {
511 let mut current = self.load(Acquire);
512 loop {
513 test_trace!("StateCell::transition; current:\n{}", current);
514 let mut next = current;
515 let res = transition(&mut next);
517
518 if test_dbg!(current.0 == next.0) {
519 return res;
520 }
521
522 test_trace!("StateCell::transition; next:\n{}", next);
523 match self
524 .0
525 .compare_exchange_weak(current.0, next.0, AcqRel, Acquire)
526 {
527 Ok(_) => return res,
528 Err(actual) => current = State(actual),
529 }
530 }
531 }
532
533 fn wait_for_join_waker(&self, mut state: State) {
534 test_trace!("StateCell::wait_for_join_waker");
535 let mut boff = Backoff::new();
536 loop {
537 state.set(State::JOIN_WAKER, JoinWakerState::Waiting);
538 let next = state.with(State::JOIN_WAKER, JoinWakerState::Woken);
539 match self
540 .0
541 .compare_exchange_weak(state.0, next.0, AcqRel, Acquire)
542 {
543 Ok(_) => return,
544 Err(actual) => state = State(actual),
545 }
546 boff.spin();
547 }
548 }
549}
550
551impl fmt::Debug for StateCell {
552 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
553 self.load(Relaxed).fmt(f)
554 }
555}
556
557impl mycelium_bitfield::FromBits<usize> for JoinWakerState {
558 type Error = core::convert::Infallible;
559
560 const BITS: u32 = 2;
562
563 #[inline]
564 #[allow(clippy::literal_string_with_formatting_args)]
565 fn try_from_bits(bits: usize) -> Result<Self, Self::Error> {
566 match bits {
567 b if b == Self::Registering as usize => Ok(Self::Registering),
568 b if b == Self::Waiting as usize => Ok(Self::Waiting),
569 b if b == Self::Empty as usize => Ok(Self::Empty),
570 b if b == Self::Woken as usize => Ok(Self::Woken),
571 _ => unsafe {
572 unreachable_unchecked!("invalid join waker state {bits:#b}")
574 },
575 }
576 }
577
578 #[inline]
579 fn into_bits(self) -> usize {
580 self as u8 as usize
581 }
582}
583
584#[cfg(all(test, not(loom)))]
585mod tests {
586 use super::*;
587
588 #[test]
589 #[cfg_attr(miri, ignore)]
591 fn packing_specs_valid() {
592 State::assert_valid()
593 }
594
595 #[test]
596 #[cfg_attr(miri, ignore)]
598 fn debug_alt() {
599 let state = StateCell::new();
600 println!("{state:#?}");
601 }
602}