1use core::{fmt, num::NonZeroU16, str::FromStr};
2use hex::{FromHex, FromHexError};
3use mycelium_bitfield::{bitfield, Pack32};
4
5#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
7pub struct Address(AddressBits);
8
9#[derive(Debug)]
10pub struct ParseError {
11 kind: ParseErrorKind,
12}
13
14#[derive(Debug)]
15
16enum ParseErrorKind {
17 ExpectedChar(char),
18 NotANumber {
19 reason: FromHexError,
20 pos: &'static str,
21 },
22 InvalidNumber {
23 num: u16,
24 max: u16,
25 name: &'static str,
26 },
27 Msg(&'static str),
28}
29
30bitfield! {
31 #[derive(Eq, PartialEq, Ord, PartialOrd)]
32 pub(crate) struct AddressBits<u32> {
33 pub(crate) const FUNCTION = 3;
34 pub(crate) const DEVICE = 5;
35 pub(crate) const BUS: u8;
36 pub(crate) const GROUP: u16;
37 }
38}
39
40impl Address {
41 #[inline]
42 #[must_use]
43 pub const fn new() -> Self {
44 Self(AddressBits::new())
45 }
46
47 #[inline]
52 #[must_use]
53 pub fn group(self) -> Option<NonZeroU16> {
54 NonZeroU16::new(self.0.get(AddressBits::GROUP))
55 }
56
57 #[inline]
61 #[must_use]
62 pub fn bus(self) -> u8 {
63 self.0.get(AddressBits::BUS)
64 }
65
66 #[inline]
68 #[must_use]
69 pub fn device(self) -> u8 {
70 self.0.get(AddressBits::DEVICE) as u8
71 }
72
73 #[inline]
77 #[must_use]
78 pub fn function(self) -> u8 {
79 self.0.get(AddressBits::FUNCTION) as u8
80 }
81
82 #[inline]
83 #[must_use]
84 pub fn with_group(self, group: Option<NonZeroU16>) -> Self {
85 let value = group.map(NonZeroU16::get).unwrap_or(0);
86 Self(self.0.with(AddressBits::GROUP, value))
87 }
88
89 #[inline]
90 #[must_use]
91 pub fn with_bus(self, bus: u8) -> Self {
92 Self(self.0.with(AddressBits::BUS, bus))
93 }
94
95 #[inline]
96 #[must_use]
97 pub fn with_device(self, device: u8) -> Self {
98 Self(self.0.with(AddressBits::DEVICE, device as u32))
99 }
100
101 #[inline]
102 #[must_use]
103 pub fn with_function(self, function: u8) -> Self {
104 Self(self.0.with(AddressBits::FUNCTION, function as u32))
105 }
106
107 pub(crate) fn bitfield(self) -> AddressBits {
108 self.0
109 }
110}
111
112impl Default for Address {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl fmt::Display for Address {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 fmt::LowerHex::fmt(self, f)
121 }
122}
123
124impl fmt::UpperHex for Address {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 if let Some(group) = self.group() {
127 write!(f, "{group:04X}:")?;
128 }
129 write!(
130 f,
131 "{:02X}:{:02X}.{}",
132 self.bus(),
133 self.device(),
134 self.function()
135 )
136 }
137}
138
139impl fmt::LowerHex for Address {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 if let Some(group) = self.group() {
142 write!(f, "{group:04x}:")?;
143 }
144 write!(
145 f,
146 "{:02x}:{:02x}.{}",
147 self.bus(),
148 self.device(),
149 self.function()
150 )
151 }
152}
153
154impl fmt::Debug for Address {
155 #[inline]
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "Address({self:x})")
158 }
159}
160
161impl FromStr for Address {
162 type Err = ParseError;
163 fn from_str(s: &str) -> Result<Self, Self::Err> {
164 fn parse_u8<T>(
165 s: &str,
166 mask: Pack32<T, AddressBits>,
167 name: &'static str,
168 ) -> Result<u8, ParseError> {
169 let [num] = <[u8; 1]>::from_hex(s).map_err(ParseError::not_a_number(name))?;
170 ParseError::validate_mask(num, mask, name)
171 }
172
173 let s = s.trim();
174 let mut addr = Address::new();
175 let mut split = s.split(':');
176 let first = split.next().ok_or_else(|| ParseError::expected_char(':'))?;
177 let second = split
178 .next()
179 .ok_or_else(|| ParseError::msg("expected a device number after ':'"))?;
180 let (bus, dev_fn) = if let Some(third) = split.next() {
181 let bytes =
183 <[u8; 2]>::from_hex(first).map_err(ParseError::not_a_number("bus group"))?;
184 let group = u16::from_be_bytes(bytes);
185 addr = addr.with_group(NonZeroU16::new(group));
186 (second, third)
187 } else {
188 (first, second)
189 };
190
191 let bus = parse_u8(bus, AddressBits::BUS, "bus number")?;
192 addr = addr.with_bus(bus);
193
194 let mut dev_fn = dev_fn.split('.');
195 let device = dev_fn
196 .next()
197 .ok_or_else(|| ParseError::msg("expected device number"))
198 .and_then(|dev| parse_u8(dev, AddressBits::DEVICE, "device number"))?;
199 addr = addr.with_device(device);
200
201 let func = dev_fn
202 .next()
203 .map(|func| {
204 func.parse::<u8>()
208 .map_err(|_| {
209 ParseError::msg("function number is not a number in the range 0-8")
210 })
211 .and_then(|num| {
212 ParseError::validate_mask(num, AddressBits::FUNCTION, "function number")
213 })
214 })
215 .transpose()?;
216 if let Some(func) = func {
217 addr = addr.with_function(func);
218 }
219 Ok(addr)
220 }
221}
222
223impl ParseError {
226 fn validate_mask<T>(
227 num: u8,
228 mask: Pack32<T, AddressBits>,
229 name: &'static str,
230 ) -> Result<u8, Self> {
231 let max = mask.max_value() as u16;
232 if num as u16 > max {
233 return Err(Self {
234 kind: ParseErrorKind::InvalidNumber {
235 num: num as u16,
236 max,
237 name,
238 },
239 });
240 }
241
242 Ok(num)
243 }
244
245 fn not_a_number(pos: &'static str) -> impl Fn(FromHexError) -> Self {
246 move |reason| Self {
247 kind: ParseErrorKind::NotANumber { reason, pos },
248 }
249 }
250 fn expected_char(c: char) -> Self {
251 Self {
252 kind: ParseErrorKind::ExpectedChar(c),
253 }
254 }
255
256 fn msg(msg: &'static str) -> Self {
257 Self {
258 kind: ParseErrorKind::Msg(msg),
259 }
260 }
261}
262
263impl fmt::Display for ParseError {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 match self.kind {
266 ParseErrorKind::ExpectedChar(c) => write!(f, "expected a '{c}'")?,
267 ParseErrorKind::InvalidNumber { num, max, name } => {
268 write!(f, "{name} must be less than {max:#x} (got {num:#x})")?
269 }
270 ParseErrorKind::NotANumber { reason, pos } => {
271 write!(f, "{pos} was not a valid hexadecimal number ({reason})")?
272 }
273 ParseErrorKind::Msg(msg) => f.write_str(msg)?,
274 };
275 f.write_str(
276 ", PCI addresses must be in the format '(<BUS GROUP>:)<BUS>:<DEVICE>(.<FUNCTION>)'",
277 )
278 }
279}
280
281#[cfg(test)]
282mod test {
283 use super::*;
284 use proptest::{prop_assert_eq, proptest};
285
286 #[test]
287 fn addrs_are_valid() {
288 AddressBits::assert_valid();
289 }
290
291 proptest! {
292 #[test]
293 fn addr_roundtrips(bus in 0u8..255u8, device in 0u8..32u8, function in 0u8..8u8) {
294 let addr = Address::new().with_bus(bus).with_device(device).with_function(function);
295
296 prop_assert_eq!(addr.bus(), bus, "bus, addr: {}", addr);
297 prop_assert_eq!(addr.device(), device, "device, addr: {}", addr);
298 prop_assert_eq!(addr.function(), function, "function, addr: {}", addr);
299
300 }
301 }
302
303 #[track_caller]
304 fn test_parse(s: &str, expected: Address) {
305 let addr = s.parse::<Address>().expect(s);
306 assert_eq!(addr, expected);
307 }
308
309 #[test]
310 fn parse_pci_addr_no_fn() {
311 test_parse("00:02", Address::new().with_device(0x02));
312 test_parse("0f:0f", Address::new().with_bus(0x000f).with_device(0x0f));
313 }
314
315 #[test]
316 fn parse_pcie_addr_no_fn() {
317 test_parse(
318 "0000:0a:01",
319 Address::new()
320 .with_group(None)
321 .with_bus(0x0a)
322 .with_device(0x01),
323 );
324 test_parse(
325 "1234:0f:0f",
326 Address::new()
327 .with_group(NonZeroU16::new(0x1234))
328 .with_bus(0x0f)
329 .with_device(0x0f),
330 );
331 test_parse(
332 "ffff:0a:0b",
333 Address::new()
334 .with_group(NonZeroU16::new(0xffff))
335 .with_bus(0x0a)
336 .with_device(0x0b),
337 );
338 }
339
340 #[test]
341 fn parse_invalid() {
342 println!("{}", "hello world".parse::<Address>().unwrap_err());
343 }
344
345 #[test]
346 fn parse_pci_addr_with_fn() {
347 test_parse("00:02.0", Address::new().with_device(0x02));
348 test_parse("0f:0f.0", Address::new().with_bus(0x000f).with_device(0x0f));
349 test_parse("00:02.1", Address::new().with_device(0x02).with_function(1));
350 test_parse(
351 "0f:0f.1",
352 Address::new()
353 .with_bus(0x000f)
354 .with_device(0x0f)
355 .with_function(1),
356 );
357 }
358
359 #[test]
360 fn parse_pcie_addr_with_fn() {
361 test_parse(
362 "0000:0a:01.0",
363 Address::new()
364 .with_group(None)
365 .with_bus(0x0a)
366 .with_device(0x01),
367 );
368 test_parse(
369 "1234:0f:0f.0",
370 Address::new()
371 .with_group(NonZeroU16::new(0x1234))
372 .with_bus(0x0f)
373 .with_device(0x0f),
374 );
375 test_parse(
376 "ffff:0a:0b.0",
377 Address::new()
378 .with_group(NonZeroU16::new(0xffff))
379 .with_bus(0x0a)
380 .with_device(0x0b),
381 );
382
383 test_parse(
384 "0000:0a:01.1",
385 Address::new()
386 .with_group(None)
387 .with_bus(0x0a)
388 .with_device(0x01)
389 .with_function(1),
390 );
391 test_parse(
392 "1234:0f:0f.1",
393 Address::new()
394 .with_group(NonZeroU16::new(0x1234))
395 .with_bus(0x0f)
396 .with_device(0x0f)
397 .with_function(1),
398 );
399 test_parse(
400 "ffff:0a:0b.1",
401 Address::new()
402 .with_group(NonZeroU16::new(0xffff))
403 .with_bus(0x0a)
404 .with_device(0x0b)
405 .with_function(1),
406 );
407 }
408}