maitake_sync/
loom.rs

1#[allow(unused_imports)]
2pub(crate) use self::inner::*;
3
4#[cfg(loom)]
5mod inner {
6    #![allow(dead_code)]
7    #![allow(unused_imports)]
8
9    #[cfg(feature = "alloc")]
10    pub(crate) mod alloc {
11        use super::sync::Arc;
12        use core::{
13            future::Future,
14            pin::Pin,
15            task::{Context, Poll},
16        };
17        pub(crate) use loom::alloc::*;
18
19        #[derive(Debug)]
20        #[pin_project::pin_project]
21        pub(crate) struct TrackFuture<F> {
22            #[pin]
23            inner: F,
24            track: Arc<()>,
25        }
26
27        impl<F: Future> Future for TrackFuture<F> {
28            type Output = TrackFuture<F::Output>;
29            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
30                let this = self.project();
31                this.inner.poll(cx).map(|inner| TrackFuture {
32                    inner,
33                    track: this.track.clone(),
34                })
35            }
36        }
37
38        impl<F> TrackFuture<F> {
39            /// Wrap a `Future` in a `TrackFuture` that participates in Loom's
40            /// leak checking.
41            #[track_caller]
42            pub(crate) fn new(inner: F) -> Self {
43                Self {
44                    inner,
45                    track: Arc::new(()),
46                }
47            }
48
49            /// Stop tracking this future, and return the inner value.
50            pub(crate) fn into_inner(self) -> F {
51                self.inner
52            }
53        }
54
55        #[track_caller]
56        pub(crate) fn track_future<F: Future>(inner: F) -> TrackFuture<F> {
57            TrackFuture::new(inner)
58        }
59
60        // PartialEq impl so that `assert_eq!(..., Ok(...))` works
61        impl<F: PartialEq> PartialEq for TrackFuture<F> {
62            fn eq(&self, other: &Self) -> bool {
63                self.inner == other.inner
64            }
65        }
66    }
67
68    #[cfg(test)]
69    pub(crate) use loom::future;
70    pub(crate) use loom::{cell, hint, model, thread};
71
72    pub(crate) mod sync {
73        pub(crate) use loom::sync::*;
74
75        pub(crate) mod blocking {
76            use core::{
77                marker::PhantomData,
78                ops::{Deref, DerefMut},
79            };
80
81            #[cfg(feature = "tracing")]
82            use core::panic::Location;
83
84            use core::fmt;
85
86            /// Mock version of mycelium's spinlock, but using
87            /// `loom::sync::Mutex`. The API is slightly different, since the
88            /// mycelium mutex does not support poisoning.
89            pub(crate) struct Mutex<T, Lock = crate::spin::Spinlock>(
90                loom::sync::Mutex<T>,
91                PhantomData<Lock>,
92            );
93
94            pub(crate) struct MutexGuard<'a, T, Lock = crate::spin::Spinlock> {
95                guard: loom::sync::MutexGuard<'a, T>,
96                #[cfg(feature = "tracing")]
97                location: &'static Location<'static>,
98                _p: PhantomData<Lock>,
99            }
100
101            impl<T, Lock> Mutex<T, Lock> {
102                #[track_caller]
103                pub(crate) fn new(t: T) -> Self {
104                    Self(loom::sync::Mutex::new(t), PhantomData)
105                }
106
107                #[track_caller]
108                pub(crate) fn new_with_raw_mutex(t: T, _: Lock) -> Self {
109                    Self::new(t)
110                }
111
112                #[track_caller]
113                pub fn with_lock<U>(&self, f: impl FnOnce(&mut T) -> U) -> U {
114                    let mut guard = self.lock();
115                    let res = f(&mut *guard);
116                    res
117                }
118
119                #[track_caller]
120                pub fn try_lock(&self) -> Option<MutexGuard<'_, T, Lock>> {
121                    #[cfg(feature = "tracing")]
122                    let location = Location::caller();
123                    #[cfg(feature = "tracing")]
124                    tracing::debug!(%location, "Mutex::try_lock");
125
126                    match self.0.try_lock() {
127                        Ok(guard) => {
128                            #[cfg(feature = "tracing")]
129                            tracing::debug!(%location, "Mutex::try_lock -> locked!");
130                            Some(MutexGuard {
131                                guard,
132
133                                #[cfg(feature = "tracing")]
134                                location,
135                                _p: PhantomData,
136                            })
137                        }
138                        Err(_) => {
139                            #[cfg(feature = "tracing")]
140                            tracing::debug!(%location, "Mutex::try_lock -> already locked");
141                            None
142                        }
143                    }
144                }
145
146                #[track_caller]
147                pub fn lock(&self) -> MutexGuard<'_, T, Lock> {
148                    #[cfg(feature = "tracing")]
149                    let location = Location::caller();
150
151                    #[cfg(feature = "tracing")]
152                    tracing::debug!(%location, "Mutex::lock");
153
154                    let guard = self
155                        .0
156                        .lock()
157                        .map(|guard| MutexGuard {
158                            guard,
159
160                            #[cfg(feature = "tracing")]
161                            location,
162                            _p: PhantomData,
163                        })
164                        .expect("loom mutex will never poison");
165
166                    #[cfg(feature = "tracing")]
167                    tracing::debug!(%location, "Mutex::lock -> locked");
168                    guard
169                }
170            }
171
172            impl<T: fmt::Debug, Lock> fmt::Debug for Mutex<T, Lock> {
173                fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174                    self.0.fmt(f)
175                }
176            }
177
178            impl<T, Lock> Deref for MutexGuard<'_, T, Lock> {
179                type Target = T;
180                #[inline]
181                fn deref(&self) -> &Self::Target {
182                    self.guard.deref()
183                }
184            }
185
186            impl<T, Lock> DerefMut for MutexGuard<'_, T, Lock> {
187                #[inline]
188                fn deref_mut(&mut self) -> &mut Self::Target {
189                    self.guard.deref_mut()
190                }
191            }
192
193            impl<T: fmt::Debug, Lock> fmt::Debug for MutexGuard<'_, T, Lock> {
194                fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195                    self.guard.fmt(f)
196                }
197            }
198
199            impl<T, Lock> Drop for MutexGuard<'_, T, Lock> {
200                #[track_caller]
201                fn drop(&mut self) {
202                    #[cfg(feature = "tracing")]
203                    tracing::debug!(
204                        location.dropped = %Location::caller(),
205                        location.locked = %self.location,
206                        "MutexGuard::drop: unlocking",
207                    );
208                }
209            }
210        }
211    }
212}
213
214#[cfg(not(loom))]
215mod inner {
216    #![allow(dead_code, unused_imports)]
217    pub(crate) mod sync {
218        #[cfg(any(feature = "alloc", test))]
219        pub use alloc::sync::*;
220        pub use core::sync::*;
221
222        pub(crate) mod atomic {
223            pub use portable_atomic::*;
224            pub use core::sync::atomic::Ordering;
225        }
226
227        pub use crate::blocking;
228    }
229
230    pub(crate) use portable_atomic::hint;
231
232    #[cfg(test)]
233    pub(crate) mod thread {
234
235        pub(crate) use std::thread::{yield_now, JoinHandle};
236
237        pub(crate) fn spawn<F, T>(f: F) -> JoinHandle<T>
238        where
239            F: FnOnce() -> T + Send + 'static,
240            T: Send + 'static,
241        {
242            use super::sync::atomic::{AtomicUsize, Ordering::Relaxed};
243            thread_local! {
244                static CHILDREN: AtomicUsize = const { AtomicUsize::new(1) };
245            }
246
247            let track = super::alloc::track::Registry::current();
248            let subscriber = tracing::Dispatch::default();
249            let span = tracing::Span::current();
250            let num = CHILDREN.with(|children| children.fetch_add(1, Relaxed));
251            std::thread::spawn(move || {
252                let _tracing = tracing::dispatcher::set_default(&subscriber);
253                let _span = tracing::info_span!(parent: span, "thread", message = num).entered();
254
255                tracing::info!(num, "spawned child thread");
256                let _tracking = track.map(|track| track.set_default());
257                let res = f();
258                tracing::info!(num, "child thread completed");
259
260                res
261            })
262        }
263    }
264
265    #[cfg(test)]
266    pub(crate) mod model {
267        #[non_exhaustive]
268        #[derive(Default)]
269        pub(crate) struct Builder {
270            pub(crate) max_threads: usize,
271            pub(crate) max_branches: usize,
272            pub(crate) max_permutations: Option<usize>,
273            // pub(crate) max_duration: Option<Duration>,
274            pub(crate) preemption_bound: Option<usize>,
275            // pub(crate) checkpoint_file: Option<PathBuf>,
276            pub(crate) checkpoint_interval: usize,
277            pub(crate) location: bool,
278            pub(crate) log: bool,
279        }
280
281        impl Builder {
282            pub(crate) fn new() -> Self {
283                Self::default()
284            }
285
286            pub(crate) fn check(&self, f: impl FnOnce()) {
287                let _trace = crate::util::test::trace_init();
288                let _span = tracing::info_span!(
289                    "test",
290                    message = std::thread::current().name().unwrap_or("<unnamed>")
291                )
292                .entered();
293                let registry = super::alloc::track::Registry::default();
294                let _tracking = registry.set_default();
295
296                tracing::info!("started test...");
297                f();
298                tracing::info!("test completed successfully!");
299
300                registry.check();
301            }
302        }
303    }
304
305    #[cfg(test)]
306    pub(crate) fn model(f: impl FnOnce()) {
307        model::Builder::new().check(f)
308    }
309
310    pub(crate) mod cell {
311        #[derive(Debug)]
312        pub(crate) struct UnsafeCell<T: ?Sized>(core::cell::UnsafeCell<T>);
313
314        impl<T> UnsafeCell<T> {
315            pub const fn new(data: T) -> UnsafeCell<T> {
316                UnsafeCell(core::cell::UnsafeCell::new(data))
317            }
318        }
319
320        impl<T: ?Sized> UnsafeCell<T> {
321            #[inline(always)]
322            pub fn with<F, R>(&self, f: F) -> R
323            where
324                F: FnOnce(*const T) -> R,
325            {
326                f(self.0.get())
327            }
328
329            #[inline(always)]
330            pub fn with_mut<F, R>(&self, f: F) -> R
331            where
332                F: FnOnce(*mut T) -> R,
333            {
334                f(self.0.get())
335            }
336
337            #[inline(always)]
338            pub(crate) fn get(&self) -> ConstPtr<T> {
339                ConstPtr(self.0.get())
340            }
341
342            #[inline(always)]
343            pub(crate) fn get_mut(&self) -> MutPtr<T> {
344                MutPtr(self.0.get())
345            }
346        }
347
348        impl<T> UnsafeCell<T> {
349            #[inline(always)]
350            #[must_use]
351            pub(crate) fn into_inner(self) -> T {
352                self.0.into_inner()
353            }
354        }
355
356        #[derive(Debug)]
357        pub(crate) struct ConstPtr<T: ?Sized>(*const T);
358
359        impl<T: ?Sized> ConstPtr<T> {
360            #[inline(always)]
361            pub(crate) unsafe fn deref(&self) -> &T {
362                &*self.0
363            }
364
365            #[inline(always)]
366            pub fn with<F, R>(&self, f: F) -> R
367            where
368                F: FnOnce(*const T) -> R,
369            {
370                f(self.0)
371            }
372        }
373
374        #[derive(Debug)]
375        pub(crate) struct MutPtr<T: ?Sized>(*mut T);
376
377        impl<T: ?Sized> MutPtr<T> {
378            // Clippy knows that it's Bad and Wrong to construct a mutable reference
379            // from an immutable one...but this function is intended to simulate a raw
380            // pointer, so we have to do that here.
381            #[allow(clippy::mut_from_ref)]
382            #[inline(always)]
383            pub(crate) unsafe fn deref(&self) -> &mut T {
384                &mut *self.0
385            }
386
387            #[inline(always)]
388            pub fn with<F, R>(&self, f: F) -> R
389            where
390                F: FnOnce(*mut T) -> R,
391            {
392                f(self.0)
393            }
394        }
395    }
396
397    pub(crate) mod alloc {
398        #[cfg(test)]
399        use core::{
400            future::Future,
401            pin::Pin,
402            task::{Context, Poll},
403        };
404
405        #[cfg(test)]
406        use std::sync::Arc;
407        #[cfg(test)]
408        pub(in crate::loom) mod track {
409            use std::{
410                cell::RefCell,
411                sync::{
412                    atomic::{AtomicBool, Ordering},
413                    Arc, Mutex, Weak,
414                },
415            };
416
417            #[derive(Clone, Debug, Default)]
418            pub(crate) struct Registry(Arc<Mutex<RegistryInner>>);
419
420            #[derive(Debug, Default)]
421            struct RegistryInner {
422                tracks: Vec<Weak<TrackData>>,
423                next_id: usize,
424            }
425
426            #[derive(Debug)]
427            pub(super) struct TrackData {
428                was_leaked: AtomicBool,
429                type_name: &'static str,
430                location: &'static core::panic::Location<'static>,
431                id: usize,
432            }
433
434            thread_local! {
435                static REGISTRY: RefCell<Option<Registry>> = const { RefCell::new(None) };
436            }
437
438            impl Registry {
439                pub(in crate::loom) fn current() -> Option<Registry> {
440                    REGISTRY.with(|current| current.borrow().clone())
441                }
442
443                pub(in crate::loom) fn set_default(&self) -> impl Drop {
444                    struct Unset(Option<Registry>);
445                    impl Drop for Unset {
446                        fn drop(&mut self) {
447                            let _ =
448                                REGISTRY.try_with(|current| *current.borrow_mut() = self.0.take());
449                        }
450                    }
451
452                    REGISTRY.with(|current| {
453                        let mut current = current.borrow_mut();
454                        let unset = Unset(current.clone());
455                        *current = Some(self.clone());
456                        unset
457                    })
458                }
459
460                #[track_caller]
461                pub(super) fn start_tracking<T>() -> Option<Arc<TrackData>> {
462                    // we don't use `Option::map` here because it creates a
463                    // closure, which breaks `#[track_caller]`, since the caller
464                    // of `insert` becomes the closure, which cannot have a
465                    // `#[track_caller]` attribute on it.
466                    #[allow(clippy::manual_map)]
467                    match Self::current() {
468                        Some(registry) => Some(registry.insert::<T>()),
469                        _ => None,
470                    }
471                }
472
473                #[track_caller]
474                pub(super) fn insert<T>(&self) -> Arc<TrackData> {
475                    let mut inner = self.0.lock().unwrap();
476                    let id = inner.next_id;
477                    inner.next_id += 1;
478                    let location = core::panic::Location::caller();
479                    let type_name = std::any::type_name::<T>();
480                    let data = Arc::new(TrackData {
481                        type_name,
482                        location,
483                        id,
484                        was_leaked: AtomicBool::new(false),
485                    });
486                    let weak = Arc::downgrade(&data);
487                    tracing::trace!(
488                        target: "maitake_sync::alloc",
489                        id,
490                        "type" = %type_name,
491                        %location,
492                        "started tracking allocation",
493                    );
494                    inner.tracks.push(weak);
495                    data
496                }
497
498                pub(in crate::loom) fn check(&self) {
499                    let leaked = self
500                        .0
501                        .lock()
502                        .unwrap()
503                        .tracks
504                        .iter()
505                        .filter_map(|weak| {
506                            let data = weak.upgrade()?;
507                            data.was_leaked.store(true, Ordering::SeqCst);
508                            Some(format!(
509                                " - id {}, {} allocated at {}",
510                                data.id, data.type_name, data.location
511                            ))
512                        })
513                        .collect::<Vec<_>>();
514                    if !leaked.is_empty() {
515                        let leaked = leaked.join("\n  ");
516                        panic!("the following allocations were leaked:\n  {leaked}");
517                    }
518                }
519            }
520
521            impl Drop for TrackData {
522                fn drop(&mut self) {
523                    if !self.was_leaked.load(Ordering::SeqCst) {
524                        tracing::trace!(
525                            target: "maitake_sync::alloc",
526                            id = self.id,
527                            "type" = %self.type_name,
528                            location = %self.location,
529                            "dropped all references to a tracked allocation",
530                        );
531                    }
532                }
533            }
534        }
535
536        #[cfg(test)]
537        #[derive(Debug)]
538        #[pin_project::pin_project]
539        pub(crate) struct TrackFuture<F> {
540            #[pin]
541            inner: F,
542            track: Option<Arc<track::TrackData>>,
543        }
544
545        #[cfg(test)]
546        impl<F: Future> Future for TrackFuture<F> {
547            type Output = TrackFuture<F::Output>;
548            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
549                let this = self.project();
550                this.inner.poll(cx).map(|inner| TrackFuture {
551                    inner,
552                    track: this.track.clone(),
553                })
554            }
555        }
556
557        #[cfg(test)]
558        impl<F> TrackFuture<F> {
559            /// Wrap a `Future` in a `TrackFuture` that participates in Loom's
560            /// leak checking.
561            #[track_caller]
562            pub(crate) fn new(inner: F) -> Self {
563                let track = track::Registry::start_tracking::<F>();
564                Self { inner, track }
565            }
566
567            /// Stop tracking this future, and return the inner value.
568            pub(crate) fn into_inner(self) -> F {
569                self.inner
570            }
571        }
572
573        #[cfg(test)]
574        #[track_caller]
575        pub(crate) fn track_future<F: Future>(inner: F) -> TrackFuture<F> {
576            TrackFuture::new(inner)
577        }
578
579        // PartialEq impl so that `assert_eq!(..., Ok(...))` works
580        #[cfg(test)]
581        impl<F: PartialEq> PartialEq for TrackFuture<F> {
582            fn eq(&self, other: &Self) -> bool {
583                self.inner == other.inner
584            }
585        }
586
587        /// Track allocations, detecting leaks
588        #[derive(Debug, Default)]
589        pub struct Track<T> {
590            value: T,
591
592            #[cfg(test)]
593            track: Option<Arc<track::TrackData>>,
594        }
595
596        impl<T> Track<T> {
597            /// Track a value for leaks
598            #[inline(always)]
599            #[track_caller]
600            pub fn new(value: T) -> Track<T> {
601                Track {
602                    value,
603
604                    #[cfg(test)]
605                    track: track::Registry::start_tracking::<T>(),
606                }
607            }
608
609            /// Get a reference to the value
610            #[inline(always)]
611            pub fn get_ref(&self) -> &T {
612                &self.value
613            }
614
615            /// Get a mutable reference to the value
616            #[inline(always)]
617            pub fn get_mut(&mut self) -> &mut T {
618                &mut self.value
619            }
620
621            /// Stop tracking the value for leaks
622            #[inline(always)]
623            pub fn into_inner(self) -> T {
624                self.value
625            }
626        }
627    }
628
629    #[cfg(test)]
630    pub(crate) mod future {
631        pub(crate) use tokio_test::block_on;
632    }
633
634    #[cfg(test)]
635    pub(crate) fn traceln(args: std::fmt::Arguments) {
636        eprintln!("{args}");
637    }
638
639    #[cfg(not(test))]
640    pub(crate) fn traceln(_: core::fmt::Arguments) {}
641}