diff --git a/src/arch/macros.rs b/src/arch/macros.rs index 6e3e2c2..91c216a 100644 --- a/src/arch/macros.rs +++ b/src/arch/macros.rs @@ -30,9 +30,19 @@ macro_rules! syscall_enum { } impl $Name { + /// A slice of all possible syscalls. + pub(crate) const ALL: &'static [Self] = &[ + Self::$first_syscall, + $( + Self::$syscall, + )* + ]; + /// Constructs a new syscall from the given ID. If the ID does not /// represent a valid syscall, returns `None`. pub const fn new(id: usize) -> Option { + // TODO: Get rid of this huge match and use the SysnoSet for + // checking validity. match id { $(#[$first_inner])* $first_num => Some(Self::$first_syscall), @@ -61,10 +71,10 @@ macro_rules! syscall_enum { return None; } - let mut next_id = self.id() as usize + 1; + let mut next_id = self.id() + 1; - while next_id < Self::len() { - if let Some(next) = Self::new(next_id) { + while next_id < Self::last().id() { + if let Some(next) = Self::new(next_id as usize) { return Some(next); } @@ -90,7 +100,19 @@ macro_rules! syscall_enum { } /// Returns the length of the syscall table, including any gaps. + #[deprecated = "Sysno::len() is misleading. Use Sysno::table_size() instead."] pub const fn len() -> usize { + Self::table_size() + } + + /// Returns the total number of valid syscalls. + pub const fn count() -> usize { + Self::ALL.len() + } + + /// Returns the length of the syscall table, including any gaps. + /// This is not the same thing as the total number of syscalls. + pub const fn table_size() -> usize { (Self::last().id() - Self::first().id()) as usize + 1 } @@ -127,6 +149,13 @@ macro_rules! syscall_enum { } } + impl From for $Name { + fn from(id: u32) -> Self { + Self::new(id as usize) + .unwrap_or_else(|| panic!("invalid syscall: {}", id)) + } + } + impl From for $Name { fn from(id: i32) -> Self { Self::new(id as usize) diff --git a/src/lib.rs b/src/lib.rs index 234f2b0..18911e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,10 +18,12 @@ mod macros; mod arch; mod args; mod errno; +mod set; pub use arch::Sysno; pub use args::SyscallArgs; pub use errno::{Errno, ErrnoSentinel}; +pub use set::SysnoSet; pub mod raw { //! Exposes raw syscalls that simply return a `usize` instead of a `Result`. @@ -263,7 +265,7 @@ mod tests { #[test] fn test_syscall_len() { - assert!(Sysno::len() > 300); - assert!(Sysno::len() < 1000); + assert!(Sysno::table_size() > 300); + assert!(Sysno::table_size() < 1000); } } diff --git a/src/set.rs b/src/set.rs new file mode 100644 index 0000000..b4a9b4e --- /dev/null +++ b/src/set.rs @@ -0,0 +1,421 @@ +//! Enables the creation of a syscall bitset. + +use super::Sysno; + +use core::fmt; +use core::num::NonZeroUsize; + +const fn bits_per() -> usize { + core::mem::size_of::().saturating_mul(8) +} + +/// Returns the number of words of type `T` required to hold the specified +/// number of `bits`. +const fn words(bits: usize) -> usize { + let width = bits_per::(); + if width == 0 { + return 0; + } + + bits / width + ((bits % width != 0) as usize) +} + +/// A set of syscalls. +/// +/// This provides constant-time lookup of syscalls within a bitset. This is +/// useful for efficient +#[derive(Clone, Eq, PartialEq)] +pub struct SysnoSet { + data: [usize; words::(Sysno::table_size())], +} + +impl SysnoSet { + /// The set of all valid syscalls. + const ALL: &'static SysnoSet = &SysnoSet::new(Sysno::ALL); + + const WORD_WIDTH: usize = usize::BITS as usize; + + /// Initialize the syscall set with the given slice of syscalls. + /// + /// Since this is a `const fn`, this can be used at compile-time. + pub const fn new(syscalls: &[Sysno]) -> Self { + let mut set = Self::empty(); + + // Use a plain-old while-loop because for-loops are not yet allowed in + // const-fns. + let mut i = 0; + let n = syscalls.len(); + while i < n { + let sysno = syscalls[i]; + + let bit = (sysno.id() as usize) - (Sysno::first().id() as usize); + let idx = bit / Self::WORD_WIDTH; + + set.data[idx] |= 1 << (bit % Self::WORD_WIDTH); + + i += 1; + } + + set + } + + /// Creates an empty set of syscalls. + pub const fn empty() -> Self { + Self { + data: [0; words::(Sysno::table_size())], + } + } + + /// Creates a set containing all valid syscalls. + pub const fn all() -> Self { + Self { + data: Self::ALL.data, + } + } + + /// Returns true if the set contains the given syscall. + pub const fn contains(&self, sysno: Sysno) -> bool { + let bit = (sysno.id() as usize) - (Sysno::first().id() as usize); + let idx = bit / Self::WORD_WIDTH; + + (self.data[idx] & (1 << (bit % Self::WORD_WIDTH))) != 0 + } + + /// Clears the set, removing all syscalls. + pub fn clear(&mut self) { + for word in &mut self.data { + *word = 0; + } + } + + /// Returns the number of syscalls in the set. This is an O(n) operation as + /// it must count the number of bits in the bitset. + pub fn count(&self) -> usize { + self.data + .iter() + .fold(0, |acc, x| acc + x.count_ones() as usize) + } + + /// Inserts the given syscall into the set. + pub fn insert(&mut self, sysno: Sysno) { + let bit = (sysno.id() as usize) - (Sysno::first().id() as usize); + let idx = bit / Self::WORD_WIDTH; + + self.data[idx] |= 1 << (bit % Self::WORD_WIDTH); + } + + /// Removes the given syscall from the set. + pub fn remove(&mut self, sysno: Sysno) { + let bit = (sysno.id() as usize) - (Sysno::first().id() as usize); + let idx = bit / Self::WORD_WIDTH; + + self.data[idx] &= !(1 << (bit % Self::WORD_WIDTH)); + } + + /// Does a set union with this set and another. + pub const fn union(mut self, other: &Self) -> Self { + let mut i = 0; + let n = self.data.len(); + while i < n { + self.data[i] |= other.data[i]; + i += 1; + } + + self + } + + /// Does a set intersection with this set and another. + pub const fn intersection(mut self, other: &Self) -> Self { + let mut i = 0; + let n = self.data.len(); + while i < n { + self.data[i] &= other.data[i]; + i += 1; + } + + self + } + + /// Calculates the difference with this set and another. That is, the + /// resulting set only includes the syscalls that are in `self` but not in + /// `other`. + pub const fn difference(mut self, other: &Self) -> Self { + let mut i = 0; + let n = self.data.len(); + while i < n { + self.data[i] &= !other.data[i]; + i += 1; + } + + self + } + + /// Calculates the symmetric difference with this set and another. That is, + /// the resulting set only includes the syscalls that are in `self` or in + /// `other`, but not in both. + pub const fn symmetric_difference(mut self, other: &Self) -> Self { + let mut i = 0; + let n = self.data.len(); + while i < n { + self.data[i] ^= other.data[i]; + i += 1; + } + + self + } + + /// Returns an iterator that iterates over the syscalls contained in the set. + pub fn iter(&self) -> SysnoSetIter { + SysnoSetIter::new(self.data.iter()) + } +} + +impl fmt::Debug for SysnoSet { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{{")?; + + let mut iter = self.iter(); + if let Some(sysno) = iter.next() { + write!(f, "{}", sysno)?; + } + + for sysno in iter { + write!(f, ", {}", sysno)?; + } + + write!(f, "}}")?; + + Ok(()) + } +} + +/// Helper for iterating over the non-zero values of the words in the bitset. +struct NonZeroUsizeIter<'a> { + iter: core::slice::Iter<'a, usize>, + count: usize, +} + +impl<'a> NonZeroUsizeIter<'a> { + pub fn new(iter: core::slice::Iter<'a, usize>) -> Self { + Self { iter, count: 0 } + } +} + +impl<'a> Iterator for NonZeroUsizeIter<'a> { + type Item = NonZeroUsize; + + fn next(&mut self) -> Option { + for item in &mut self.iter { + self.count += 1; + + if let Some(item) = NonZeroUsize::new(*item) { + return Some(item); + } + } + + None + } +} + +/// An iterator over the syscalls contained in a [`SysnoSet`]. +pub struct SysnoSetIter<'a> { + // Our iterator over nonzero words in the bitset. + iter: NonZeroUsizeIter<'a>, + + // The current word in the set we're operating on. This is only None if the + // iterator has been exhausted. The next bit that is set is found by + // counting the number of leading zeros. When found, we just mask it off. + current: Option, +} + +impl<'a> SysnoSetIter<'a> { + fn new(iter: core::slice::Iter<'a, usize>) -> Self { + let mut iter = NonZeroUsizeIter::new(iter); + let current = iter.next(); + Self { iter, current } + } +} + +impl<'a> Iterator for SysnoSetIter<'a> { + type Item = Sysno; + + fn next(&mut self) -> Option { + // Construct a mask where all but the last bit is set. This is then + // shifted to remove the first bit we find. + const MASK: usize = !1usize; + + if let Some(word) = self.current.take() { + let index = self.iter.count.wrapping_sub(1); + + // Get the index of the next bit. For example: + // 0b0000000010000 + // ^ + // Here, there are 4 trailing zeros, so 4 is the next set bit. Since + // we're only iterating over non-zero words, we are guaranteed to + // get a valid index. + let bit = word.trailing_zeros(); + + // Mask off that bit and store the resulting word for next time. + let next_word = + NonZeroUsize::new(word.get() & MASK.rotate_left(bit)); + + self.current = next_word.or_else(|| self.iter.next()); + + let offset = Sysno::first().id() as u32; + let sysno = index as u32 * usize::BITS + bit + offset; + + // TODO: Use an unchecked conversion to speed this up. + return Some(Sysno::from(sysno)); + } + + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_words() { + assert_eq!(words::(42), 1); + assert_eq!(words::(0), 0); + assert_eq!(words::(42), 2); + assert_eq!(words::<()>(42), 0); + } + + #[test] + fn test_bits_per() { + assert_eq!(bits_per::<()>(), 0); + assert_eq!(bits_per::(), 8); + assert_eq!(bits_per::(), 32); + assert_eq!(bits_per::(), 64); + } + + #[test] + fn test_const_new() { + static SYSCALLS: SysnoSet = + SysnoSet::new(&[Sysno::open, Sysno::read, Sysno::close]); + + assert_eq!(SYSCALLS.contains(Sysno::open), true); + assert_eq!(SYSCALLS.contains(Sysno::read), true); + assert_eq!(SYSCALLS.contains(Sysno::close), true); + assert_eq!(SYSCALLS.contains(Sysno::write), false); + } + + #[test] + fn test_contains() { + let set = SysnoSet::empty(); + assert_eq!(set.contains(Sysno::open), false); + assert_eq!(set.contains(Sysno::first()), false); + assert_eq!(set.contains(Sysno::last()), false); + + let set = SysnoSet::all(); + assert_eq!(set.contains(Sysno::open), true); + assert_eq!(set.contains(Sysno::first()), true); + assert_eq!(set.contains(Sysno::last()), true); + } + + #[test] + fn test_count() { + let mut set = SysnoSet::empty(); + assert_eq!(set.count(), 0); + set.insert(Sysno::open); + set.insert(Sysno::last()); + assert_eq!(set.count(), 2); + } + + #[test] + fn test_insert() { + let mut set = SysnoSet::empty(); + set.insert(Sysno::open); + set.insert(Sysno::read); + set.insert(Sysno::close); + assert!(set.contains(Sysno::open)); + assert!(set.contains(Sysno::read)); + assert!(set.contains(Sysno::close)); + assert_eq!(set.count(), 3); + } + + #[test] + fn test_remove() { + let mut set = SysnoSet::all(); + set.remove(Sysno::open); + assert!(!set.contains(Sysno::open)); + assert!(set.contains(Sysno::close)); + } + + #[test] + fn test_all() { + let mut all = SysnoSet::all(); + assert_eq!(all.count(), Sysno::count()); + + all.contains(Sysno::open); + all.contains(Sysno::first()); + all.contains(Sysno::last()); + + all.clear(); + + assert_eq!(all.count(), 0); + } + + #[test] + fn test_union() { + let a = SysnoSet::new(&[Sysno::read, Sysno::open, Sysno::close]); + let b = SysnoSet::new(&[Sysno::write, Sysno::open, Sysno::close]); + assert_eq!( + a.union(&b), + SysnoSet::new(&[ + Sysno::read, + Sysno::write, + Sysno::open, + Sysno::close + ]) + ); + } + + #[test] + fn test_intersection() { + let a = SysnoSet::new(&[Sysno::read, Sysno::open, Sysno::close]); + let b = SysnoSet::new(&[Sysno::write, Sysno::open, Sysno::close]); + assert_eq!( + a.intersection(&b), + SysnoSet::new(&[Sysno::open, Sysno::close]) + ); + } + + #[test] + fn test_difference() { + let a = SysnoSet::new(&[Sysno::read, Sysno::open, Sysno::close]); + let b = SysnoSet::new(&[Sysno::write, Sysno::open, Sysno::close]); + assert_eq!(a.difference(&b), SysnoSet::new(&[Sysno::read])); + } + + #[test] + fn test_symmetric_difference() { + let a = SysnoSet::new(&[Sysno::read, Sysno::open, Sysno::close]); + let b = SysnoSet::new(&[Sysno::write, Sysno::open, Sysno::close]); + assert_eq!( + a.symmetric_difference(&b), + SysnoSet::new(&[Sysno::read, Sysno::write,]) + ); + } + + #[test] + fn test_iter() { + let syscalls = &[Sysno::read, Sysno::open, Sysno::close]; + let set = SysnoSet::new(syscalls); + + assert_eq!(set.iter().collect::>(), syscalls); + } + + #[test] + fn test_iter_full() { + assert_eq!(SysnoSet::all().iter().count(), Sysno::count()); + } + + #[test] + fn test_iter_empty() { + assert_eq!(SysnoSet::empty().iter().collect::>(), &[]); + } +}