Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark internal functions and traits unsafe to reflect preconditions #111609

Merged
merged 1 commit into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions library/alloc/src/vec/in_place_collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ where
)
};

let len = SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end);
// SAFETY: `dst_buf` and `dst_end` are the start and end of the buffer.
let len = unsafe { SpecInPlaceCollect::collect_in_place(&mut iterator, dst_buf, dst_end) };

let src = unsafe { iterator.as_inner().as_into_iter() };
// check if SourceIter contract was upheld
Expand Down Expand Up @@ -239,15 +240,15 @@ trait SpecInPlaceCollect<T, I>: Iterator<Item = T> {
/// `Iterator::__iterator_get_unchecked` calls with a `TrustedRandomAccessNoCoerce` bound
/// on `I` which means the caller of this method must take the safety conditions
/// of that trait into consideration.
fn collect_in_place(&mut self, dst: *mut T, end: *const T) -> usize;
unsafe fn collect_in_place(&mut self, dst: *mut T, end: *const T) -> usize;
}

impl<T, I> SpecInPlaceCollect<T, I> for I
where
I: Iterator<Item = T>,
{
#[inline]
default fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
default unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
// use try-fold since
// - it vectorizes better for some iterator adapters
// - unlike most internal iteration methods, it only takes a &mut self
Expand All @@ -265,7 +266,7 @@ where
I: Iterator<Item = T> + TrustedRandomAccessNoCoerce,
{
#[inline]
fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
unsafe fn collect_in_place(&mut self, dst_buf: *mut T, end: *const T) -> usize {
let len = self.size();
let mut drop_guard = InPlaceDrop { inner: dst_buf, dst: dst_buf };
for i in 0..len {
Expand Down
12 changes: 8 additions & 4 deletions library/core/src/fmt/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ where
&mut buf,
&mut parts,
);
fmt.pad_formatted_parts(&formatted)
// SAFETY: `to_exact_fixed_str` and `format_exact` produce only ASCII characters.
unsafe { fmt.pad_formatted_parts(&formatted) }
}

// Don't inline this so callers that call both this and the above won't wind
Expand All @@ -71,7 +72,8 @@ where
&mut buf,
&mut parts,
);
fmt.pad_formatted_parts(&formatted)
// SAFETY: `to_shortest_str` and `format_shortest` produce only ASCII characters.
unsafe { fmt.pad_formatted_parts(&formatted) }
}

fn float_to_decimal_display<T>(fmt: &mut Formatter<'_>, num: &T) -> Result
Expand Down Expand Up @@ -116,7 +118,8 @@ where
&mut buf,
&mut parts,
);
fmt.pad_formatted_parts(&formatted)
// SAFETY: `to_exact_exp_str` and `format_exact` produce only ASCII characters.
unsafe { fmt.pad_formatted_parts(&formatted) }
}

// Don't inline this so callers that call both this and the above won't wind
Expand All @@ -143,7 +146,8 @@ where
&mut buf,
&mut parts,
);
fmt.pad_formatted_parts(&formatted)
// SAFETY: `to_shortest_exp_str` and `format_shortest` produce only ASCII characters.
unsafe { fmt.pad_formatted_parts(&formatted) }
}

// Common code of floating point LowerExp and UpperExp.
Expand Down
42 changes: 27 additions & 15 deletions library/core/src/fmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,11 @@ impl<'a> Formatter<'a> {
/// Takes the formatted parts and applies the padding.
/// Assumes that the caller already has rendered the parts with required precision,
/// so that `self.precision` can be ignored.
fn pad_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
///
/// # Safety
///
/// Any `numfmt::Part::Copy` parts in `formatted` must contain valid UTF-8.
unsafe fn pad_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
if let Some(mut width) = self.width {
// for the sign-aware zero padding, we render the sign first and
// behave as if we had no sign from the beginning.
Expand All @@ -1438,31 +1442,35 @@ impl<'a> Formatter<'a> {
let len = formatted.len();
let ret = if width <= len {
// no padding
self.write_formatted_parts(&formatted)
// SAFETY: Per the precondition.
unsafe { self.write_formatted_parts(&formatted) }
} else {
let post_padding = self.padding(width - len, Alignment::Right)?;
self.write_formatted_parts(&formatted)?;
// SAFETY: Per the precondition.
unsafe {
self.write_formatted_parts(&formatted)?;
}
post_padding.write(self)
};
self.fill = old_fill;
self.align = old_align;
ret
} else {
// this is the common case and we take a shortcut
self.write_formatted_parts(formatted)
// SAFETY: Per the precondition.
unsafe { self.write_formatted_parts(formatted) }
}
}

fn write_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
fn write_bytes(buf: &mut dyn Write, s: &[u8]) -> Result {
/// # Safety
///
/// Any `numfmt::Part::Copy` parts in `formatted` must contain valid UTF-8.
unsafe fn write_formatted_parts(&mut self, formatted: &numfmt::Formatted<'_>) -> Result {
unsafe fn write_bytes(buf: &mut dyn Write, s: &[u8]) -> Result {
// SAFETY: This is used for `numfmt::Part::Num` and `numfmt::Part::Copy`.
// It's safe to use for `numfmt::Part::Num` since every char `c` is between
// `b'0'` and `b'9'`, which means `s` is valid UTF-8.
// It's also probably safe in practice to use for `numfmt::Part::Copy(buf)`
// since `buf` should be plain ASCII, but it's possible for someone to pass
// in a bad value for `buf` into `numfmt::to_shortest_str` since it is a
// public function.
// FIXME: Determine whether this could result in UB.
// `b'0'` and `b'9'`, which means `s` is valid UTF-8. It's safe to use for
// `numfmt::Part::Copy` due to this function's precondition.
buf.write_str(unsafe { str::from_utf8_unchecked(s) })
}

Expand All @@ -1489,11 +1497,15 @@ impl<'a> Formatter<'a> {
*c = b'0' + (v % 10) as u8;
v /= 10;
}
write_bytes(self.buf, &s[..len])?;
// SAFETY: Per the precondition.
unsafe {
write_bytes(self.buf, &s[..len])?;
}
}
numfmt::Part::Copy(buf) => {
// SAFETY: Per the precondition.
numfmt::Part::Copy(buf) => unsafe {
write_bytes(self.buf, buf)?;
}
},
}
}
Ok(())
Expand Down
15 changes: 10 additions & 5 deletions library/core/src/fmt/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ impl_int! { i8 i16 i32 i64 i128 isize }
impl_uint! { u8 u16 u32 u64 u128 usize }

/// A type that represents a specific radix
///
/// # Safety
///
/// `digit` must return an ASCII character.
#[doc(hidden)]
trait GenericRadix: Sized {
unsafe trait GenericRadix: Sized {
/// The number of digits.
const BASE: u8;

Expand Down Expand Up @@ -129,7 +133,7 @@ struct UpperHex;

macro_rules! radix {
($T:ident, $base:expr, $prefix:expr, $($x:pat => $conv:expr),+) => {
impl GenericRadix for $T {
unsafe impl GenericRadix for $T {
const BASE: u8 = $base;
const PREFIX: &'static str = $prefix;
fn digit(x: u8) -> u8 {
Expand Down Expand Up @@ -407,7 +411,7 @@ macro_rules! impl_Exp {
let parts = &[
numfmt::Part::Copy(buf_slice),
numfmt::Part::Zero(added_precision),
numfmt::Part::Copy(exp_slice)
numfmt::Part::Copy(exp_slice),
];
let sign = if !is_nonnegative {
"-"
Expand All @@ -416,8 +420,9 @@ macro_rules! impl_Exp {
} else {
""
};
let formatted = numfmt::Formatted{sign, parts};
f.pad_formatted_parts(&formatted)
let formatted = numfmt::Formatted { sign, parts };
// SAFETY: `buf_slice` and `exp_slice` contain only ASCII characters.
unsafe { f.pad_formatted_parts(&formatted) }
}

$(
Expand Down
11 changes: 7 additions & 4 deletions library/std/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,8 +733,9 @@ impl<'a> Components<'a> {
}
}

// parse a given byte sequence into the corresponding path component
fn parse_single_component<'b>(&self, comp: &'b [u8]) -> Option<Component<'b>> {
// parse a given byte sequence following the OsStr encoding into the
// corresponding path component
unsafe fn parse_single_component<'b>(&self, comp: &'b [u8]) -> Option<Component<'b>> {
match comp {
b"." if self.prefix_verbatim() => Some(Component::CurDir),
b"." => None, // . components are normalized away, except at
Expand All @@ -754,7 +755,8 @@ impl<'a> Components<'a> {
None => (0, self.path),
Some(i) => (1, &self.path[..i]),
};
(comp.len() + extra, self.parse_single_component(comp))
// SAFETY: `comp` is a valid substring, since it is split on a separator.
(comp.len() + extra, unsafe { self.parse_single_component(comp) })
}

// parse a component from the right, saying how many bytes to consume to
Expand All @@ -766,7 +768,8 @@ impl<'a> Components<'a> {
None => (0, &self.path[start..]),
Some(i) => (1, &self.path[start + i + 1..]),
};
(comp.len() + extra, self.parse_single_component(comp))
// SAFETY: `comp` is a valid substring, since it is split on a separator.
(comp.len() + extra, unsafe { self.parse_single_component(comp) })
}

// trim away repeated separators (i.e., empty components) on the left
Expand Down