maitake/scheduler/steal.rs
1use super::*;
2use crate::loom::sync::atomic::AtomicUsize;
3use cordyceps::mpsc_queue::{self, MpscQueue};
4use core::marker::PhantomData;
5use mycelium_util::fmt;
6
7/// An injector queue for spawning tasks on multiple [`Scheduler`] instances.
8pub struct Injector<S> {
9 /// The queue.
10 queue: MpscQueue<Header>,
11
12 /// The number of tasks in the queue.
13 tasks: AtomicUsize,
14
15 /// An `Injector` can only be used with [`Schedule`] implementations that
16 /// are the same type, because the task allocation is sized based on the
17 /// scheduler value.
18 _scheduler_type: PhantomData<fn(S)>,
19}
20
21/// A handle for stealing tasks from a [`Scheduler`]'s run queue, or an
22/// [`Injector`] queue.
23///
24/// While this handle exists, no other worker can steal tasks from the queue.
25pub struct Stealer<'worker, S> {
26 queue: mpsc_queue::Consumer<'worker, Header>,
27
28 /// The initial task count in the target queue when this `Stealer` was created.
29 snapshot: usize,
30
31 /// A reference to the target queue's current task count. This is used to
32 /// decrement the task count when stealing.
33 tasks: &'worker AtomicUsize,
34
35 /// The type of the [`Schedule`] implementation that tasks are being stolen
36 /// from.
37 ///
38 /// This must be the same type as the scheduler that is stealing tasks, as
39 /// the size of the scheduler value stored in the task must be the same.
40 _scheduler_type: PhantomData<fn(S)>,
41}
42
43/// Errors returned by [`Injector::try_steal`], [`Scheduler::try_steal`], and
44/// [`StaticScheduler::try_steal`].
45#[derive(Debug, Clone, Eq, PartialEq)]
46#[non_exhaustive]
47pub enum TryStealError {
48 /// Tasks could not be stolen because the targeted queue already has a
49 /// consumer.
50 Busy,
51 /// No tasks were available to steal.
52 Empty,
53}
54
55impl<S: Schedule> Injector<S> {
56 /// Returns a new injector queue.
57 ///
58 /// # Safety
59 ///
60 /// The "stub" provided must ONLY EVER be used for a single
61 /// `Injector` instance. Re-using the stub for multiple distributors
62 /// or schedulers may lead to undefined behavior.
63 #[must_use]
64 #[cfg(not(loom))]
65 pub const unsafe fn new_with_static_stub(stub: &'static TaskStub) -> Self {
66 Self {
67 queue: MpscQueue::new_with_static_stub(&stub.hdr),
68 tasks: AtomicUsize::new(0),
69 _scheduler_type: PhantomData,
70 }
71 }
72
73 /// Spawns a pre-allocated task on the injector queue.
74 ///
75 /// The spawned task will be executed by any
76 /// [`Scheduler`]/[`StaticScheduler`] instance that runs tasks from this
77 /// queue.
78 ///
79 /// This method is used to spawn a task that requires some bespoke
80 /// procedure of allocation, typically of a custom [`Storage`] implementor.
81 /// See the documentation for the [`Storage`] trait for more details on
82 /// using custom task storage.
83 ///
84 /// When the "alloc" feature flag is available, tasks that do not require
85 /// custom storage may be spawned using the [`Injector::spawn`] method,
86 /// instead.
87 ///
88 /// This method returns a [`JoinHandle`] that can be used to await the
89 /// task's output. Dropping the [`JoinHandle`] _detaches_ the spawned task,
90 /// allowing it to run in the background without awaiting its output.
91 ///
92 /// [`Storage`]: crate::task::Storage
93 pub fn spawn_allocated<STO, F>(&self, task: STO::StoredTask) -> JoinHandle<F::Output>
94 where
95 F: Future + Send + 'static,
96 F::Output: Send + 'static,
97 STO: Storage<S, F>,
98 {
99 self.tasks.fetch_add(1, Relaxed);
100 let (task, join) = TaskRef::build_allocated::<S, F, STO>(TaskRef::NO_BUILDER, task);
101 self.queue.enqueue(task);
102 join
103 }
104
105 /// Attempt to take tasks from the injector queue.
106 ///
107 /// # Returns
108 ///
109 /// - `Ok(`[`Stealer`]`)) if tasks can be spawned from the injector
110 /// queue.
111 /// - `Err`([`TryStealError::Empty`]`)` if there were no tasks in this
112 /// injector queue.
113 /// - `Err`([`TryStealError::Busy`]`)` if another worker was already
114 /// taking tasks from this injector queue.
115 pub fn try_steal(&self) -> Result<Stealer<'_, S>, TryStealError> {
116 Stealer::try_new(&self.queue, &self.tasks)
117 }
118}
119
120impl<S> fmt::Debug for Injector<S> {
121 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
122 // determine if alt-mode is enabled *before* constructing the
123 // `DebugStruct`, because that mutably borrows the formatter.
124 let alt = f.alternate();
125
126 let Self {
127 queue,
128 tasks,
129 _scheduler_type,
130 } = self;
131 let mut debug = f.debug_struct("Injector");
132 debug
133 .field("queue", queue)
134 .field("tasks", &tasks.load(Relaxed));
135
136 // only include the kind of wordy type name field if alt-mode
137 // (multi-line) formatting is enabled.
138 if alt {
139 debug.field(
140 "scheduler",
141 &format_args!("PhantomData<{}>", core::any::type_name::<S>()),
142 );
143 }
144
145 debug.finish()
146 }
147}
148
149// === impl Stealer ===
150
151impl<'worker, S: Schedule> Stealer<'worker, S> {
152 fn try_new(
153 queue: &'worker MpscQueue<Header>,
154 tasks: &'worker AtomicUsize,
155 ) -> Result<Self, TryStealError> {
156 let snapshot = tasks.load(Acquire);
157 if snapshot == 0 {
158 return Err(TryStealError::Empty);
159 }
160
161 let queue = queue.try_consume().ok_or(TryStealError::Busy)?;
162 Ok(Self {
163 queue,
164 snapshot,
165 tasks,
166 _scheduler_type: PhantomData,
167 })
168 }
169
170 /// Returns the number of tasks that were in the targeted queue when this
171 /// `Stealer` was created.
172 ///
173 /// This number is *not* guaranteed to be greater than the *current* number
174 /// of tasks returned by [`task_count`], as new tasks may be enqueued while
175 /// stealing.
176 ///
177 /// [`task_count`]: Self::task_count
178 pub fn initial_task_count(&self) -> usize {
179 self.snapshot
180 }
181
182 /// Returns the number of tasks currently in the targeted queue.
183 pub fn task_count(&self) -> usize {
184 self.tasks.load(Acquire)
185 }
186
187 /// Steal one task from the targeted queue and spawn it on the provided
188 /// `scheduler`.
189 ///
190 /// # Returns
191 ///
192 /// - `true` if a task was successfully stolen.
193 /// - `false` if the targeted queue is empty.
194 pub fn spawn_one(&self, scheduler: &S) -> bool {
195 let Some(task) = self.queue.dequeue() else {
196 return false;
197 };
198 test_trace!(?task, "stole");
199
200 // decrement the target queue's task count
201 self.tasks.fetch_sub(1, Release);
202
203 // TODO(eliza): probably handle cancelation by throwing out canceled
204 // tasks here before binding them?
205 unsafe {
206 task.bind_scheduler(scheduler.clone());
207 }
208 scheduler.schedule(task);
209 true
210 }
211
212 /// Steal up to `max` tasks from the targeted queue and spawn them on the
213 /// provided scheduler.
214 ///
215 /// # Returns
216 ///
217 /// The number of tasks stolen. This may be less than `max` if the targeted
218 /// queue contained fewer tasks than `max`.
219 pub fn spawn_n(&self, scheduler: &S, max: usize) -> usize {
220 let mut stolen = 0;
221 while stolen <= max && self.spawn_one(scheduler) {
222 stolen += 1;
223 }
224
225 stolen
226 }
227
228 /// Steal half of the tasks currently in the targeted queue and spawn them
229 /// on the provided scheduler.
230 ///
231 /// This is a convenience method that is equivalent to the following:
232 ///
233 /// ```
234 /// # fn docs() {
235 /// # use maitake::scheduler::{StaticScheduler, Stealer};
236 /// # let scheduler = unimplemented!();
237 /// # let stealer: Stealer<'_, &'static StaticScheduler> = unimplemented!();
238 /// stealer.spawn_n(&scheduler, stealer.initial_task_count() / 2);
239 /// # }
240 /// ```
241 ///
242 /// # Returns
243 ///
244 /// The number of tasks stolen.
245 pub fn spawn_half(&self, scheduler: &S) -> usize {
246 self.spawn_n(scheduler, self.initial_task_count() / 2)
247 }
248}
249
250impl<S> fmt::Debug for Stealer<'_, S> {
251 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
252 // determine if alt-mode is enabled *before* constructing the
253 // `DebugStruct`, because that mutably borrows the formatter.
254 let alt = f.alternate();
255
256 let Self {
257 queue,
258 snapshot,
259 tasks,
260 _scheduler_type,
261 } = self;
262 let mut debug = f.debug_struct("Stealer");
263 debug
264 .field("queue", queue)
265 .field("snapshot", snapshot)
266 .field("tasks", &tasks.load(Relaxed));
267
268 // only include the kind of wordy type name field if alt-mode
269 // (multi-line) formatting is enabled.
270 if alt {
271 debug.field(
272 "scheduler",
273 &format_args!("PhantomData<{}>", core::any::type_name::<S>()),
274 );
275 }
276
277 debug.finish()
278 }
279}
280
281// === impls on Scheduler types ===
282
283impl StaticScheduler {
284 /// Attempt to steal tasks from this scheduler's run queue.
285 ///
286 /// # Returns
287 ///
288 /// - `Ok(`[`Stealer`]`)) if tasks can be stolen from this scheduler's
289 /// queue.
290 /// - `Err`([`TryStealError::Empty`]`)` if there were no tasks in this
291 /// scheduler's run queue.
292 /// - `Err`([`TryStealError::Busy`]`)` if another worker was already
293 /// stealing from this scheduler's run queue.
294 pub fn try_steal(&self) -> Result<Stealer<'_, &'static StaticScheduler>, TryStealError> {
295 Stealer::try_new(&self.0.run_queue, &self.0.queued)
296 }
297}
298
299feature! {
300 #![feature = "alloc"]
301
302 use alloc::boxed::Box;
303 use super::{BoxStorage, Task};
304
305 impl<S: Schedule> Injector<S> {
306 /// Returns a new `Injector` queue with a dynamically heap-allocated
307 /// [`TaskStub`].
308 #[must_use]
309 pub fn new() -> Self {
310 let stub_task = Box::new(Task::new_stub());
311 let (stub_task, _) =
312 TaskRef::new_allocated::<task::Stub, task::Stub, BoxStorage>(task::Stub, stub_task);
313 Self {
314 queue: MpscQueue::new_with_stub(stub_task),
315 tasks: AtomicUsize::new(0),
316 _scheduler_type: PhantomData,
317 }
318
319 }
320
321 /// Spawns a new task on the injector queue, to execute on any
322 /// [`Scheduler`]/[`StaticScheduler`] instance that runs tasks from this
323 /// queue.
324 ///
325 /// This method returns a [`JoinHandle`] that can be used to await the
326 /// task's output. Dropping the [`JoinHandle`] _detaches_ the spawned task,
327 /// allowing it to run in the background without awaiting its output.
328 #[inline]
329 #[track_caller]
330 pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
331 where
332 F: Future + Send + 'static,
333 F::Output: Send + 'static,
334 {
335 let task = Box::new(Task::<S, _, BoxStorage>::new(future));
336 self.spawn_allocated::<BoxStorage, _>(task)
337 }
338 }
339
340 impl<S: Schedule> Default for Injector<S> {
341 fn default() -> Self {
342 Self::new()
343 }
344 }
345
346
347 impl Scheduler {
348 /// Attempt to steal tasks from this scheduler's run queue.
349 ///
350 /// # Returns
351 ///
352 /// - `Ok(`[`Stealer`]`)) if tasks can be stolen from this scheduler's
353 /// queue.
354 /// - `Err`([`TryStealError::Empty`]`)` if there were no tasks in this
355 /// scheduler's run queue.
356 /// - `Err`([`TryStealError::Busy`]`)` if another worker was already
357 /// stealing from this scheduler's run queue.
358 pub fn try_steal(&self) -> Result<Stealer<'_, Scheduler>, TryStealError> {
359 Stealer::try_new(&self.0.run_queue, &self.0.queued)
360 }
361 }
362
363}