mycelium_kernel/drivers/
pci.rs

1use crate::shell;
2use alloc::collections::{
3    btree_map::{self, BTreeMap},
4    btree_set::{self, BTreeSet},
5};
6use core::{iter, num::NonZeroU16};
7pub use mycelium_pci::*;
8use mycelium_util::{fmt, sync::InitOnce};
9
10#[derive(Debug, Default)]
11pub struct DeviceRegistry {
12    // TODO(eliza): these BTreeMaps could be `[T; 256]`...
13    by_class: BTreeMap<Class, BySubclass>,
14    by_vendor: BTreeMap<u16, BTreeMap<u16, Devices>>,
15    by_bus_group: BTreeMap<u16, BusGroup>,
16    len: usize,
17}
18
19#[derive(Clone, Debug, Default)]
20pub struct Devices(BTreeSet<Address>);
21
22#[derive(Clone, Debug, Default)]
23pub struct BySubclass(BTreeMap<Subclass, Devices>);
24
25#[derive(Clone, Debug, Default)]
26pub struct BusGroup(BTreeMap<u8, Bus>);
27
28#[derive(Clone, Debug, Default)]
29pub struct Bus([Option<BusDevice>; 32]);
30
31type BusDevice = [Option<(Subclass, device::Id)>; 8];
32
33pub static DEVICES: InitOnce<DeviceRegistry> = InitOnce::uninitialized();
34
35type BusDeviceFilter = fn((usize, &Option<BusDevice>)) -> Option<(usize, &BusDevice)>;
36type SubclassDeviceFilter = fn(&(&Subclass, &Devices)) -> bool;
37
38pub const LSPCI_CMD: shell::Command = shell::Command::new("lspci")
39    .with_help("list PCI devices")
40    .with_usage("[ADDRESS]")
41    .with_subcommands(&[shell::Command::new("class")
42        .with_help(
43            "list PCI devices by class. if no class code is provided, lists all devices by class.",
44        )
45        .with_usage("[CLASS]")
46        .with_fn(|ctx| {
47            fn log_device(device: Address, subclass: Subclass) {
48                let Some(header) = config::ConfigReg::new(device).read_header() else {
49                    tracing::error!(target: "pci", "[{device}]: invalid device header!");
50                    return;
51                };
52                let prog_if = subclass.prog_if(header.raw_prog_if());
53                match header.id() {
54                    device::Id::Known(id) => tracing::info!(
55                        target: "pci",
56                        vendor = %id.vendor().name(),
57                        device = %id.name(),
58                        %prog_if,
59                        "[{device}]",
60                    ),
61                    device::Id::Unknown(id) => tracing::warn!(
62                        target: "pci",
63                        vendor = fmt::hex(id.vendor_id),
64                        device = fmt::hex(id.device_id),
65                        %prog_if,
66                        "[{device}]: unknown ID or vendor"
67                    ),
68                }
69            }
70
71            fn list_classes<'a>(classes: impl IntoIterator<Item = (Class, &'a BySubclass)>) {
72                for (class, subclasses) in classes {
73                    let _span = tracing::info_span!("class", message = %class.name()).entered();
74                    for (subclass, devices) in subclasses {
75                        let _span =
76                            tracing::info_span!("subclass", message = %subclass.name()).entered();
77                        for device in devices {
78                            log_device(device, *subclass)
79                        }
80                    }
81                }
82            }
83
84            // if no class code was provided, list all classes
85            if ctx.command().is_empty() {
86                tracing::info!("listing all PCI devices by class");
87                list_classes(DEVICES.get().classes());
88                return Ok(());
89            }
90
91            // otherwise, list devices in the provided class.
92            let class = {
93                let class_code = ctx
94                    .command()
95                    .parse::<u8>()
96                    .map_err(|_| ctx.invalid_argument("a PCI class must be a valid `u8` value"))?;
97                Class::from_id(class_code)
98                    .map_err(|_| ctx.invalid_argument("not a valid PCI class"))?
99            };
100            if let Some(subclasses) = DEVICES.get().class(&class) {
101                list_classes(Some((class, subclasses)));
102            } else {
103                tracing::info!("no {} devices found", class.name());
104            }
105
106            Ok(())
107        })])
108    .with_fn(|ctx| {
109        if !ctx.command().is_empty() {
110            let _addr = match ctx.command().parse::<Address>() {
111                Ok(addr) => addr,
112                Err(error) => {
113                    tracing::error!(%error, "invalid PCI address");
114                    return Err(ctx.invalid_argument("invalid PCI address"));
115                }
116            };
117
118            return Err(ctx.other_error("looking up individual devices is not yet implemented"));
119        }
120
121        tracing::info!("listing all PCI devices by address");
122        for (bus_group, group) in DEVICES.get().bus_groups() {
123            let _span = tracing::info_span!("bus group", "{bus_group:04x}").entered();
124            for (bus_num, bus) in group.buses() {
125                let _span = tracing::info_span!("bus", "{bus_num:02x}").entered();
126                for (device_num, device) in bus.devices() {
127                    let _span = tracing::info_span!("device", "{device_num:02x}").entered();
128                    for (fn_num, (subclass, id)) in device
129                        .iter()
130                        .enumerate()
131                        .filter_map(|(fn_num, func)| Some((fn_num, func.as_ref()?)))
132                    {
133                        match id {
134                            device::Id::Known(id) => tracing::info!(
135                                target: " pci",
136                                class = %subclass.class().name(),
137                                device = %subclass.name(),
138                                vendor = %id.vendor().name(),
139                                device = %id.name(),
140                                "[{bus_group:04x}:{bus_num:02x}:{device_num:02x}.{fn_num}]"
141                            ),
142                            device::Id::Unknown(id) => tracing::warn!(
143                                target: " pci",
144                                class = %subclass.class().name(),
145                                device = %subclass.name(),
146                                vendor = fmt::hex(id.vendor_id),
147                                device = fmt::hex(id.device_id),
148                                "[{bus_group:04x}:{bus_num:02x}:{device_num:02x}.{fn_num}]",
149                            ),
150                        }
151                    }
152                }
153            }
154        }
155
156        Ok(())
157    });
158
159impl DeviceRegistry {
160    pub fn insert(&mut self, addr: Address, class: Classes, id: device::Id) -> bool {
161        // class->subclass->addr registry
162        let mut new = self
163            .by_class
164            .entry(class.class())
165            .or_default()
166            .0
167            .entry(class.subclass())
168            .or_default()
169            .0
170            .insert(addr);
171        // vendor->device->addr registry
172        new &= self
173            .by_vendor
174            .entry(id.vendor_id())
175            .or_default()
176            .entry(id.device_id())
177            .or_default()
178            .0
179            .insert(addr);
180        // address (bus group->bus->device->function) registry
181        let Bus(bus) = self
182            .by_bus_group
183            .entry(addr.group().map(Into::into).unwrap_or(0))
184            .or_default()
185            .0
186            .entry(addr.bus())
187            .or_default();
188        let bus_device = bus
189            .get_mut(addr.device() as usize)
190            .expect("invalid address: the device must be 5 bits")
191            .get_or_insert_with(|| [None; 8]);
192        new &= bus_device
193            .get_mut(addr.function() as usize)
194            .expect("invalid address: the function must be 3 bits")
195            .replace((class.subclass(), id))
196            .is_none();
197        new
198    }
199
200    pub fn class(&self, class: &Class) -> Option<&BySubclass> {
201        self.by_class.get(class)
202    }
203
204    pub fn len(&self) -> usize {
205        self.len
206    }
207
208    pub fn is_empty(&self) -> bool {
209        let is_empty = self.len() == 0;
210        debug_assert_eq!(is_empty, self.by_class.is_empty());
211        debug_assert_eq!(is_empty, self.by_vendor.is_empty());
212        is_empty
213    }
214
215    pub fn classes(&self) -> impl Iterator<Item = (Class, &BySubclass)> + '_ {
216        self.by_class.iter().filter_map(|(class, by_subclass)| {
217            if !by_subclass.0.is_empty() {
218                Some((*class, by_subclass))
219            } else {
220                None
221            }
222        })
223    }
224
225    pub fn bus_groups(&self) -> btree_map::Iter<'_, u16, BusGroup> {
226        self.by_bus_group.iter()
227    }
228
229    pub fn iter(&self) -> impl Iterator<Item = (Address, Subclass, device::Id)> + '_ {
230        self.bus_groups().flat_map(|(&group_addr, group)| {
231            group.buses().flat_map(move |(&bus_addr, bus)| {
232                bus.devices().flat_map(move |(device_addr, device)| {
233                    device
234                        .iter()
235                        .enumerate()
236                        .filter_map(move |(func_num, func)| {
237                            let (subclass, id) = func.as_ref()?;
238                            let addr = Address::new()
239                                .with_group(NonZeroU16::new(group_addr))
240                                .with_bus(bus_addr)
241                                .with_device(device_addr as u8)
242                                .with_function(func_num as u8);
243                            Some((addr, *subclass, *id))
244                        })
245                })
246            })
247        })
248    }
249}
250
251// === impl BySubclass ===
252
253impl BySubclass {
254    pub fn subclass(&self, subclass: class::Subclass) -> Option<&Devices> {
255        self.0.get(&subclass)
256    }
257
258    pub fn iter(
259        &self,
260    ) -> iter::Filter<btree_map::Iter<'_, Subclass, Devices>, SubclassDeviceFilter> {
261        self.0.iter().filter(|&(_, devices)| !devices.is_empty())
262    }
263}
264
265impl<'a> IntoIterator for &'a BySubclass {
266    type Item = (&'a Subclass, &'a Devices);
267    type IntoIter = iter::Filter<btree_map::Iter<'a, Subclass, Devices>, SubclassDeviceFilter>;
268    #[inline]
269    fn into_iter(self) -> Self::IntoIter {
270        self.iter()
271    }
272}
273
274// === impl Devices ===
275
276impl Devices {
277    pub fn iter(&self) -> iter::Copied<btree_set::Iter<'_, Address>> {
278        self.0.iter().copied()
279    }
280
281    #[inline]
282    #[must_use]
283    pub fn is_empty(&self) -> bool {
284        self.0.is_empty()
285    }
286
287    #[inline]
288    #[must_use]
289    pub fn len(&self) -> usize {
290        self.0.len()
291    }
292}
293
294impl<'a> IntoIterator for &'a Devices {
295    type Item = Address;
296    type IntoIter = iter::Copied<btree_set::Iter<'a, Address>>;
297    #[inline]
298    fn into_iter(self) -> Self::IntoIter {
299        self.iter()
300    }
301}
302
303// === impl BusGroup ===
304
305impl BusGroup {
306    pub fn buses(&self) -> btree_map::Iter<'_, u8, Bus> {
307        self.0.iter()
308    }
309}
310
311// === impl Bus ===
312
313impl Bus {
314    pub fn devices(
315        &self,
316    ) -> iter::FilterMap<iter::Enumerate<core::slice::Iter<'_, Option<BusDevice>>>, BusDeviceFilter>
317    {
318        self.0
319            .iter()
320            .enumerate()
321            .filter_map((|(addr, device)| Some((addr, device.as_ref()?))) as BusDeviceFilter)
322    }
323}