1use super::{Context, Poll, TaskId, TaskRef};
2use core::{future::Future, marker::PhantomData, pin::Pin};
3use mycelium_util::fmt;
4
5#[derive(PartialEq, Eq)]
28#[allow(clippy::derive_partial_eq_without_eq)]
32pub struct JoinHandle<T> {
33 task: JoinHandleState,
34 id: TaskId,
35 _t: PhantomData<fn(T)>,
36}
37
38#[derive(PartialEq, Eq)]
40pub struct JoinError<T> {
41 kind: JoinErrorKind,
42 id: TaskId,
43 output: Option<T>,
44}
45
46#[derive(PartialEq, Eq, Debug)]
47enum JoinHandleState {
48 Task(TaskRef),
49 Empty,
50 Error(JoinErrorKind),
51}
52
53#[derive(Debug, PartialEq, Eq)]
54#[non_exhaustive]
55pub(crate) enum JoinErrorKind {
56 Canceled {
58 completed: bool,
60 },
61
62 StubNever,
64
65 Shutdown,
67}
68
69impl<T> JoinHandle<T> {
70 pub(super) unsafe fn from_task_ref(task: TaskRef) -> Self {
76 task.state().create_join_handle();
77 let id = task.id();
78 Self {
79 task: JoinHandleState::Task(task),
80 id,
81 _t: PhantomData,
82 }
83 }
84
85 pub(crate) fn error(kind: JoinErrorKind) -> Self {
86 Self {
87 id: TaskId::stub(),
88 task: JoinHandleState::Error(kind),
89 _t: PhantomData,
90 }
91 }
92
93 #[must_use]
99 pub fn task_ref(&self) -> TaskRef {
100 match self.task {
101 JoinHandleState::Task(ref task) => task.clone(),
102 JoinHandleState::Empty => {
103 panic!("`TaskRef` only taken while polling a `JoinHandle`; this is a bug")
104 }
105 JoinHandleState::Error(ref error) => panic!("`JoinHandle` errored: {error:?}"),
106 }
107 }
108
109 #[inline]
122 #[must_use]
123 pub fn is_complete(&self) -> bool {
124 match self.task {
125 JoinHandleState::Task(ref task) => task.is_complete(),
126 _ => true,
130 }
131 }
132
133 pub fn cancel(&self) -> bool {
146 match self.task {
147 JoinHandleState::Task(ref task) => task.cancel(),
148 _ => false,
149 }
150 }
151
152 #[must_use]
160 #[inline]
161 #[track_caller]
162 pub fn id(&self) -> TaskId {
163 self.id
164 }
165}
166
167impl<T> Future for JoinHandle<T> {
168 type Output = Result<T, JoinError<T>>;
169
170 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
171 let this = self.get_mut();
172 let task = match core::mem::replace(&mut this.task, JoinHandleState::Empty) {
173 JoinHandleState::Task(task) => task,
174 JoinHandleState::Empty => {
175 panic!("`TaskRef` only taken while polling a `JoinHandle`; this is a bug")
176 }
177 JoinHandleState::Error(kind) => {
178 return Poll::Ready(Err(JoinError {
179 kind,
180 id: this.id,
181 output: None,
182 }))
183 }
184 };
185 let poll = unsafe {
186 task.poll_join::<T>(cx)
189 };
190 if poll.is_pending() {
191 this.task = JoinHandleState::Task(task);
192 } else {
193 task.state().drop_join_handle();
195 }
196 poll
197 }
198}
199
200impl<T> Drop for JoinHandle<T> {
201 fn drop(&mut self) {
202 if let JoinHandleState::Task(ref task) = self.task {
205 test_debug!(
206 task = ?self.task,
207 task.tid = task.id().as_u64(),
208 consumed = false,
209 "drop JoinHandle",
210 );
211 task.state().drop_join_handle();
212 } else {
213 test_debug!(
214 task = ?self.task,
215 consumed = true,
216 "drop JoinHandle",
217 );
218 }
219 }
220}
221
222impl<T> PartialEq<TaskRef> for JoinHandle<T> {
225 fn eq(&self, other: &TaskRef) -> bool {
226 match self.task {
227 JoinHandleState::Task(ref task) => task == other,
228 _ => false,
229 }
230 }
231}
232
233impl<T> PartialEq<&'_ TaskRef> for JoinHandle<T> {
234 fn eq(&self, other: &&TaskRef) -> bool {
235 match self.task {
236 JoinHandleState::Task(ref task) => task == *other,
237 _ => false,
238 }
239 }
240}
241
242impl<T> PartialEq<JoinHandle<T>> for TaskRef {
243 fn eq(&self, other: &JoinHandle<T>) -> bool {
244 match other.task {
245 JoinHandleState::Task(ref task) => self == task,
246 _ => false,
247 }
248 }
249}
250
251impl<T> PartialEq<&'_ JoinHandle<T>> for TaskRef {
252 fn eq(&self, other: &&JoinHandle<T>) -> bool {
253 match other.task {
254 JoinHandleState::Task(ref task) => self == task,
255 _ => false,
256 }
257 }
258}
259
260impl<T> PartialEq<TaskId> for JoinHandle<T> {
263 #[inline]
264 fn eq(&self, other: &TaskId) -> bool {
265 self.id == other
266 }
267}
268
269impl<T> PartialEq<&'_ TaskId> for JoinHandle<T> {
270 #[inline]
271 fn eq(&self, other: &&TaskId) -> bool {
272 self.id == *other
273 }
274}
275
276impl<T> PartialEq<JoinHandle<T>> for TaskId {
277 #[inline]
278 fn eq(&self, other: &JoinHandle<T>) -> bool {
279 self == other.id
280 }
281}
282
283impl<T> PartialEq<&'_ JoinHandle<T>> for TaskId {
284 #[inline]
285 fn eq(&self, other: &&JoinHandle<T>) -> bool {
286 self == other.id
287 }
288}
289
290impl<T> fmt::Debug for JoinHandle<T> {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 f.debug_struct("JoinHandle")
293 .field("output", &core::any::type_name::<T>())
294 .field("task", &self.task)
295 .field("id", &self.id)
296 .finish()
297 }
298}
299
300impl JoinError<()> {
303 #[inline]
304 pub(super) fn canceled(completed: bool, id: TaskId) -> Poll<Result<(), Self>> {
305 Poll::Ready(Err(Self {
306 kind: JoinErrorKind::Canceled { completed },
307 id,
308 output: None,
309 }))
310 }
311
312 #[allow(dead_code)]
313 #[inline]
314 pub(crate) fn stub() -> Self {
315 Self {
316 kind: JoinErrorKind::StubNever,
317 id: TaskId::stub(),
318 output: None,
319 }
320 }
321
322 #[must_use]
323 pub(super) fn with_output<T>(self, output: Option<T>) -> JoinError<T> {
324 JoinError {
325 kind: self.kind,
326 id: self.id,
327 output,
328 }
329 }
330}
331
332impl<T> JoinError<T> {
333 pub fn is_canceled(&self) -> bool {
335 matches!(self.kind, JoinErrorKind::Canceled { .. })
336 }
337
338 pub fn is_completed(&self) -> bool {
340 match self.kind {
341 JoinErrorKind::Canceled { completed } => completed,
342 _ => false,
343 }
344 }
345
346 pub fn id(&self) -> TaskId {
348 self.id
349 }
350
351 pub fn output(self) -> Option<T> {
356 self.output
357 }
358}
359
360impl<T> fmt::Display for JoinError<T> {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 match self.kind {
363 JoinErrorKind::Canceled { completed } => {
364 let completed = if completed {
365 " (after completing successfully)"
366 } else {
367 ""
368 };
369 write!(f, "task {} was canceled{completed}", self.id)
370 }
371 JoinErrorKind::StubNever => f.write_str("the stub task can never join"),
372 JoinErrorKind::Shutdown => f.write_str("the scheduler has already shut down"),
373 }
374 }
375}
376
377impl<T> fmt::Debug for JoinError<T> {
378 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379 f.debug_struct("JoinError")
380 .field("id", &self.id)
381 .field("kind", &self.kind)
382 .finish()
383 }
384}
385
386feature! {
387 #![feature = "core-error"]
388 impl<T> core::error::Error for JoinError<T> {}
389}