diff --git a/mfio-netfs/src/net/client.rs b/mfio-netfs/src/net/client.rs index fad249d..aaa5175 100644 --- a/mfio-netfs/src/net/client.rs +++ b/mfio-netfs/src/net/client.rs @@ -69,25 +69,25 @@ impl IntoOp for Write { } } -struct ShardedPacket> { - shards: BTreeMap, +struct ShardedPacket { + shards: BTreeMap, } -impl> From for ShardedPacket { +impl From for ShardedPacket { fn from(pkt: T) -> Self { Self { - shards: std::iter::once((0, pkt)).collect(), + shards: std::iter::once((Default::default(), pkt)).collect(), } } } -impl> ShardedPacket { +impl ShardedPacket { fn is_empty(&self) -> bool { // TODO: do this or self.len() == 0? self.shards.is_empty() } - fn extract(&mut self, idx: u64, len: u64) -> Option { + fn extract(&mut self, idx: T::Bounds, len: T::Bounds) -> Option { let (&shard_idx, _) = self.shards.range(..=idx).next_back()?; let mut shard = self.shards.remove(&shard_idx)?; diff --git a/mfio-rt/src/native/impls/io_uring/mod.rs b/mfio-rt/src/native/impls/io_uring/mod.rs index a97387e..23e2298 100644 --- a/mfio-rt/src/native/impls/io_uring/mod.rs +++ b/mfio-rt/src/native/impls/io_uring/mod.rs @@ -86,7 +86,7 @@ impl Operation { match self { Operation::FileRead(pkt, buf) => match res { Ok(read) if (read as u64) < pkt.len() => { - let (left, right) = pkt.split_at(read as u64); + let (left, right) = pkt.split_at(read); if let Err(pkt) = left { assert!(!buf.0.is_null()); let buf = unsafe { &*buf.0 }; @@ -111,7 +111,7 @@ impl Operation { }, Operation::FileWrite(pkt, _) => match res { Ok(read) if (read as u64) < pkt.len() => { - let (left, right) = pkt.split_at(read as u64); + let (left, right) = pkt.split_at(read); deferred_pkts.ok(left); right.error(io_err(State::Nop)); } diff --git a/mfio-rt/src/native/impls/iocp/mod.rs b/mfio-rt/src/native/impls/iocp/mod.rs index f858b62..fe607dc 100644 --- a/mfio-rt/src/native/impls/iocp/mod.rs +++ b/mfio-rt/src/native/impls/iocp/mod.rs @@ -274,7 +274,7 @@ impl OperationMode { match self { Self::FileRead(pkt, buf) => match res { Ok(read) if (read as u64) < pkt.len() => { - let (left, right) = pkt.split_at(read as u64); + let (left, right) = pkt.split_at(read); if let Err(pkt) = left { assert!(!buf.0.is_null()); let buf = unsafe { &*buf.0 }; @@ -299,7 +299,7 @@ impl OperationMode { }, Self::FileWrite(pkt, _) => match res { Ok(read) if (read as u64) < pkt.len() => { - let (left, right) = pkt.split_at(read as u64); + let (left, right) = pkt.split_at(read); deferred_pkts.ok(left); right.error(io_err(State::Nop)); } diff --git a/mfio-rt/src/native/impls/mio/file.rs b/mfio-rt/src/native/impls/mio/file.rs index 654e9c6..5f4c7be 100644 --- a/mfio-rt/src/native/impls/mio/file.rs +++ b/mfio-rt/src/native/impls/mio/file.rs @@ -191,7 +191,7 @@ impl FileInner { } break; } else if l > 0 { - let (a, b) = pkt.split_at(l as _); + let (a, b) = pkt.split_at(l); if let Err(pkt) = a { let _ = unsafe { pkt.transfer_data(self.tmp_buf.as_mut_ptr().cast()) @@ -270,7 +270,7 @@ impl FileInner { if l == len as usize { break; } else if l > 0 { - pkt = pkt.split_at(l as _).1; + pkt = pkt.split_at(l).1; } else { pkt.error(io_err(State::Nop)); break; diff --git a/mfio-rt/src/native/impls/thread.rs b/mfio-rt/src/native/impls/thread.rs index 71426a2..fc37939 100644 --- a/mfio-rt/src/native/impls/thread.rs +++ b/mfio-rt/src/native/impls/thread.rs @@ -268,7 +268,7 @@ impl From>> match buf.try_alloc() { Ok(mut alloced) => match copy_buf(&mut alloced[..]) { Ok(read) if (read as u64) < alloced.len() => { - let (_, right) = alloced.split_at(read as _); + let (_, right) = alloced.split_at(read); right.error(io_err(State::Nop)); } Err(e) => alloced.error(io_err(e.kind().into())), @@ -284,7 +284,7 @@ impl From>> } match copy_buf(&mut tmp_buf[..(buf.len() as usize)]) { Ok(read) if (read as u64) < buf.len() => { - let (left, right) = buf.split_at(read as u64); + let (left, right) = buf.split_at(read); let _ = unsafe { left.transfer_data(tmp_buf.as_ptr().cast()) }; @@ -321,7 +321,7 @@ impl From>> let alloced: ReadPacketObj = alloced; match file.write_at(&alloced[..], pos) { Ok(written) if (written as u64) < alloced.len() => { - let (_, right) = alloced.split_at(written as u64); + let (_, right) = alloced.split_at(written); right.error(io_err(State::Nop)); } Err(e) => alloced.error(io_err(e.kind().into())), @@ -342,7 +342,7 @@ impl From>> }; match file.write_at(tmp_buf, pos) { Ok(written) if (written as u64) < buf.len() => { - let (_, right) = buf.split_at(written as u64); + let (_, right) = buf.split_at(written); right.error(io_err(State::Nop)); } Err(e) => buf.error(io_err(e.kind().into())), diff --git a/mfio-rt/src/util/stream.rs b/mfio-rt/src/util/stream.rs index b2725f2..706b33e 100644 --- a/mfio-rt/src/util/stream.rs +++ b/mfio-rt/src/util/stream.rs @@ -265,7 +265,7 @@ impl StreamBuf { let spare_len = core::cmp::min(spare.len(), self.read_cached); if (spare_len as u64) < packet.len() { - let (a, b) = packet.split_at(spare_len as u64); + let (a, b) = packet.split_at(spare_len); let transferred = unsafe { a.transfer_data(spare.as_mut_ptr().cast()) }; self.read_buf.release(transferred.len() as usize); self.read_cached -= transferred.len() as usize; @@ -365,7 +365,7 @@ impl StreamBuf { let packet = if len as u64 >= packet.len() { packet } else { - let (a, b) = packet.split_at(len as u64); + let (a, b) = packet.split_at(len); self.read_ops2.push_front(b); a }; @@ -575,7 +575,7 @@ impl StreamBuf { self.read_ops1.push_front(Err(pkt)); break; } else if (spare_len as u64) < pkt.len() { - let (a, b) = pkt.split_at(spare_len as u64); + let (a, b) = pkt.split_at(spare_len); self.read_buf.reserve(spare_len); self.read_ops2.push_back(Err(a)); pkt = b; @@ -650,7 +650,7 @@ impl StreamBuf { *queued = Some(pkt); break; } else if (spare_len as u64) < pkt.len() { - let (a, b) = pkt.split_at(spare_len as u64); + let (a, b) = pkt.split_at(spare_len); let pkt = unsafe { a.transfer_data(spare.as_mut_ptr().cast()) }; self.write_buf.reserve(spare_len); transferred.push_back(pkt); diff --git a/mfio/Cargo.toml b/mfio/Cargo.toml index e334eca..c30b223 100644 --- a/mfio/Cargo.toml +++ b/mfio/Cargo.toml @@ -34,6 +34,8 @@ mfio-derive = { version = "0.1", path = "../mfio-derive" } http = { version = "0.2", optional = true } log = "0.4" rangemap = "1" +num = { version = "0.4", default-features = false } +atomic-traits = { version = "0.3", default-features = false } # This is only needed when std feature is disabled, but we can't do negative bounds spin = "0.9" diff --git a/mfio/src/io/packet/mod.rs b/mfio/src/io/packet/mod.rs index 43f5920..b70bdae 100644 --- a/mfio/src/io/packet/mod.rs +++ b/mfio/src/io/packet/mod.rs @@ -2,6 +2,7 @@ use crate::std_prelude::*; use super::OpaqueStore; use crate::error::Error; +use atomic_traits::{fetch::Min, Atomic}; pub use cglue::task::{CWaker, FastCWaker}; use core::cell::UnsafeCell; use core::future::Future; @@ -12,6 +13,7 @@ use core::num::NonZeroI32; use core::pin::Pin; use core::sync::atomic::*; use core::task::{Context, Poll}; +use num::{Integer, NumCast, One, Saturating, ToPrimitive}; use rangemap::RangeSet; use tarc::BaseArc; @@ -649,7 +651,7 @@ pub struct Packet { /// This value is initialized to !0, and upon each errored packet segment, is minned /// atomically. Upon I/O is complete, this allows the caller to check for the size of the /// contiguous memory region being successfully processed without gaps. - error_clamp: AtomicU64, + error_clamp: ::Atomic, /// Note that this may be raced against so it should not be relied as "the minimum error". min_error: AtomicI32, // We need miri to treat packets magically. Without marking this type as !Unpin, miri would @@ -706,7 +708,7 @@ impl Packet { (self.rc_and_waker.acquire_rc()) as usize } - unsafe fn on_output(&self, error: Option<(u64, NonZeroI32)>) -> Option { + unsafe fn on_output(&self, error: Option<(Perms::Bounds, NonZeroI32)>) -> Option { if let Some((start, error)) = error { if self.error_clamp.fetch_min(start, Ordering::AcqRel) > start { self.min_error.store(error.into(), Ordering::Relaxed); @@ -745,7 +747,8 @@ impl Packet { /// /// This function is safe to call only when all packet operations have concluded. pub unsafe fn reset_err(&self) { - self.error_clamp.store(!0u64, Ordering::Release); + self.error_clamp + .store(!::default(), Ordering::Release); self.min_error.store(0, Ordering::Release); } @@ -759,7 +762,7 @@ impl Packet { Packet { vtbl, rc_and_waker: Default::default(), - error_clamp: (!0u64).into(), + error_clamp: (!::default()).into(), min_error: 0.into(), _phantom: PhantomPinned, } @@ -842,7 +845,7 @@ impl Packet { core::slice::from_raw_parts( self.simple_data_ptr(), core::cmp::min( - self.error_clamp.load(Ordering::Acquire) as usize, + self.error_clamp.load(Ordering::Acquire).to_usize().unwrap(), *Self::simple_len(self), ), ) @@ -887,13 +890,13 @@ impl Packet { /// /// If there was an error case, this function will return the length of the first contiguous /// segment. However, if the packet encountered no error cases, `!0` will be returned. - pub fn error_clamp(&self) -> u64 { + pub fn error_clamp(&self) -> Perms::Bounds { self.error_clamp.load(Ordering::Relaxed) } /// Returns [`min_error`](Self::min_error) if 0 leading bytes were processed. pub fn err_on_zero(&self) -> Result<(), Error> { - if self.error_clamp() > 0 { + if self.error_clamp() > Default::default() { Ok(()) } else { Err(self.min_error().expect("No error when error_clamp is 0")) @@ -1398,7 +1401,7 @@ pub type TransferDataFn = for<'a> unsafe extern "C" fn( ); /// Retrieves total length of the packet. -pub type LenFn = unsafe extern "C" fn(packet: &Packet) -> u64; +pub type LenFn = unsafe extern "C" fn(packet: &Packet) -> ::Bounds; /// Packet that may be alloced. /// @@ -1415,16 +1418,19 @@ pub trait PacketPerms: 'static + core::fmt::Debug + Clone + Copy { type DataType: Clone + Copy + core::fmt::Debug; type ReverseDataType: Clone + Copy + core::fmt::Debug; type Alloced: AllocatedPacket; + type Bounds: NumBounds; /// Returns vtable function for getting packet length. fn len_fn(&self) -> LenFn; /// Returns packet length. - fn len(packet: &Packet) -> u64 { + fn len(packet: &Packet) -> Self::Bounds { if let Some(vtbl) = packet.vtbl.vtbl() { unsafe { (vtbl.len_fn())(packet) } } else { - unsafe { *Packet::simple_len(packet) as u64 } + unsafe { + NumCast::from(*Packet::simple_len(packet)).expect("Packet larger than bounds") + } } } @@ -1455,7 +1461,8 @@ pub trait PacketPerms: 'static + core::fmt::Debug + Clone + Copy { .view .pkt() .simple_data_ptr() - .add(packet.view.start as usize) + // This should never panic + .add(packet.view.start.to_usize().unwrap()) }; if data.align_offset(alignment) == 0 { Ok(unsafe { Self::alloced_simple(packet) }) @@ -1497,26 +1504,27 @@ pub trait PacketPerms: 'static + core::fmt::Debug + Clone + Copy { /// of data between the 2 buffers. #[repr(C)] #[derive(Clone, Copy)] -pub struct ReadWrite { - pub len: unsafe extern "C" fn(&Packet) -> u64, +pub struct ReadWrite { + pub len: unsafe extern "C" fn(&Packet) -> Bounds, pub get_mut: for<'a> unsafe extern "C" fn( &mut ManuallyDrop>, usize, - &mut MaybeUninit, + &mut MaybeUninit>, ) -> bool, pub transfer_data: for<'a, 'b> unsafe extern "C" fn(&'a mut PacketView, *mut ()), } -impl core::fmt::Debug for ReadWrite { +impl core::fmt::Debug for ReadWrite { fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result { write!(fmt, "{:?}", self.get_mut as *const ()) } } -impl PacketPerms for ReadWrite { +impl PacketPerms for ReadWrite { type DataType = *mut (); type ReverseDataType = *mut (); - type Alloced = ReadWritePacketObj; + type Alloced = ReadWritePacketObj; + type Bounds = Bounds; fn len_fn(&self) -> LenFn { self.len @@ -1533,7 +1541,7 @@ impl PacketPerms for ReadWrite { unsafe fn alloced_simple(packet: BoundPacketView) -> Self::Alloced { let data = packet.view.pkt().simple_data_ptr().cast_mut(); ReadWritePacketObj { - alloced_packet: unsafe { data.add(packet.view.start as usize) }, + alloced_packet: unsafe { data.add(packet.view.start.to_usize().unwrap()) }, buffer: packet, } } @@ -1543,8 +1551,8 @@ impl PacketPerms for ReadWrite { // TODO: does this operation even make sense? core::ptr::swap_nonoverlapping( data.cast(), - dst.add(view.start as usize), - view.len() as usize, + dst.add(view.start.to_usize().unwrap()), + view.len().to_usize().unwrap(), ); } } @@ -1554,26 +1562,27 @@ impl PacketPerms for ReadWrite { /// This implies the packet is writeable and may not have valid data beforehand. #[repr(C)] #[derive(Clone, Copy)] -pub struct Write { - pub len: unsafe extern "C" fn(&Packet) -> u64, +pub struct Write { + pub len: unsafe extern "C" fn(&Packet) -> Bounds, pub get_mut: for<'a> unsafe extern "C" fn( &mut ManuallyDrop>, usize, - &mut MaybeUninit, + &mut MaybeUninit>, ) -> bool, pub transfer_data: for<'a, 'b> unsafe extern "C" fn(&'a mut PacketView, *const ()), } -impl core::fmt::Debug for Write { +impl core::fmt::Debug for Write { fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result { write!(fmt, "{:?}", self.get_mut as *const ()) } } -impl PacketPerms for Write { +impl PacketPerms for Write { type DataType = *mut (); type ReverseDataType = *const (); - type Alloced = WritePacketObj; + type Alloced = WritePacketObj; + type Bounds = Bounds; fn len_fn(&self) -> LenFn { self.len @@ -1595,7 +1604,7 @@ impl PacketPerms for Write { .cast_mut() .cast::>(); WritePacketObj { - alloced_packet: unsafe { data.add(packet.view.start as usize) }, + alloced_packet: unsafe { data.add(packet.view.start.to_usize().unwrap()) }, buffer: packet, } } @@ -1604,8 +1613,8 @@ impl PacketPerms for Write { let dst = Packet::simple_data_ptr_mut(view.pkt_mut()); core::ptr::copy( data.cast(), - dst.add(view.start as usize), - view.len() as usize, + dst.add(view.start.to_usize().unwrap()), + view.len().to_usize().unwrap(), ); } } @@ -1615,26 +1624,27 @@ impl PacketPerms for Write { /// This implies this packet contains valid data and it can be read. #[repr(C)] #[derive(Clone, Copy)] -pub struct Read { - pub len: unsafe extern "C" fn(&Packet) -> u64, +pub struct Read { + pub len: unsafe extern "C" fn(&Packet) -> Bounds, pub get: unsafe extern "C" fn( &mut ManuallyDrop>, usize, - &mut MaybeUninit, + &mut MaybeUninit>, ) -> bool, pub transfer_data: for<'a> unsafe extern "C" fn(&'a mut PacketView, *mut ()), } -impl core::fmt::Debug for Read { +impl core::fmt::Debug for Read { fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result { write!(fmt, "{:?}", self.get as *const ()) } } -impl PacketPerms for Read { +impl PacketPerms for Read { type DataType = *const (); type ReverseDataType = *mut (); - type Alloced = ReadPacketObj; + type Alloced = ReadPacketObj; + type Bounds = Bounds; fn len_fn(&self) -> LenFn { self.len @@ -1651,7 +1661,7 @@ impl PacketPerms for Read { unsafe fn alloced_simple(packet: BoundPacketView) -> Self::Alloced { let data = packet.view.pkt().simple_data_ptr().cast::(); ReadPacketObj { - alloced_packet: unsafe { data.add(packet.view.start as usize) }, + alloced_packet: unsafe { data.add(packet.view.start.to_usize().unwrap()) }, buffer: packet, } } @@ -1660,30 +1670,65 @@ impl PacketPerms for Read { let src = view.pkt().simple_data_ptr(); core::ptr::copy( src, - data.cast::().add(view.start as usize), - view.len() as usize, + data.cast::().add(view.start.to_usize().unwrap()), + view.len().to_usize().unwrap(), ); } } +pub trait NumBounds: + Integer + + NumCast + + Copy + + Send + + Default + + core::ops::Not + + Saturating + + core::fmt::Debug + + Into + + 'static +{ + type Atomic: atomic_traits::Atomic + atomic_traits::NumOps + core::fmt::Debug; +} + +macro_rules! num_bounds { + ($($ty1:ident => $ty2:ident),*) => { + $( + impl NumBounds for $ty1 { + type Atomic = core::sync::atomic::$ty2; + } + )* + } +} + +num_bounds!(u8 => AtomicU8, u16 => AtomicU16, u32 => AtomicU32, u64 => AtomicU64, usize => AtomicUsize); + +//impl NumBounds for T {} + /// Objects that can be split. /// /// This trait enables splitting objects into non-overlapping parts. -pub trait Splittable: Sized { +pub trait Splittable: Sized { + type Bounds: NumBounds; + /// Splits an object at given position. /// /// # Panics /// /// This function may panic if len is outside the bounds of the given object. - fn split_at(self, len: T) -> (Self, Self); - fn len(&self) -> T; + fn split_at(self, len: impl NumBounds) -> (Self, Self); + fn len(&self) -> Self::Bounds; fn is_empty(&self) -> bool { self.len() == Default::default() } } -impl, B: Splittable> Splittable for Result { - fn split_at(self, len: T) -> (Self, Self) { +impl, B: Splittable> Splittable + for Result +{ + type Bounds = T; + + fn split_at(self, len: impl NumBounds) -> (Self, Self) { match self { Ok(v) => { let (a, b) = v.split_at(len); @@ -1724,7 +1769,7 @@ impl Errorable for Result { /// Packet which has been allocated. /// /// Allocated packets expose direct access to the underlying buffer. -pub trait AllocatedPacket: Splittable + Errorable { +pub trait AllocatedPacket: Splittable + Errorable { type Perms: PacketPerms; type Pointer: Copy; @@ -1733,13 +1778,16 @@ pub trait AllocatedPacket: Splittable + Errorable { /// Represents a simple allocated packet with write permissions. #[repr(C)] -pub struct ReadWritePacketObj { +pub struct ReadWritePacketObj { alloced_packet: *mut u8, - buffer: BoundPacketView, + buffer: BoundPacketView>, } -impl Splittable for ReadWritePacketObj { - fn split_at(self, len: u64) -> (Self, Self) { +impl Splittable for ReadWritePacketObj { + type Bounds = Bounds; + + fn split_at(self, len: impl NumBounds) -> (Self, Self) { + let len_usize = len.to_usize().expect("Input out of range"); let (b1, b2) = self.buffer.split_at(len); ( @@ -1748,25 +1796,25 @@ impl Splittable for ReadWritePacketObj { buffer: b1, }, Self { - alloced_packet: unsafe { self.alloced_packet.add(len as usize) }, + alloced_packet: unsafe { self.alloced_packet.add(len_usize) }, buffer: b2, }, ) } - fn len(&self) -> u64 { + fn len(&self) -> Bounds { self.buffer.view.len() } } -impl Errorable for ReadWritePacketObj { +impl Errorable for ReadWritePacketObj { fn error(self, err: Error) { self.buffer.error(err) } } -impl AllocatedPacket for ReadWritePacketObj { - type Perms = ReadWrite; +impl AllocatedPacket for ReadWritePacketObj { + type Perms = ReadWrite; type Pointer = *mut u8; fn as_ptr(&self) -> Self::Pointer { @@ -1774,21 +1822,29 @@ impl AllocatedPacket for ReadWritePacketObj { } } -unsafe impl Send for ReadWritePacketObj {} -unsafe impl Sync for ReadWritePacketObj {} +unsafe impl Send for ReadWritePacketObj {} +unsafe impl Sync for ReadWritePacketObj {} -impl core::ops::Deref for ReadWritePacketObj { +impl core::ops::Deref for ReadWritePacketObj { type Target = [u8]; fn deref(&self) -> &Self::Target { - unsafe { core::slice::from_raw_parts(self.alloced_packet, self.buffer.view.len() as usize) } + unsafe { + core::slice::from_raw_parts( + self.alloced_packet, + self.buffer.view.len().to_usize().unwrap(), + ) + } } } -impl core::ops::DerefMut for ReadWritePacketObj { +impl core::ops::DerefMut for ReadWritePacketObj { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { - core::slice::from_raw_parts_mut(self.alloced_packet, self.buffer.view.len() as usize) + core::slice::from_raw_parts_mut( + self.alloced_packet, + self.buffer.view.len().to_usize().unwrap(), + ) } } } @@ -1797,13 +1853,16 @@ impl core::ops::DerefMut for ReadWritePacketObj { /// /// The data inside may not be initialized, therefore, this packet should only be written to. #[repr(C)] -pub struct WritePacketObj { +pub struct WritePacketObj { alloced_packet: *mut MaybeUninit, - buffer: BoundPacketView, + buffer: BoundPacketView>, } -impl Splittable for WritePacketObj { - fn split_at(self, len: u64) -> (Self, Self) { +impl Splittable for WritePacketObj { + type Bounds = Bounds; + + fn split_at(self, len: impl NumBounds) -> (Self, Self) { + let len_usize = len.to_usize().expect("Input out of range"); let (b1, b2) = self.buffer.split_at(len); ( @@ -1812,25 +1871,25 @@ impl Splittable for WritePacketObj { buffer: b1, }, Self { - alloced_packet: unsafe { self.alloced_packet.add(len as usize) }, + alloced_packet: unsafe { self.alloced_packet.add(len_usize) }, buffer: b2, }, ) } - fn len(&self) -> u64 { + fn len(&self) -> Bounds { self.buffer.view.len() } } -impl Errorable for WritePacketObj { +impl Errorable for WritePacketObj { fn error(self, err: Error) { self.buffer.error(err) } } -impl AllocatedPacket for WritePacketObj { - type Perms = Write; +impl AllocatedPacket for WritePacketObj { + type Perms = Write; type Pointer = *mut MaybeUninit; fn as_ptr(&self) -> Self::Pointer { @@ -1838,34 +1897,45 @@ impl AllocatedPacket for WritePacketObj { } } -unsafe impl Send for WritePacketObj {} -unsafe impl Sync for WritePacketObj {} +unsafe impl Send for WritePacketObj {} +unsafe impl Sync for WritePacketObj {} -impl core::ops::Deref for WritePacketObj { +impl core::ops::Deref for WritePacketObj { type Target = [MaybeUninit]; fn deref(&self) -> &Self::Target { - unsafe { core::slice::from_raw_parts(self.alloced_packet, self.buffer.view.len() as usize) } + unsafe { + core::slice::from_raw_parts( + self.alloced_packet, + self.buffer.view.len().to_usize().unwrap(), + ) + } } } -impl core::ops::DerefMut for WritePacketObj { +impl core::ops::DerefMut for WritePacketObj { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { - core::slice::from_raw_parts_mut(self.alloced_packet, self.buffer.view.len() as usize) + core::slice::from_raw_parts_mut( + self.alloced_packet, + self.buffer.view.len().to_usize().unwrap(), + ) } } } /// Represents a simple allocated packet with read permissions. #[repr(C)] -pub struct ReadPacketObj { +pub struct ReadPacketObj { alloced_packet: *const u8, - buffer: BoundPacketView, + buffer: BoundPacketView>, } -impl Splittable for ReadPacketObj { - fn split_at(self, len: u64) -> (Self, Self) { +impl Splittable for ReadPacketObj { + type Bounds = Bounds; + + fn split_at(self, len: impl NumBounds) -> (Self, Self) { + let len_usize = len.to_usize().expect("Input out of range"); let (b1, b2) = self.buffer.split_at(len); ( @@ -1874,25 +1944,25 @@ impl Splittable for ReadPacketObj { buffer: b1, }, Self { - alloced_packet: unsafe { self.alloced_packet.add(len as usize) }, + alloced_packet: unsafe { self.alloced_packet.add(len_usize) }, buffer: b2, }, ) } - fn len(&self) -> u64 { + fn len(&self) -> Bounds { self.buffer.view.len() } } -impl Errorable for ReadPacketObj { +impl Errorable for ReadPacketObj { fn error(self, err: Error) { self.buffer.error(err) } } -impl AllocatedPacket for ReadPacketObj { - type Perms = Read; +impl AllocatedPacket for ReadPacketObj { + type Perms = Read; type Pointer = *const u8; fn as_ptr(&self) -> Self::Pointer { @@ -1900,14 +1970,19 @@ impl AllocatedPacket for ReadPacketObj { } } -unsafe impl Send for ReadPacketObj {} -unsafe impl Sync for ReadPacketObj {} +unsafe impl Send for ReadPacketObj {} +unsafe impl Sync for ReadPacketObj {} -impl core::ops::Deref for ReadPacketObj { +impl core::ops::Deref for ReadPacketObj { type Target = [u8]; fn deref(&self) -> &Self::Target { - unsafe { core::slice::from_raw_parts(self.alloced_packet, self.buffer.view.len() as usize) } + unsafe { + core::slice::from_raw_parts( + self.alloced_packet, + self.buffer.view.len().to_usize().unwrap(), + ) + } } } @@ -1921,14 +1996,16 @@ impl core::ops::Deref for ReadPacketObj { #[must_use = "please handle point of drop intentionally"] pub struct TransferredPacket(BoundPacketView); -impl Splittable for TransferredPacket { - fn split_at(self, len: u64) -> (Self, Self) { +impl Splittable for TransferredPacket { + type Bounds = T::Bounds; + + fn split_at(self, len: impl NumBounds) -> (Self, Self) { let (b1, b2) = self.0.split_at(len); (Self(b1), Self(b2)) } - fn len(&self) -> u64 { + fn len(&self) -> T::Bounds { self.0.view.len() } } @@ -2158,7 +2235,7 @@ downgrade_packet!(perms::READ_WRITE, perms::WRITE); /// /// `ReboundPacket` helps in this situation, because it facilitates this rebind process safely. pub struct ReboundPacket { - ranges: RangeSet, + ranges: RangeSet, orig: ManuallyDrop>, unbound: AtomicBool, } @@ -2206,12 +2283,12 @@ impl ReboundPacket { /// # Panics /// /// Whenever an invalid range is provided. - pub fn range_result(&mut self, start: u64, len: u64, err: Option) { + pub fn range_result(&mut self, start: T::Bounds, len: T::Bounds, err: Option) { let range = start..(start + len); let mut o = self.ranges.overlapping(&range); let o = o.next().unwrap(); assert!(o.contains(&start)); - assert!(o.contains(&(start + len.saturating_sub(1)))); + assert!(o.contains(&(start + len.saturating_sub(One::one())))); self.ranges.remove(range); // SAFETY: we verified uniqueness of the range. @@ -2229,7 +2306,7 @@ impl ReboundPacket { } } - pub fn ranges(&self) -> &RangeSet { + pub fn ranges(&self) -> &RangeSet { &self.ranges } } diff --git a/mfio/src/io/packet/view.rs b/mfio/src/io/packet/view.rs index 8ec2554..f836205 100644 --- a/mfio/src/io/packet/view.rs +++ b/mfio/src/io/packet/view.rs @@ -1,5 +1,6 @@ use super::{ - Errorable, MaybeAlloced, OutputRef, Packet, PacketPerms, Splittable, TransferredPacket, + Errorable, MaybeAlloced, NumBounds, OutputRef, Packet, PacketPerms, Splittable, + TransferredPacket, }; use crate::error::Error; use cglue::prelude::v1::*; @@ -7,6 +8,7 @@ use core::marker::PhantomData; use core::mem::{ManuallyDrop, MaybeUninit}; use core::ptr::NonNull; use core::sync::atomic::*; +use num::{NumCast, ToPrimitive}; use tarc::BaseArc; /// Bound Packet View. @@ -29,11 +31,13 @@ impl Drop for BoundPacketView { } } -impl Splittable for BoundPacketView { - fn split_at(self, len: u64) -> (Self, Self) { +impl Splittable for BoundPacketView { + type Bounds = T::Bounds; + + fn split_at(self, len: impl NumBounds) -> (Self, Self) { let mut this = ManuallyDrop::new(self); let view = unsafe { ManuallyDrop::take(&mut this.view) }; - let (v1, v2) = view.split_local(len); + let (v1, v2) = view.split_local(NumCast::from(len).expect("Input bound out of range")); //this.id.size.fetch_add(1, Ordering::Release); @@ -61,7 +65,7 @@ impl Splittable for BoundPacketView { ) } - fn len(&self) -> u64 { + fn len(&self) -> T::Bounds { self.view.len() } } @@ -116,7 +120,7 @@ impl BoundPacketView { /// /// 3. Before the last extracted packet is dropped, `self` must be released with /// [`BoundPacketView::forget`]. - pub unsafe fn extract_packet(&self, pos: u64, len: u64) -> Self { + pub unsafe fn extract_packet(&self, pos: Perms::Bounds, len: Perms::Bounds) -> Self { let b = self.view.extract_packet(pos, len); Self { @@ -162,7 +166,8 @@ impl BoundPacketView { /// /// Please see [`PacketView::extract_packet`] documentation for details. pub unsafe fn unbound(&self) -> PacketView<'static, Perms> { - self.view.extract_packet(0, self.view.len()) + self.view + .extract_packet(Default::default(), self.view.len()) } /// Transfers data between the packet and the `input`. @@ -197,7 +202,7 @@ impl BoundPacketView { self.view .pkt() .simple_data_ptr() - .add(self.view.start as usize) + .add(self.view.start.to_usize().unwrap()) } } } @@ -213,8 +218,8 @@ pub struct PacketView<'a, Perms: PacketPerms> { pub(crate) pkt: NonNull>, /// Right-most bit indicates whether packet is an Arc or a ref. The rest is user-defined. pub(crate) tag: u64, - pub(crate) start: u64, - pub(crate) end: u64, + pub(crate) start: Perms::Bounds, + pub(crate) end: Perms::Bounds, phantom: PhantomData<&'a Packet>, } @@ -264,7 +269,7 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> { Self { pkt, tag: tag << 1 | 1, - start: 0, + start: Default::default(), end, phantom: PhantomData, } @@ -335,7 +340,7 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> { Self { pkt: NonNull::new((pkt as *const Packet).cast_mut()).unwrap(), tag: tag << 1, - start: 0, + start: Default::default(), end, phantom: PhantomData, } @@ -367,27 +372,27 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> { } /// Returns the length this packet view covers. - pub fn len(&self) -> u64 { + pub fn len(&self) -> Perms::Bounds { self.end - self.start } /// Returns the starting offset within the packet that this view represents. - pub fn start(&self) -> u64 { + pub fn start(&self) -> Perms::Bounds { self.start } /// Returns the ending position (+1) within the packet that this view represents. - pub fn end(&self) -> u64 { + pub fn end(&self) -> Perms::Bounds { self.end } /// Returns true if packet length is 0. pub fn is_empty(&self) -> bool { - self.len() == 0 + self.len() == Default::default() } /// Split the packet view into 2 at given offset. - pub fn split_local(self, pos: u64) -> (Self, Self) { + pub fn split_local(self, pos: Perms::Bounds) -> (Self, Self) { assert!(pos < self.len()); // TODO: maybe relaxed is enough here? @@ -434,7 +439,7 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> { /// # Safety /// /// Please see [`BoundPacketView::extract_packet`] documentation for details. - pub unsafe fn extract_packet(&self, offset: u64, len: u64) -> Self { + pub unsafe fn extract_packet(&self, offset: Perms::Bounds, len: Perms::Bounds) -> Self { self.pkt().rc_and_waker.inc_rc(); let Self { @@ -452,8 +457,8 @@ impl<'a, Perms: PacketPerms> PacketView<'a, Perms> { Self { pkt: *pkt, tag: *tag, - start: start + offset, - end: start + offset + len, + start: *start + offset, + end: *start + offset + len, phantom: PhantomData, } } diff --git a/mfio/src/stdeq.rs b/mfio/src/stdeq.rs index 1051510..0b007a8 100644 --- a/mfio/src/stdeq.rs +++ b/mfio/src/stdeq.rs @@ -11,6 +11,7 @@ use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; use mfio_derive::*; +use num::ToPrimitive; pub trait StreamPos { fn set_pos(&self, pos: Param); @@ -174,7 +175,7 @@ impl< let hdr = <>::Target as OpaqueStore>::stack_hdr(&pkt); // TODO: put this after error checking Obj::sync_back(hdr, this.sync.take().unwrap()); - let progressed = core::cmp::min(hdr.error_clamp() as usize, this.len); + let progressed = core::cmp::min(hdr.error_clamp().to_usize().unwrap_or(!0), this.len); Param::add_io_pos(this.io, progressed); // TODO: actual error checking Ok(progressed)