hal_x86_64/cpu/
local.rs
1use super::Msr;
2use alloc::boxed::Box;
3use core::{
4 arch::asm,
5 marker::PhantomPinned,
6 pin::Pin,
7 ptr,
8 sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
9};
10use hal_core::CoreLocal;
11use mycelium_util::{fmt, sync::Lazy};
12
13#[repr(C)]
14#[derive(Debug)]
15pub struct GsLocalData {
16 _self: *const Self,
19 magic: usize,
20 _must_pin: PhantomPinned,
22 userdata: [AtomicPtr<()>; Self::MAX_LOCAL_KEYS],
27}
28
29pub struct LocalKey<T> {
30 idx: Lazy<usize>,
31 initializer: fn() -> T,
32}
33
34impl GsLocalData {
35 const MAGIC: usize = 0xC0FFEE;
37 pub const MAX_LOCAL_KEYS: usize = 64;
38
39 const fn new() -> Self {
40 #[allow(clippy::declare_interior_mutable_const)] const LOCAL_SLOT_INIT: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
42 Self {
43 _self: ptr::null(),
44 _must_pin: PhantomPinned,
45 magic: Self::MAGIC,
46 userdata: [LOCAL_SLOT_INIT; Self::MAX_LOCAL_KEYS],
47 }
48 }
49
50 #[must_use]
53 pub fn try_current() -> Option<Pin<&'static Self>> {
54 if !Self::has_local_data() {
55 return None;
56 }
57 unsafe {
58 let ptr: *const Self;
59 asm!("mov {}, gs:0x0", out(reg) ptr);
60 debug_assert_eq!(
61 (*ptr).magic,
62 Self::MAGIC,
63 "weird magic mismatch, this should never happen??"
64 );
65 Some(Pin::new_unchecked(&*ptr))
66 }
67 }
68
69 #[track_caller]
76 #[must_use]
77 pub fn current() -> Pin<&'static Self> {
78 Self::try_current()
79 .expect("GsLocalData::current() called before local data was initialized on this core!")
80 }
81
82 pub fn with<T, U>(&self, key: &LocalKey<T>, f: impl FnOnce(&T) -> U) -> U {
84 let idx = *key.idx.get();
85 let slot = match self.userdata.get(idx) {
86 Some(slot) => slot,
87 None => panic!(
88 "local key had an index greater than GsLocalData::MAX_LOCAL_KEYS: index = {idx}, max = {}",
89 Self::MAX_LOCAL_KEYS
90 ),
91 };
92
93 let mut ptr = slot.load(Ordering::Acquire);
96 if ptr.is_null() {
97 let data = Box::new((key.initializer)());
98 let data_ptr = Box::into_raw(data) as *mut ();
99 slot.compare_exchange(ptr, data_ptr, Ordering::AcqRel, Ordering::Acquire)
100 .expect("CAS should be uncontended!");
101 ptr = data_ptr;
102 }
103
104 let data = unsafe { &*(ptr as *const T) };
105 f(data)
106 }
107
108 #[track_caller]
112 pub fn init() {
113 if Self::has_local_data() {
114 tracing::warn!("this CPU core already has local data initialized!");
115 debug_assert!(false, "this CPU core already has local data initialized!");
116 return;
117 }
118
119 let ptr = Box::into_raw(Box::new(Self::new()));
120 tracing::trace!(?ptr, "initializing local data");
121 unsafe {
122 (*ptr)._self = ptr as *const _;
124 Msr::ia32_gs_base().write(ptr as u64);
125 }
126 }
127
128 fn has_local_data() -> bool {
130 if Msr::ia32_gs_base().read() == 0 {
132 return false;
133 }
134
135 let word: usize;
137 unsafe {
138 asm!("mov {}, gs:0x8", out(reg) word);
139 }
140 word == Self::MAGIC
141 }
142}
143
144impl<T: 'static> LocalKey<T> {
147 #[must_use]
148 #[track_caller]
149 pub const fn new(initializer: fn() -> T) -> Self {
150 Self {
151 idx: Lazy::new(Self::next_index),
152 initializer,
153 }
154 }
155
156 #[track_caller]
157 pub fn with<U>(&self, f: impl FnOnce(&T) -> U) -> U {
158 GsLocalData::current().with(self, f)
159 }
160
161 #[track_caller]
162 fn next_index() -> usize {
163 static NEXT_INDEX: AtomicUsize = AtomicUsize::new(0);
164 let idx = NEXT_INDEX.fetch_add(1, Ordering::Relaxed);
165 assert!(
166 idx < GsLocalData::MAX_LOCAL_KEYS,
167 "maximum number of local keys ({}) exceeded",
168 GsLocalData::MAX_LOCAL_KEYS
169 );
170 idx
171 }
172}
173
174impl<T> fmt::Debug for LocalKey<T> {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.debug_struct("LocalKey")
177 .field("type", &core::any::type_name::<T>())
178 .field("initializer", &fmt::ptr(self.initializer))
179 .field("idx", &self.idx)
180 .finish()
181 }
182}
183
184impl<T: 'static> CoreLocal<T> for LocalKey<T> {
185 fn new(initializer: fn() -> T) -> Self {
186 Self::new(initializer)
187 }
188
189 #[track_caller]
190 fn with<F, U>(&self, f: F) -> U
191 where
192 F: FnOnce(&T) -> U,
193 {
194 self.with(f)
195 }
196}