maitake/scheduler/
steal.rs

1use super::*;
2use crate::loom::sync::atomic::AtomicUsize;
3use cordyceps::mpsc_queue::{self, MpscQueue};
4use core::marker::PhantomData;
5use mycelium_util::fmt;
6
7/// An injector queue for spawning tasks on multiple [`Scheduler`] instances.
8pub struct Injector<S> {
9    /// The queue.
10    queue: MpscQueue<Header>,
11
12    /// The number of tasks in the queue.
13    tasks: AtomicUsize,
14
15    /// An `Injector` can only be used with [`Schedule`] implementations that
16    /// are the same type, because the task allocation is sized based on the
17    /// scheduler value.
18    _scheduler_type: PhantomData<fn(S)>,
19}
20
21/// A handle for stealing tasks from a [`Scheduler`]'s run queue, or an
22/// [`Injector`] queue.
23///
24/// While this handle exists, no other worker can steal tasks from the queue.
25pub struct Stealer<'worker, S> {
26    queue: mpsc_queue::Consumer<'worker, Header>,
27
28    /// The initial task count in the target queue when this `Stealer` was created.
29    snapshot: usize,
30
31    /// A reference to the target queue's current task count. This is used to
32    /// decrement the task count when stealing.
33    tasks: &'worker AtomicUsize,
34
35    /// The type of the [`Schedule`] implementation that tasks are being stolen
36    /// from.
37    ///
38    /// This must be the same type as the scheduler that is stealing tasks, as
39    /// the size of the scheduler value stored in the task must be the same.
40    _scheduler_type: PhantomData<fn(S)>,
41}
42
43/// Errors returned by [`Injector::try_steal`], [`Scheduler::try_steal`], and
44/// [`StaticScheduler::try_steal`].
45#[derive(Debug, Clone, Eq, PartialEq)]
46#[non_exhaustive]
47pub enum TryStealError {
48    /// Tasks could not be stolen because the targeted queue already has a
49    /// consumer.
50    Busy,
51    /// No tasks were available to steal.
52    Empty,
53}
54
55impl<S: Schedule> Injector<S> {
56    /// Returns a new injector queue.
57    ///
58    /// # Safety
59    ///
60    /// The "stub" provided must ONLY EVER be used for a single
61    /// `Injector` instance. Re-using the stub for multiple distributors
62    /// or schedulers may lead to undefined behavior.
63    #[must_use]
64    #[cfg(not(loom))]
65    pub const unsafe fn new_with_static_stub(stub: &'static TaskStub) -> Self {
66        Self {
67            queue: MpscQueue::new_with_static_stub(&stub.hdr),
68            tasks: AtomicUsize::new(0),
69            _scheduler_type: PhantomData,
70        }
71    }
72
73    /// Spawns a pre-allocated task on the injector queue.
74    ///
75    /// The spawned task will be executed by any
76    /// [`Scheduler`]/[`StaticScheduler`] instance that runs tasks from this
77    /// queue.
78    ///
79    /// This method is used to spawn a task that requires some bespoke
80    /// procedure of allocation, typically of a custom [`Storage`] implementor.
81    /// See the documentation for the [`Storage`] trait for more details on
82    /// using custom task storage.
83    ///
84    /// When the "alloc" feature flag is available, tasks that do not require
85    /// custom storage may be spawned using the [`Injector::spawn`] method,
86    /// instead.
87    ///
88    /// This method returns a [`JoinHandle`] that can be used to await the
89    /// task's output. Dropping the [`JoinHandle`] _detaches_ the spawned task,
90    /// allowing it to run in the background without awaiting its output.
91    ///
92    /// [`Storage`]: crate::task::Storage
93    pub fn spawn_allocated<STO, F>(&self, task: STO::StoredTask) -> JoinHandle<F::Output>
94    where
95        F: Future + Send + 'static,
96        F::Output: Send + 'static,
97        STO: Storage<S, F>,
98    {
99        self.tasks.fetch_add(1, Relaxed);
100        let (task, join) = TaskRef::build_allocated::<S, F, STO>(TaskRef::NO_BUILDER, task);
101        self.queue.enqueue(task);
102        join
103    }
104
105    /// Attempt to take tasks from the injector queue.
106    ///
107    /// # Returns
108    ///
109    /// - `Ok(`[`Stealer`]`)) if tasks can be spawned from the injector
110    ///   queue.
111    /// - `Err`([`TryStealError::Empty`]`)` if there were no tasks in this
112    ///   injector queue.
113    /// - `Err`([`TryStealError::Busy`]`)` if another worker was already
114    ///   taking tasks from this injector queue.
115    pub fn try_steal(&self) -> Result<Stealer<'_, S>, TryStealError> {
116        Stealer::try_new(&self.queue, &self.tasks)
117    }
118}
119
120impl<S> fmt::Debug for Injector<S> {
121    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
122        // determine if alt-mode is enabled *before* constructing the
123        // `DebugStruct`, because that mutably borrows the formatter.
124        let alt = f.alternate();
125
126        let Self {
127            queue,
128            tasks,
129            _scheduler_type,
130        } = self;
131        let mut debug = f.debug_struct("Injector");
132        debug
133            .field("queue", queue)
134            .field("tasks", &tasks.load(Relaxed));
135
136        // only include the kind of wordy type name field if alt-mode
137        // (multi-line) formatting is enabled.
138        if alt {
139            debug.field(
140                "scheduler",
141                &format_args!("PhantomData<{}>", core::any::type_name::<S>()),
142            );
143        }
144
145        debug.finish()
146    }
147}
148
149// === impl Stealer ===
150
151impl<'worker, S: Schedule> Stealer<'worker, S> {
152    fn try_new(
153        queue: &'worker MpscQueue<Header>,
154        tasks: &'worker AtomicUsize,
155    ) -> Result<Self, TryStealError> {
156        let snapshot = tasks.load(Acquire);
157        if snapshot == 0 {
158            return Err(TryStealError::Empty);
159        }
160
161        let queue = queue.try_consume().ok_or(TryStealError::Busy)?;
162        Ok(Self {
163            queue,
164            snapshot,
165            tasks,
166            _scheduler_type: PhantomData,
167        })
168    }
169
170    /// Returns the number of tasks that were in the targeted queue when this
171    /// `Stealer` was created.
172    ///
173    /// This number is *not* guaranteed to be greater than the *current* number
174    /// of tasks returned by [`task_count`], as new tasks may be enqueued while
175    /// stealing.
176    ///
177    /// [`task_count`]: Self::task_count
178    pub fn initial_task_count(&self) -> usize {
179        self.snapshot
180    }
181
182    /// Returns the number of tasks currently in the targeted queue.
183    pub fn task_count(&self) -> usize {
184        self.tasks.load(Acquire)
185    }
186
187    /// Steal one task from the targeted queue and spawn it on the provided
188    /// `scheduler`.
189    ///
190    /// # Returns
191    ///
192    /// - `true` if a task was successfully stolen.
193    /// - `false` if the targeted queue is empty.
194    pub fn spawn_one(&self, scheduler: &S) -> bool {
195        let Some(task) = self.queue.dequeue() else {
196            return false;
197        };
198        test_trace!(?task, "stole");
199
200        // decrement the target queue's task count
201        self.tasks.fetch_sub(1, Release);
202
203        // TODO(eliza): probably handle cancelation by throwing out canceled
204        // tasks here before binding them?
205        unsafe {
206            task.bind_scheduler(scheduler.clone());
207        }
208        scheduler.schedule(task);
209        true
210    }
211
212    /// Steal up to `max` tasks from the targeted queue and spawn them on the
213    /// provided scheduler.
214    ///
215    /// # Returns
216    ///
217    /// The number of tasks stolen. This may be less than `max` if the targeted
218    /// queue contained fewer tasks than `max`.
219    pub fn spawn_n(&self, scheduler: &S, max: usize) -> usize {
220        let mut stolen = 0;
221        while stolen <= max && self.spawn_one(scheduler) {
222            stolen += 1;
223        }
224
225        stolen
226    }
227
228    /// Steal half of the tasks currently in the targeted queue and spawn them
229    /// on the provided scheduler.
230    ///
231    /// This is a convenience method that is equivalent to the following:
232    ///
233    /// ```
234    /// # fn docs() {
235    /// # use maitake::scheduler::{StaticScheduler, Stealer};
236    /// # let scheduler = unimplemented!();
237    /// # let stealer: Stealer<'_, &'static StaticScheduler> = unimplemented!();
238    /// stealer.spawn_n(&scheduler, stealer.initial_task_count() / 2);
239    /// # }
240    /// ```
241    ///
242    /// # Returns
243    ///
244    /// The number of tasks stolen.
245    pub fn spawn_half(&self, scheduler: &S) -> usize {
246        self.spawn_n(scheduler, self.initial_task_count() / 2)
247    }
248}
249
250impl<S> fmt::Debug for Stealer<'_, S> {
251    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
252        // determine if alt-mode is enabled *before* constructing the
253        // `DebugStruct`, because that mutably borrows the formatter.
254        let alt = f.alternate();
255
256        let Self {
257            queue,
258            snapshot,
259            tasks,
260            _scheduler_type,
261        } = self;
262        let mut debug = f.debug_struct("Stealer");
263        debug
264            .field("queue", queue)
265            .field("snapshot", snapshot)
266            .field("tasks", &tasks.load(Relaxed));
267
268        // only include the kind of wordy type name field if alt-mode
269        // (multi-line) formatting is enabled.
270        if alt {
271            debug.field(
272                "scheduler",
273                &format_args!("PhantomData<{}>", core::any::type_name::<S>()),
274            );
275        }
276
277        debug.finish()
278    }
279}
280
281// === impls on Scheduler types ===
282
283impl StaticScheduler {
284    /// Attempt to steal tasks from this scheduler's run queue.
285    ///
286    /// # Returns
287    ///
288    /// - `Ok(`[`Stealer`]`)) if tasks can be stolen from this scheduler's
289    ///   queue.
290    /// - `Err`([`TryStealError::Empty`]`)` if there were no tasks in this
291    ///   scheduler's run queue.
292    /// - `Err`([`TryStealError::Busy`]`)` if another worker was already
293    ///   stealing from this scheduler's run queue.
294    pub fn try_steal(&self) -> Result<Stealer<'_, &'static StaticScheduler>, TryStealError> {
295        Stealer::try_new(&self.0.run_queue, &self.0.queued)
296    }
297}
298
299feature! {
300    #![feature = "alloc"]
301
302    use alloc::boxed::Box;
303    use super::{BoxStorage, Task};
304
305    impl<S: Schedule> Injector<S> {
306        /// Returns a new `Injector` queue with a dynamically heap-allocated
307        /// [`TaskStub`].
308        #[must_use]
309        pub fn new() -> Self {
310            let stub_task = Box::new(Task::new_stub());
311            let (stub_task, _) =
312                TaskRef::new_allocated::<task::Stub, task::Stub, BoxStorage>(task::Stub, stub_task);
313            Self {
314                queue: MpscQueue::new_with_stub(stub_task),
315                tasks: AtomicUsize::new(0),
316                _scheduler_type: PhantomData,
317            }
318
319        }
320
321        /// Spawns a new task on the injector queue, to execute on any
322        /// [`Scheduler`]/[`StaticScheduler`] instance that runs tasks from this
323        /// queue.
324        ///
325        /// This method returns a [`JoinHandle`] that can be used to await the
326        /// task's output. Dropping the [`JoinHandle`] _detaches_ the spawned task,
327        /// allowing it to run in the background without awaiting its output.
328        #[inline]
329        #[track_caller]
330        pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
331        where
332            F: Future + Send + 'static,
333            F::Output: Send + 'static,
334        {
335            let task = Box::new(Task::<S, _, BoxStorage>::new(future));
336            self.spawn_allocated::<BoxStorage, _>(task)
337        }
338    }
339
340    impl<S: Schedule> Default for Injector<S> {
341        fn default() -> Self {
342            Self::new()
343        }
344    }
345
346
347    impl Scheduler {
348        /// Attempt to steal tasks from this scheduler's run queue.
349        ///
350        /// # Returns
351        ///
352        /// - `Ok(`[`Stealer`]`)) if tasks can be stolen from this scheduler's
353        ///   queue.
354        /// - `Err`([`TryStealError::Empty`]`)` if there were no tasks in this
355        ///   scheduler's run queue.
356        /// - `Err`([`TryStealError::Busy`]`)` if another worker was already
357        ///   stealing from this scheduler's run queue.
358        pub fn try_steal(&self) -> Result<Stealer<'_, Scheduler>, TryStealError> {
359            Stealer::try_new(&self.0.run_queue, &self.0.queued)
360        }
361    }
362
363}