mycelium_pci/
addr.rs

1use core::{fmt, num::NonZeroU16, str::FromStr};
2use hex::{FromHex, FromHexError};
3use mycelium_bitfield::{bitfield, Pack32};
4
5/// A PCI device address.
6#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
7pub struct Address(AddressBits);
8
9#[derive(Debug)]
10pub struct ParseError {
11    kind: ParseErrorKind,
12}
13
14#[derive(Debug)]
15
16enum ParseErrorKind {
17    ExpectedChar(char),
18    NotANumber {
19        reason: FromHexError,
20        pos: &'static str,
21    },
22    InvalidNumber {
23        num: u16,
24        max: u16,
25        name: &'static str,
26    },
27    Msg(&'static str),
28}
29
30bitfield! {
31    #[derive(Eq, PartialEq, Ord, PartialOrd)]
32    pub(crate) struct AddressBits<u32> {
33        pub(crate) const FUNCTION = 3;
34        pub(crate) const DEVICE = 5;
35        pub(crate) const BUS: u8;
36        pub(crate) const GROUP: u16;
37    }
38}
39
40impl Address {
41    #[inline]
42    #[must_use]
43    pub const fn new() -> Self {
44        Self(AddressBits::new())
45    }
46
47    /// Returns the device's segment group, if it is a PCI Express device.
48    ///
49    /// PCI Express supports up to 65535 segment groups, each with 256 bus
50    /// segments. Standard PCI does not support segment groups.
51    #[inline]
52    #[must_use]
53    pub fn group(self) -> Option<NonZeroU16> {
54        NonZeroU16::new(self.0.get(AddressBits::GROUP))
55    }
56
57    /// Returns the device's bus segment.
58    ///
59    /// PCI supports up to 256 bus segments.
60    #[inline]
61    #[must_use]
62    pub fn bus(self) -> u8 {
63        self.0.get(AddressBits::BUS)
64    }
65
66    /// Returns the device number within its bus segment.
67    #[inline]
68    #[must_use]
69    pub fn device(self) -> u8 {
70        self.0.get(AddressBits::DEVICE) as u8
71    }
72
73    /// Returns which function of the device this address refers to.
74    ///
75    /// A device may support up to 8 separate functions.
76    #[inline]
77    #[must_use]
78    pub fn function(self) -> u8 {
79        self.0.get(AddressBits::FUNCTION) as u8
80    }
81
82    #[inline]
83    #[must_use]
84    pub fn with_group(self, group: Option<NonZeroU16>) -> Self {
85        let value = group.map(NonZeroU16::get).unwrap_or(0);
86        Self(self.0.with(AddressBits::GROUP, value))
87    }
88
89    #[inline]
90    #[must_use]
91    pub fn with_bus(self, bus: u8) -> Self {
92        Self(self.0.with(AddressBits::BUS, bus))
93    }
94
95    #[inline]
96    #[must_use]
97    pub fn with_device(self, device: u8) -> Self {
98        Self(self.0.with(AddressBits::DEVICE, device as u32))
99    }
100
101    #[inline]
102    #[must_use]
103    pub fn with_function(self, function: u8) -> Self {
104        Self(self.0.with(AddressBits::FUNCTION, function as u32))
105    }
106
107    pub(crate) fn bitfield(self) -> AddressBits {
108        self.0
109    }
110}
111
112impl Default for Address {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl fmt::Display for Address {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        fmt::LowerHex::fmt(self, f)
121    }
122}
123
124impl fmt::UpperHex for Address {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        if let Some(group) = self.group() {
127            write!(f, "{group:04X}:")?;
128        }
129        write!(
130            f,
131            "{:02X}:{:02X}.{}",
132            self.bus(),
133            self.device(),
134            self.function()
135        )
136    }
137}
138
139impl fmt::LowerHex for Address {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        if let Some(group) = self.group() {
142            write!(f, "{group:04x}:")?;
143        }
144        write!(
145            f,
146            "{:02x}:{:02x}.{}",
147            self.bus(),
148            self.device(),
149            self.function()
150        )
151    }
152}
153
154impl fmt::Debug for Address {
155    #[inline]
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "Address({self:x})")
158    }
159}
160
161impl FromStr for Address {
162    type Err = ParseError;
163    fn from_str(s: &str) -> Result<Self, Self::Err> {
164        fn parse_u8<T>(
165            s: &str,
166            mask: Pack32<T, AddressBits>,
167            name: &'static str,
168        ) -> Result<u8, ParseError> {
169            let [num] = <[u8; 1]>::from_hex(s).map_err(ParseError::not_a_number(name))?;
170            ParseError::validate_mask(num, mask, name)
171        }
172
173        let s = s.trim();
174        let mut addr = Address::new();
175        let mut split = s.split(':');
176        let first = split.next().ok_or_else(|| ParseError::expected_char(':'))?;
177        let second = split
178            .next()
179            .ok_or_else(|| ParseError::msg("expected a device number after ':'"))?;
180        let (bus, dev_fn) = if let Some(third) = split.next() {
181            // if there are two colons, the first part is the bus group.
182            let bytes =
183                <[u8; 2]>::from_hex(first).map_err(ParseError::not_a_number("bus group"))?;
184            let group = u16::from_be_bytes(bytes);
185            addr = addr.with_group(NonZeroU16::new(group));
186            (second, third)
187        } else {
188            (first, second)
189        };
190
191        let bus = parse_u8(bus, AddressBits::BUS, "bus number")?;
192        addr = addr.with_bus(bus);
193
194        let mut dev_fn = dev_fn.split('.');
195        let device = dev_fn
196            .next()
197            .ok_or_else(|| ParseError::msg("expected device number"))
198            .and_then(|dev| parse_u8(dev, AddressBits::DEVICE, "device number"))?;
199        addr = addr.with_device(device);
200
201        let func = dev_fn
202            .next()
203            .map(|func| {
204                // use `u8`'s `FromStr` impl rather than the `hex` crate for the
205                // function number, as `hex` refuses to parse single-digit
206                // strings as hex :<
207                func.parse::<u8>()
208                    .map_err(|_| {
209                        ParseError::msg("function number is not a number in the range 0-8")
210                    })
211                    .and_then(|num| {
212                        ParseError::validate_mask(num, AddressBits::FUNCTION, "function number")
213                    })
214            })
215            .transpose()?;
216        if let Some(func) = func {
217            addr = addr.with_function(func);
218        }
219        Ok(addr)
220    }
221}
222
223// === impl ParseError ===
224
225impl ParseError {
226    fn validate_mask<T>(
227        num: u8,
228        mask: Pack32<T, AddressBits>,
229        name: &'static str,
230    ) -> Result<u8, Self> {
231        let max = mask.max_value() as u16;
232        if num as u16 > max {
233            return Err(Self {
234                kind: ParseErrorKind::InvalidNumber {
235                    num: num as u16,
236                    max,
237                    name,
238                },
239            });
240        }
241
242        Ok(num)
243    }
244
245    fn not_a_number(pos: &'static str) -> impl Fn(FromHexError) -> Self {
246        move |reason| Self {
247            kind: ParseErrorKind::NotANumber { reason, pos },
248        }
249    }
250    fn expected_char(c: char) -> Self {
251        Self {
252            kind: ParseErrorKind::ExpectedChar(c),
253        }
254    }
255
256    fn msg(msg: &'static str) -> Self {
257        Self {
258            kind: ParseErrorKind::Msg(msg),
259        }
260    }
261}
262
263impl fmt::Display for ParseError {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        match self.kind {
266            ParseErrorKind::ExpectedChar(c) => write!(f, "expected a '{c}'")?,
267            ParseErrorKind::InvalidNumber { num, max, name } => {
268                write!(f, "{name} must be less than {max:#x} (got {num:#x})")?
269            }
270            ParseErrorKind::NotANumber { reason, pos } => {
271                write!(f, "{pos} was not a valid hexadecimal number ({reason})")?
272            }
273            ParseErrorKind::Msg(msg) => f.write_str(msg)?,
274        };
275        f.write_str(
276            ", PCI addresses must be in the format '(<BUS GROUP>:)<BUS>:<DEVICE>(.<FUNCTION>)'",
277        )
278    }
279}
280
281#[cfg(test)]
282mod test {
283    use super::*;
284    use proptest::{prop_assert_eq, proptest};
285
286    #[test]
287    fn addrs_are_valid() {
288        AddressBits::assert_valid();
289    }
290
291    proptest! {
292        #[test]
293        fn addr_roundtrips(bus in 0u8..255u8, device in 0u8..32u8, function in 0u8..8u8) {
294            let addr = Address::new().with_bus(bus).with_device(device).with_function(function);
295
296            prop_assert_eq!(addr.bus(), bus, "bus, addr: {}", addr);
297            prop_assert_eq!(addr.device(), device, "device, addr: {}", addr);
298            prop_assert_eq!(addr.function(), function, "function, addr: {}", addr);
299
300        }
301    }
302
303    #[track_caller]
304    fn test_parse(s: &str, expected: Address) {
305        let addr = s.parse::<Address>().expect(s);
306        assert_eq!(addr, expected);
307    }
308
309    #[test]
310    fn parse_pci_addr_no_fn() {
311        test_parse("00:02", Address::new().with_device(0x02));
312        test_parse("0f:0f", Address::new().with_bus(0x000f).with_device(0x0f));
313    }
314
315    #[test]
316    fn parse_pcie_addr_no_fn() {
317        test_parse(
318            "0000:0a:01",
319            Address::new()
320                .with_group(None)
321                .with_bus(0x0a)
322                .with_device(0x01),
323        );
324        test_parse(
325            "1234:0f:0f",
326            Address::new()
327                .with_group(NonZeroU16::new(0x1234))
328                .with_bus(0x0f)
329                .with_device(0x0f),
330        );
331        test_parse(
332            "ffff:0a:0b",
333            Address::new()
334                .with_group(NonZeroU16::new(0xffff))
335                .with_bus(0x0a)
336                .with_device(0x0b),
337        );
338    }
339
340    #[test]
341    fn parse_invalid() {
342        println!("{}", "hello world".parse::<Address>().unwrap_err());
343    }
344
345    #[test]
346    fn parse_pci_addr_with_fn() {
347        test_parse("00:02.0", Address::new().with_device(0x02));
348        test_parse("0f:0f.0", Address::new().with_bus(0x000f).with_device(0x0f));
349        test_parse("00:02.1", Address::new().with_device(0x02).with_function(1));
350        test_parse(
351            "0f:0f.1",
352            Address::new()
353                .with_bus(0x000f)
354                .with_device(0x0f)
355                .with_function(1),
356        );
357    }
358
359    #[test]
360    fn parse_pcie_addr_with_fn() {
361        test_parse(
362            "0000:0a:01.0",
363            Address::new()
364                .with_group(None)
365                .with_bus(0x0a)
366                .with_device(0x01),
367        );
368        test_parse(
369            "1234:0f:0f.0",
370            Address::new()
371                .with_group(NonZeroU16::new(0x1234))
372                .with_bus(0x0f)
373                .with_device(0x0f),
374        );
375        test_parse(
376            "ffff:0a:0b.0",
377            Address::new()
378                .with_group(NonZeroU16::new(0xffff))
379                .with_bus(0x0a)
380                .with_device(0x0b),
381        );
382
383        test_parse(
384            "0000:0a:01.1",
385            Address::new()
386                .with_group(None)
387                .with_bus(0x0a)
388                .with_device(0x01)
389                .with_function(1),
390        );
391        test_parse(
392            "1234:0f:0f.1",
393            Address::new()
394                .with_group(NonZeroU16::new(0x1234))
395                .with_bus(0x0f)
396                .with_device(0x0f)
397                .with_function(1),
398        );
399        test_parse(
400            "ffff:0a:0b.1",
401            Address::new()
402                .with_group(NonZeroU16::new(0xffff))
403                .with_bus(0x0a)
404                .with_device(0x0b)
405                .with_function(1),
406        );
407    }
408}