use crate::{cpu, mm, segment, time, VAddr};
use core::{arch::asm, marker::PhantomData, time::Duration};
use hal_core::interrupt::Control;
use hal_core::interrupt::{ctx, Handlers};
use hal_core::mem::page;
use mycelium_util::{
bits, fmt,
sync::{
blocking::{Mutex, MutexGuard},
spin::Spinlock,
InitOnce,
},
};
pub mod apic;
pub mod idt;
pub mod pic;
use self::apic::{IoApicSet, LocalApic};
pub use idt::Idt;
pub use pic::CascadedPic;
#[derive(Debug)]
pub struct Controller {
model: InterruptModel,
}
#[derive(Debug)]
#[repr(C)]
pub struct Context<'a, T = ()> {
registers: &'a mut Registers,
code: T,
}
pub type ErrorCode = u64;
pub struct CodeFault<'a> {
kind: &'static str,
error_code: Option<&'a dyn fmt::Display>,
}
pub type Isr<T> = extern "x86-interrupt" fn(&mut Context<T>);
#[derive(Debug, thiserror::Error)]
pub enum PeriodicTimerError {
#[error("could not start PIT periodic timer: {0}")]
Pit(#[from] time::PitError),
#[error(transparent)]
InvalidDuration(#[from] time::InvalidDuration),
#[error("could access local APIC: {0}")]
Apic(#[from] apic::local::LocalApicError),
#[error("could not start local APIC periodic timer: {0}")]
ApicTimer(#[from] apic::local::TimerError),
}
#[derive(Debug)]
#[repr(C)]
pub struct Interrupt<T = ()> {
vector: u8,
_t: PhantomData<T>,
}
#[derive(Debug)]
enum InterruptModel {
Pic(Mutex<pic::CascadedPic, Spinlock>),
Apic {
io: apic::IoApicSet,
local: apic::local::Handle,
},
}
bits::bitfield! {
pub struct PageFaultCode<u32> {
pub const PRESENT: bool;
pub const WRITE: bool;
pub const USER: bool;
pub const RESERVED_WRITE: bool;
pub const INSTRUCTION_FETCH: bool;
pub const PROTECTION_KEY: bool;
pub const SHADOW_STACK: bool;
const _RESERVED0 = 8;
pub const SGX: bool;
}
}
bits::bitfield! {
pub struct SelectorErrorCode<u16> {
const EXTERNAL: bool;
const TABLE: cpu::DescriptorTable;
const INDEX = 13;
}
}
#[repr(C)]
pub struct Registers {
pub instruction_ptr: VAddr, pub code_segment: segment::Selector,
_pad: [u16; 3],
pub cpu_flags: u64, pub stack_ptr: VAddr, pub stack_segment: segment::Selector,
_pad2: [u16; 3],
}
static IDT: Mutex<idt::Idt, Spinlock> = Mutex::new_with_raw_mutex(idt::Idt::new(), Spinlock::new());
static INTERRUPT_CONTROLLER: InitOnce<Controller> = InitOnce::uninitialized();
#[derive(Copy, Clone, Debug)]
#[repr(u8)]
pub enum IsaInterrupt {
PitTimer = 0,
Ps2Keyboard = 1,
Com2 = 3,
Com1 = 4,
Lpt2 = 5,
Floppy = 6,
Lpt1 = 7,
CmosRtc = 8,
Periph1 = 9,
Periph2 = 10,
Periph3 = 11,
Ps2Mouse = 12,
Fpu = 13,
AtaPrimary = 14,
AtaSecondary = 15,
}
impl IsaInterrupt {
pub const ALL: [IsaInterrupt; 15] = [
IsaInterrupt::PitTimer,
IsaInterrupt::Ps2Keyboard,
IsaInterrupt::Com2,
IsaInterrupt::Com1,
IsaInterrupt::Lpt2,
IsaInterrupt::Floppy,
IsaInterrupt::Lpt1,
IsaInterrupt::CmosRtc,
IsaInterrupt::Periph1,
IsaInterrupt::Periph2,
IsaInterrupt::Periph3,
IsaInterrupt::Ps2Mouse,
IsaInterrupt::Fpu,
IsaInterrupt::AtaPrimary,
IsaInterrupt::AtaSecondary,
];
}
#[must_use]
fn disable_scoped() -> impl Drop + Send + Sync {
unsafe {
crate::cpu::intrinsics::cli();
}
mycelium_util::defer(|| unsafe {
crate::cpu::intrinsics::sti();
})
}
impl Controller {
pub fn idt() -> MutexGuard<'static, idt::Idt, Spinlock> {
IDT.lock()
}
#[tracing::instrument(level = "info", name = "interrupt::Controller::init")]
pub fn init<H: Handlers<Registers>>() {
tracing::info!("intializing IDT...");
let mut idt = IDT.lock();
idt.register_handlers::<H>().unwrap();
unsafe {
idt.load_raw();
}
}
pub fn mask_isa_irq(&self, irq: IsaInterrupt) {
match self.model {
InterruptModel::Pic(ref pics) => pics.lock().mask(irq),
InterruptModel::Apic { ref io, .. } => io.set_isa_masked(irq, true),
}
}
pub fn unmask_isa_irq(&self, irq: IsaInterrupt) {
match self.model {
InterruptModel::Pic(ref pics) => pics.lock().unmask(irq),
InterruptModel::Apic { ref io, .. } => io.set_isa_masked(irq, false),
}
}
fn local_apic_handle(&self) -> Result<&apic::local::Handle, apic::local::LocalApicError> {
match self.model {
InterruptModel::Pic(_) => Err(apic::local::LocalApicError::NoApic),
InterruptModel::Apic { ref local, .. } => Ok(local),
}
}
pub fn with_local_apic<T>(
&self,
f: impl FnOnce(&LocalApic) -> T,
) -> Result<T, apic::local::LocalApicError> {
self.local_apic_handle()?.with(f)
}
pub fn initialize_local_apic<A>(
&self,
frame_alloc: &A,
pagectrl: &mut impl page::Map<mm::size::Size4Kb, A>,
) -> Result<(), apic::local::LocalApicError>
where
A: page::Alloc<mm::size::Size4Kb>,
{
let _deferred = disable_scoped();
let hdl = self.local_apic_handle()?;
unsafe {
hdl.initialize(frame_alloc, pagectrl, Idt::LOCAL_APIC_SPURIOUS as u8);
}
Ok(())
}
pub unsafe fn end_isa_irq(&self, irq: IsaInterrupt) {
match self.model {
InterruptModel::Pic(ref pics) => pics.lock().end_interrupt(irq),
InterruptModel::Apic { ref local, .. } => local.with(|apic| unsafe { apic.end_interrupt() })
.expect("interrupts should not be handled on this core until the local APIC is initialized")
}
}
pub fn enable_hardware_interrupts(
acpi: Option<&acpi::InterruptModel>,
frame_alloc: &impl page::Alloc<mm::size::Size4Kb>,
) -> &'static Self {
let mut pics = pic::CascadedPic::new();
unsafe {
tracing::debug!(
big = Idt::PIC_BIG_START,
little = Idt::PIC_LITTLE_START,
"remapping PIC interrupt vectors"
);
pics.set_irq_address(Idt::PIC_BIG_START as u8, Idt::PIC_LITTLE_START as u8);
}
let controller = match acpi {
Some(acpi::InterruptModel::Apic(apic_info)) => {
tracing::info!("detected APIC interrupt model");
let mut pagectrl = mm::PageCtrl::current();
unsafe {
pics.disable();
}
tracing::info!("disabled 8259 PICs");
let io = IoApicSet::new(apic_info, frame_alloc, &mut pagectrl, Idt::ISA_BASE as u8);
let local = apic::local::Handle::new();
unsafe {
local.initialize(frame_alloc, &mut pagectrl, Idt::LOCAL_APIC_SPURIOUS as u8);
}
let model = InterruptModel::Apic { local, io };
tracing::trace!(interrupt_model = ?model);
INTERRUPT_CONTROLLER.init(Self { model })
}
model => {
if model.is_none() {
tracing::warn!("platform does not support ACPI; falling back to 8259 PIC");
} else {
tracing::warn!(
"ACPI does not indicate APIC interrupt model; falling back to 8259 PIC"
)
}
tracing::info!("configuring 8259 PIC interrupts...");
unsafe {
pics.enable();
}
INTERRUPT_CONTROLLER.init(Self {
model: InterruptModel::Pic(Mutex::new_with_raw_mutex(pics, Spinlock::new())),
})
}
};
unsafe {
crate::cpu::intrinsics::sti();
}
controller.unmask_isa_irq(IsaInterrupt::PitTimer);
controller.unmask_isa_irq(IsaInterrupt::Ps2Keyboard);
controller
}
pub fn start_periodic_timer(&self, interval: Duration) -> Result<(), PeriodicTimerError> {
match self.model {
InterruptModel::Pic(_) => crate::time::PIT
.lock()
.start_periodic_timer(interval)
.map_err(Into::into),
InterruptModel::Apic { ref local, .. } => local.with(|apic| {
apic.calibrate_timer(apic::local::register::TimerDivisor::By16);
apic.start_periodic_timer(interval, Idt::LOCAL_APIC_TIMER as u8)?;
Ok(())
})?,
}
}
}
impl<T> hal_core::interrupt::Context for Context<'_, T> {
type Registers = Registers;
fn registers(&self) -> &Registers {
self.registers
}
unsafe fn registers_mut(&mut self) -> &mut Registers {
self.registers
}
}
impl ctx::PageFault for Context<'_, PageFaultCode> {
fn fault_vaddr(&self) -> crate::VAddr {
crate::control_regs::Cr2::read()
}
fn debug_error_code(&self) -> &dyn fmt::Debug {
&self.code
}
fn display_error_code(&self) -> &dyn fmt::Display {
&self.code
}
}
impl ctx::CodeFault for Context<'_, CodeFault<'_>> {
fn is_user_mode(&self) -> bool {
false }
fn instruction_ptr(&self) -> crate::VAddr {
self.registers.instruction_ptr
}
fn fault_kind(&self) -> &'static str {
self.code.kind
}
fn details(&self) -> Option<&dyn fmt::Display> {
self.code.error_code
}
}
impl Context<'_, ErrorCode> {
pub fn error_code(&self) -> ErrorCode {
self.code
}
}
impl Context<'_, PageFaultCode> {
pub fn page_fault_code(&self) -> PageFaultCode {
self.code
}
}
impl hal_core::interrupt::Control for Idt {
type Registers = Registers;
#[inline]
unsafe fn disable(&mut self) {
crate::cpu::intrinsics::cli();
}
#[inline]
unsafe fn enable(&mut self) {
crate::cpu::intrinsics::sti();
tracing::trace!("interrupts enabled");
}
fn is_enabled(&self) -> bool {
unimplemented!("eliza do this one!!!")
}
fn register_handlers<H>(&mut self) -> Result<(), hal_core::interrupt::RegistrationError>
where
H: Handlers<Registers>,
{
let span = tracing::debug_span!("Idt::register_handlers");
let _enter = span.enter();
self.register_isr(Self::DIVIDE_BY_ZERO, isr::div_0::<H> as *const ());
self.register_isr(Self::OVERFLOW, isr::overflow::<H> as *const ());
self.register_isr(Self::BOUND_RANGE_EXCEEDED, isr::br::<H> as *const ());
self.register_isr(Self::INVALID_OPCODE, isr::ud::<H> as *const ());
self.register_isr(Self::DEVICE_NOT_AVAILABLE, isr::no_fpu::<H> as *const ());
self.register_isr(
Self::ALIGNMENT_CHECK,
isr::alignment_check::<H> as *const (),
);
self.register_isr(
Self::SIMD_FLOATING_POINT,
isr::simd_fp_exn::<H> as *const (),
);
self.register_isr(Self::X87_FPU_EXCEPTION, isr::x87_exn::<H> as *const ());
self.register_isr(Self::PAGE_FAULT, isr::page_fault::<H> as *const ());
self.register_isr(Self::INVALID_TSS, isr::invalid_tss::<H> as *const ());
self.register_isr(
Self::SEGMENT_NOT_PRESENT,
isr::segment_not_present::<H> as *const (),
);
self.register_isr(
Self::STACK_SEGMENT_FAULT,
isr::stack_segment::<H> as *const (),
);
self.register_isr(Self::GENERAL_PROTECTION_FAULT, isr::gpf::<H> as *const ());
self.register_isr(Self::DOUBLE_FAULT, isr::double_fault::<H> as *const ());
self.register_isa_isr(IsaInterrupt::PitTimer, isr::pit_timer::<H> as *const ());
self.register_isa_isr(IsaInterrupt::Ps2Keyboard, isr::keyboard::<H> as *const ());
self.register_isr(Self::LOCAL_APIC_SPURIOUS, isr::spurious as *const ());
self.register_isr(Self::LOCAL_APIC_TIMER, isr::apic_timer::<H> as *const ());
self.register_isr(69, isr::test::<H> as *const ());
Ok(())
}
}
unsafe fn force_unlock_tracing() {
crate::vga::writer().force_unlock();
if let Some(com1) = crate::serial::com1() {
com1.force_unlock();
}
}
impl fmt::Debug for Registers {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self {
instruction_ptr,
code_segment,
stack_ptr,
stack_segment,
_pad: _,
cpu_flags,
_pad2: _,
} = self;
f.debug_struct("Registers")
.field("instruction_ptr", instruction_ptr)
.field("code_segment", code_segment)
.field("cpu_flags", &format_args!("{cpu_flags:#b}"))
.field("stack_ptr", stack_ptr)
.field("stack_segment", stack_segment)
.finish()
}
}
impl fmt::Display for Registers {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, " rip: {:?}", self.instruction_ptr)?;
writeln!(f, " cs: {:?}", self.code_segment)?;
writeln!(f, " flags: {:#b}", self.cpu_flags)?;
writeln!(f, " rsp: {:?}", self.stack_ptr)?;
writeln!(f, " ss: {:?}", self.stack_segment)?;
Ok(())
}
}
pub fn fire_test_interrupt() {
unsafe { asm!("int {0}", const 69) }
}
impl SelectorErrorCode {
#[inline]
fn named(self, segment_kind: &'static str) -> NamedSelectorErrorCode {
NamedSelectorErrorCode {
segment_kind,
code: self,
}
}
fn display(&self) -> impl fmt::Display {
struct PrettyErrorCode(SelectorErrorCode);
impl fmt::Display for PrettyErrorCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let table = self.0.get(SelectorErrorCode::TABLE);
let index = self.0.get(SelectorErrorCode::INDEX);
write!(f, "{table} index {index}")?;
if self.0.get(SelectorErrorCode::EXTERNAL) {
f.write_str(" (from an external source)")?;
}
write!(f, " (error code {:#b})", self.0.bits())?;
Ok(())
}
}
PrettyErrorCode(*self)
}
}
struct NamedSelectorErrorCode {
segment_kind: &'static str,
code: SelectorErrorCode,
}
impl fmt::Display for NamedSelectorErrorCode {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} at {}", self.segment_kind, self.code.display())
}
}
mod isr {
use super::*;
macro_rules! gen_code_faults {
($(fn $name:ident($($rest:tt)+),)+) => {
$(
gen_code_faults! {@ $name($($rest)+); }
)+
};
(@ $name:ident($kind:literal);) => {
pub(super) extern "x86-interrupt" fn $name<H: Handlers<Registers>>(mut registers: Registers) {
let code = CodeFault {
error_code: None,
kind: $kind,
};
H::code_fault(Context { registers: &mut registers, code });
}
};
(@ $name:ident($kind:literal, code);) => {
pub(super) extern "x86-interrupt" fn $name<H: Handlers<Registers>>(
mut registers: Registers,
code: u64,
) {
let code = CodeFault {
error_code: Some(&code),
kind: $kind,
};
H::code_fault(Context { registers: &mut registers, code });
}
};
}
gen_code_faults! {
fn div_0("Divide-By-Zero (0x0)"),
fn overflow("Overflow (0x4)"),
fn br("Bound Range Exceeded (0x5)"),
fn ud("Invalid Opcode (0x6)"),
fn no_fpu("Device (FPU) Not Available (0x7)"),
fn alignment_check("Alignment Check (0x11)", code),
fn simd_fp_exn("SIMD Floating-Point Exception (0x13)"),
fn x87_exn("x87 Floating-Point Exception (0x10)"),
}
pub(super) extern "x86-interrupt" fn page_fault<H: Handlers<Registers>>(
mut registers: Registers,
code: PageFaultCode,
) {
H::page_fault(Context {
registers: &mut registers,
code,
});
}
pub(super) extern "x86-interrupt" fn double_fault<H: Handlers<Registers>>(
mut registers: Registers,
code: u64,
) {
H::double_fault(Context {
registers: &mut registers,
code,
});
}
pub(super) extern "x86-interrupt" fn pit_timer<H: Handlers<Registers>>(_regs: Registers) {
if crate::time::Pit::handle_interrupt() {
H::timer_tick()
}
unsafe {
INTERRUPT_CONTROLLER
.get_unchecked()
.end_isa_irq(IsaInterrupt::PitTimer);
}
}
pub(super) extern "x86-interrupt" fn apic_timer<H: Handlers<Registers>>(_regs: Registers) {
H::timer_tick();
unsafe {
match INTERRUPT_CONTROLLER.get_unchecked().model {
InterruptModel::Pic(_) => unreachable!(),
InterruptModel::Apic { ref local, .. } => {
match local.with(|apic| apic.end_interrupt()) {
Ok(_) => {}
Err(e) => unreachable!(
"the local APIC timer will not have fired if the \
local APIC is uninitialized on this core! {e:?}",
),
}
}
}
}
}
pub(super) extern "x86-interrupt" fn keyboard<H: Handlers<Registers>>(_regs: Registers) {
static PORT: cpu::Port = cpu::Port::at(0x60);
let scancode = unsafe { PORT.readb() };
H::ps2_keyboard(scancode);
unsafe {
INTERRUPT_CONTROLLER
.get_unchecked()
.end_isa_irq(IsaInterrupt::Ps2Keyboard);
}
}
pub(super) extern "x86-interrupt" fn test<H: Handlers<Registers>>(mut registers: Registers) {
H::test_interrupt(Context {
registers: &mut registers,
code: (),
});
}
pub(super) extern "x86-interrupt" fn invalid_tss<H: Handlers<Registers>>(
mut registers: Registers,
code: u64,
) {
unsafe {
force_unlock_tracing();
}
let selector = SelectorErrorCode(code as u16);
tracing::error!(?selector, "invalid task-state segment!");
let msg = selector.named("task-state segment (TSS)");
let code = CodeFault {
error_code: Some(&msg),
kind: "Invalid TSS (0xA)",
};
H::code_fault(Context {
registers: &mut registers,
code,
});
}
pub(super) extern "x86-interrupt" fn segment_not_present<H: Handlers<Registers>>(
mut registers: Registers,
code: u64,
) {
unsafe {
force_unlock_tracing();
}
let selector = SelectorErrorCode(code as u16);
tracing::error!(?selector, "a segment was not present!");
let msg = selector.named("stack segment");
let code = CodeFault {
error_code: Some(&msg),
kind: "Segment Not Present (0xB)",
};
H::code_fault(Context {
registers: &mut registers,
code,
});
}
pub(super) extern "x86-interrupt" fn stack_segment<H: Handlers<Registers>>(
mut registers: Registers,
code: u64,
) {
unsafe {
force_unlock_tracing();
}
let selector = SelectorErrorCode(code as u16);
tracing::error!(?selector, "a stack-segment fault is happening");
let msg = selector.named("stack segment");
let code = CodeFault {
error_code: Some(&msg),
kind: "Stack-Segment Fault (0xC)",
};
H::code_fault(Context {
registers: &mut registers,
code,
});
}
pub(super) extern "x86-interrupt" fn gpf<H: Handlers<Registers>>(
mut registers: Registers,
code: u64,
) {
unsafe {
force_unlock_tracing();
}
let segment = if code > 0 {
Some(SelectorErrorCode(code as u16))
} else {
None
};
tracing::error!(?segment, "lmao, a general protection fault is happening");
let error_code = segment.map(|seg| seg.named("selector"));
let code = CodeFault {
error_code: error_code.as_ref().map(|code| code as &dyn fmt::Display),
kind: "General Protection Fault (0xD)",
};
H::code_fault(Context {
registers: &mut registers,
code,
});
}
pub(super) extern "x86-interrupt" fn spurious() {
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::mem::size_of;
#[test]
fn registers_is_correct_size() {
assert_eq!(size_of::<Registers>(), 40);
}
}