maitake/task/
join_handle.rs

1use super::{Context, Poll, TaskId, TaskRef};
2use core::{future::Future, marker::PhantomData, pin::Pin};
3use mycelium_util::fmt;
4
5/// An owned permission to join a [task] (await its termination).
6///
7/// This is equivalent to the standard library's [`std::thread::JoinHandle`]
8/// type, but for asynchronous tasks rather than OS threads.
9///
10/// A `JoinHandle` *detaches* the associated task when it is dropped, which
11/// means that there is no longer any handle to the task and no way to await
12/// its termination.
13///
14/// `JoinHandle`s implement [`Future`], so a task's output can be awaited by
15/// `.await`ing its `JoinHandle`.
16///
17/// This `struct` is returned by the [`Scheduler::spawn`] and
18/// [`Scheduler::spawn_allocated`] methods, and the [`task::Builder::spawn`] and
19/// [`task::Builder::spawn_allocated`] methods.
20///
21/// [task]: crate::task
22/// [`std::thread::JoinHandle`]: https://doc.rust-lang.org/stable/std/thread/struct.JoinHandle.html
23/// [`Scheduler::spawn`]: crate::scheduler::Scheduler::spawn
24/// [`Scheduler::spawn_allocated`]: crate::scheduler::Scheduler::spawn_allocated
25/// [`task::Builder::spawn`]: crate::task::Builder::spawn
26/// [`task::Builder::spawn_allocated`]: crate::task::Builder::spawn_allocated
27#[derive(PartialEq, Eq)]
28// This clippy lint appears to be triggered incorrectly; this type *does* derive
29// `Eq` based on its `PartialEq<Self>` impl, but it also implements `PartialEq`
30// with types other than `Self` (which cannot impl `Eq`).
31#[allow(clippy::derive_partial_eq_without_eq)]
32pub struct JoinHandle<T> {
33    task: JoinHandleState,
34    id: TaskId,
35    _t: PhantomData<fn(T)>,
36}
37
38/// Errors returned by awaiting a [`JoinHandle`].
39#[derive(PartialEq, Eq)]
40pub struct JoinError<T> {
41    kind: JoinErrorKind,
42    id: TaskId,
43    output: Option<T>,
44}
45
46#[derive(PartialEq, Eq, Debug)]
47enum JoinHandleState {
48    Task(TaskRef),
49    Empty,
50    Error(JoinErrorKind),
51}
52
53#[derive(Debug, PartialEq, Eq)]
54#[non_exhaustive]
55pub(crate) enum JoinErrorKind {
56    /// The task was canceled.
57    Canceled {
58        /// `true` if the task was canceled after it completed successfully.
59        completed: bool,
60    },
61
62    /// A stub was awaited
63    StubNever,
64
65    /// The scheduler has been dropped.
66    Shutdown,
67}
68
69impl<T> JoinHandle<T> {
70    /// Converts a `TaskRef` into a `JoinHandle`.
71    ///
72    /// # Safety
73    ///
74    /// The pointed type must actually output a `T`-typed value.
75    pub(super) unsafe fn from_task_ref(task: TaskRef) -> Self {
76        task.state().create_join_handle();
77        let id = task.id();
78        Self {
79            task: JoinHandleState::Task(task),
80            id,
81            _t: PhantomData,
82        }
83    }
84
85    pub(crate) fn error(kind: JoinErrorKind) -> Self {
86        Self {
87            id: TaskId::stub(),
88            task: JoinHandleState::Error(kind),
89            _t: PhantomData,
90        }
91    }
92
93    /// Returns a [`TaskRef`] referencing the task this [`JoinHandle`] is
94    /// associated with.
95    ///
96    /// This increases the task's reference count; its storage is not
97    /// deallocated until all such [`TaskRef`]s are dropped.
98    #[must_use]
99    pub fn task_ref(&self) -> TaskRef {
100        match self.task {
101            JoinHandleState::Task(ref task) => task.clone(),
102            JoinHandleState::Empty => {
103                panic!("`TaskRef` only taken while polling a `JoinHandle`; this is a bug")
104            }
105            JoinHandleState::Error(ref error) => panic!("`JoinHandle` errored: {error:?}"),
106        }
107    }
108
109    /// Returns `true` if this task has completed.
110    ///
111    /// Tasks are considered completed when the spawned [`Future`] has returned
112    /// [`Poll::Ready`], or if the task has been canceled by the [`cancel()`]
113    /// method.
114    ///
115    /// **Note**: This method can return `false` after [`cancel()`] has
116    /// been called. This is because calling `cancel` *begins* the process of
117    /// cancelling a task. The task is not considered canceled until it has been
118    /// polled by the scheduler after calling [`cancel()`].
119    ///
120    /// [`cancel()`]: Self::cancel
121    #[inline]
122    #[must_use]
123    pub fn is_complete(&self) -> bool {
124        match self.task {
125            JoinHandleState::Task(ref task) => task.is_complete(),
126            // if the `JoinHandle`'s `TaskRef` has been taken, we know the
127            // `Future` impl for `JoinHandle` completed, and the task has
128            // _definitely_ completed.
129            _ => true,
130        }
131    }
132
133    /// Forcibly cancel the task.
134    ///
135    /// Canceling a task sets a flag indicating that it has been canceled and
136    /// should terminate. The next time a canceled task is polled by the
137    /// scheduler, it will terminate instead of polling the inner [`Future`]. If
138    /// the task has a [`JoinHandle`], that [`JoinHandle`] will complete with a
139    /// [`JoinError`]. The task then will be deallocated once all
140    /// [`JoinHandle`]s and [`TaskRef`]s referencing it have been dropped.
141    ///
142    /// This method returns `true` if the task was canceled successfully, and
143    /// `false` if the task could not be canceled (i.e., it has already completed,
144    /// has already been canceled, cancel culture has gone TOO FAR, et cetera).
145    pub fn cancel(&self) -> bool {
146        match self.task {
147            JoinHandleState::Task(ref task) => task.cancel(),
148            _ => false,
149        }
150    }
151
152    /// Returns a [`TaskId`] that uniquely identifies this [task].
153    ///
154    /// The returned ID does *not* increment the task's reference count, and may
155    /// persist even after the task it identifies has completed and been
156    /// deallocated.
157    ///
158    /// [task]: crate::task
159    #[must_use]
160    #[inline]
161    #[track_caller]
162    pub fn id(&self) -> TaskId {
163        self.id
164    }
165}
166
167impl<T> Future for JoinHandle<T> {
168    type Output = Result<T, JoinError<T>>;
169
170    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
171        let this = self.get_mut();
172        let task = match core::mem::replace(&mut this.task, JoinHandleState::Empty) {
173            JoinHandleState::Task(task) => task,
174            JoinHandleState::Empty => {
175                panic!("`TaskRef` only taken while polling a `JoinHandle`; this is a bug")
176            }
177            JoinHandleState::Error(kind) => {
178                return Poll::Ready(Err(JoinError {
179                    kind,
180                    id: this.id,
181                    output: None,
182                }))
183            }
184        };
185        let poll = unsafe {
186            // Safety: the `JoinHandle` must have been constructed with the
187            // task's actual output type!
188            task.poll_join::<T>(cx)
189        };
190        if poll.is_pending() {
191            this.task = JoinHandleState::Task(task);
192        } else {
193            // clear join interest
194            task.state().drop_join_handle();
195        }
196        poll
197    }
198}
199
200impl<T> Drop for JoinHandle<T> {
201    fn drop(&mut self) {
202        // if the JoinHandle has not already been consumed, clear the join
203        // handle flag on the task.
204        if let JoinHandleState::Task(ref task) = self.task {
205            test_debug!(
206                task = ?self.task,
207                task.tid = task.id().as_u64(),
208                consumed = false,
209                "drop JoinHandle",
210            );
211            task.state().drop_join_handle();
212        } else {
213            test_debug!(
214                task = ?self.task,
215                consumed = true,
216                "drop JoinHandle",
217            );
218        }
219    }
220}
221
222// ==== PartialEq impls for JoinHandle/TaskRef ====
223
224impl<T> PartialEq<TaskRef> for JoinHandle<T> {
225    fn eq(&self, other: &TaskRef) -> bool {
226        match self.task {
227            JoinHandleState::Task(ref task) => task == other,
228            _ => false,
229        }
230    }
231}
232
233impl<T> PartialEq<&'_ TaskRef> for JoinHandle<T> {
234    fn eq(&self, other: &&TaskRef) -> bool {
235        match self.task {
236            JoinHandleState::Task(ref task) => task == *other,
237            _ => false,
238        }
239    }
240}
241
242impl<T> PartialEq<JoinHandle<T>> for TaskRef {
243    fn eq(&self, other: &JoinHandle<T>) -> bool {
244        match other.task {
245            JoinHandleState::Task(ref task) => self == task,
246            _ => false,
247        }
248    }
249}
250
251impl<T> PartialEq<&'_ JoinHandle<T>> for TaskRef {
252    fn eq(&self, other: &&JoinHandle<T>) -> bool {
253        match other.task {
254            JoinHandleState::Task(ref task) => self == task,
255            _ => false,
256        }
257    }
258}
259
260// ==== PartialEq impls for JoinHandle/TaskId ====
261
262impl<T> PartialEq<TaskId> for JoinHandle<T> {
263    #[inline]
264    fn eq(&self, other: &TaskId) -> bool {
265        self.id == other
266    }
267}
268
269impl<T> PartialEq<&'_ TaskId> for JoinHandle<T> {
270    #[inline]
271    fn eq(&self, other: &&TaskId) -> bool {
272        self.id == *other
273    }
274}
275
276impl<T> PartialEq<JoinHandle<T>> for TaskId {
277    #[inline]
278    fn eq(&self, other: &JoinHandle<T>) -> bool {
279        self == other.id
280    }
281}
282
283impl<T> PartialEq<&'_ JoinHandle<T>> for TaskId {
284    #[inline]
285    fn eq(&self, other: &&JoinHandle<T>) -> bool {
286        self == other.id
287    }
288}
289
290impl<T> fmt::Debug for JoinHandle<T> {
291    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292        f.debug_struct("JoinHandle")
293            .field("output", &core::any::type_name::<T>())
294            .field("task", &self.task)
295            .field("id", &self.id)
296            .finish()
297    }
298}
299
300// === impl JoinError ===
301
302impl JoinError<()> {
303    #[inline]
304    pub(super) fn canceled(completed: bool, id: TaskId) -> Poll<Result<(), Self>> {
305        Poll::Ready(Err(Self {
306            kind: JoinErrorKind::Canceled { completed },
307            id,
308            output: None,
309        }))
310    }
311
312    #[allow(dead_code)]
313    #[inline]
314    pub(crate) fn stub() -> Self {
315        Self {
316            kind: JoinErrorKind::StubNever,
317            id: TaskId::stub(),
318            output: None,
319        }
320    }
321
322    #[must_use]
323    pub(super) fn with_output<T>(self, output: Option<T>) -> JoinError<T> {
324        JoinError {
325            kind: self.kind,
326            id: self.id,
327            output,
328        }
329    }
330}
331
332impl<T> JoinError<T> {
333    /// Returns `true` if a task failed to join because it was canceled.
334    pub fn is_canceled(&self) -> bool {
335        matches!(self.kind, JoinErrorKind::Canceled { .. })
336    }
337
338    /// Returns `true` if the task completed successfully before it was canceled.
339    pub fn is_completed(&self) -> bool {
340        match self.kind {
341            JoinErrorKind::Canceled { completed } => completed,
342            _ => false,
343        }
344    }
345
346    /// Returns the [`TaskId`] of the task that failed to join.
347    pub fn id(&self) -> TaskId {
348        self.id
349    }
350
351    /// Returns the task's output, if the task completed successfully before it
352    /// was canceled.
353    ///
354    /// Otherwise, returns `None`.
355    pub fn output(self) -> Option<T> {
356        self.output
357    }
358}
359
360impl<T> fmt::Display for JoinError<T> {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        match self.kind {
363            JoinErrorKind::Canceled { completed } => {
364                let completed = if completed {
365                    " (after completing successfully)"
366                } else {
367                    ""
368                };
369                write!(f, "task {} was canceled{completed}", self.id)
370            }
371            JoinErrorKind::StubNever => f.write_str("the stub task can never join"),
372            JoinErrorKind::Shutdown => f.write_str("the scheduler has already shut down"),
373        }
374    }
375}
376
377impl<T> fmt::Debug for JoinError<T> {
378    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379        f.debug_struct("JoinError")
380            .field("id", &self.id)
381            .field("kind", &self.kind)
382            .finish()
383    }
384}
385
386feature! {
387    #![feature = "core-error"]
388    impl<T> core::error::Error for JoinError<T> {}
389}