Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small optimization for integers Display implementation #128204

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 131 additions & 81 deletions library/core/src/fmt/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,75 +208,119 @@ static DEC_DIGITS_LUT: &[u8; 200] = b"0001020304050607080910111213141516171819\
8081828384858687888990919293949596979899";

macro_rules! impl_Display {
($($t:ident),* as $u:ident via $conv_fn:ident named $name:ident) => {
#[cfg(not(feature = "optimize_for_size"))]
fn $name(mut n: $u, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// 2^128 is about 3*10^38, so 39 gives an extra byte of space
let mut buf = [MaybeUninit::<u8>::uninit(); 39];
let mut curr = buf.len();
let buf_ptr = MaybeUninit::slice_as_mut_ptr(&mut buf);
let lut_ptr = DEC_DIGITS_LUT.as_ptr();
($($t:ident $(as $positive:ident)? named $name:ident,)* ; as $u:ident via $conv_fn:ident named $gen_name:ident) => {

// SAFETY: Since `d1` and `d2` are always less than or equal to `198`, we
// can copy from `lut_ptr[d1..d1 + 1]` and `lut_ptr[d2..d2 + 1]`. To show
// that it's OK to copy into `buf_ptr`, notice that at the beginning
// `curr == buf.len() == 39 > log(n)` since `n < 2^128 < 10^39`, and at
// each step this is kept the same as `n` is divided. Since `n` is always
// non-negative, this means that `curr > 0` so `buf_ptr[curr..curr + 1]`
// is safe to access.
unsafe {
// need at least 16 bits for the 4-characters-at-a-time to work.
assert!(crate::mem::size_of::<$u>() >= 2);
$(
#[stable(feature = "rust1", since = "1.0.0")]
impl fmt::Display for $t {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// If it's a signed integer.
$(
let is_nonnegative = *self >= 0;

// eagerly decode 4 characters at a time
while n >= 10000 {
let rem = (n % 10000) as usize;
n /= 10000;
#[cfg(not(feature = "optimize_for_size"))]
{
if !is_nonnegative {
// convert the negative num to positive by summing 1 to its 2s complement
return (!self as $positive).wrapping_add(1)._fmt(false, f);
}
}
#[cfg(feature = "optimize_for_size")]
{
if !is_nonnegative {
// convert the negative num to positive by summing 1 to its 2s complement
return $gen_name((!self.$conv_fn()).wrapping_add(1), false, f);
}
}
)?
// If it's a positive integer.
#[cfg(not(feature = "optimize_for_size"))]
{
self._fmt(true, f)
}
#[cfg(feature = "optimize_for_size")]
{
$gen_name(self.$conv_fn(), true, f)
}
}
}

let d1 = (rem / 100) << 1;
let d2 = (rem % 100) << 1;
curr -= 4;
#[cfg(not(feature = "optimize_for_size"))]
impl $t {
fn _fmt(mut self: $t, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
const SIZE: usize = $t::MAX.ilog(10) as usize + 1;
let mut buf = [MaybeUninit::<u8>::uninit(); SIZE];
let mut curr = SIZE;
let buf_ptr = MaybeUninit::slice_as_mut_ptr(&mut buf);
let lut_ptr = DEC_DIGITS_LUT.as_ptr();

// SAFETY: Since `d1` and `d2` are always less than or equal to `198`, we
// can copy from `lut_ptr[d1..d1 + 1]` and `lut_ptr[d2..d2 + 1]`. To show
// that it's OK to copy into `buf_ptr`, notice that at the beginning
// `curr == buf.len() == 39 > log(n)` since `n < 2^128 < 10^39`, and at
// each step this is kept the same as `n` is divided. Since `n` is always
// non-negative, this means that `curr > 0` so `buf_ptr[curr..curr + 1]`
// is safe to access.
unsafe {
// need at least 16 bits for the 4-characters-at-a-time to work.
#[allow(overflowing_literals)]
#[allow(unused_comparisons)]
// This block will be removed for smaller types at compile time and in the worst
// case, it will prevent to have the `10000` literal to overflow for `i8` and `u8`.
if core::mem::size_of::<$t>() >= 2 {
// eagerly decode 4 characters at a time
while self >= 10000 {
let rem = (self % 10000) as usize;
self /= 10000;

let d1 = (rem / 100) << 1;
let d2 = (rem % 100) << 1;
curr -= 4;

// We are allowed to copy to `buf_ptr[curr..curr + 3]` here since
// otherwise `curr < 0`. But then `n` was originally at least `10000^10`
// which is `10^40 > 2^128 > n`.
ptr::copy_nonoverlapping(lut_ptr.add(d1 as usize), buf_ptr.add(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2 as usize), buf_ptr.add(curr + 2), 2);
}
}

// We are allowed to copy to `buf_ptr[curr..curr + 3]` here since
// otherwise `curr < 0`. But then `n` was originally at least `10000^10`
// which is `10^40 > 2^128 > n`.
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2), buf_ptr.add(curr + 2), 2);
}
// if we reach here numbers are <= 9999, so at most 4 chars long
let mut n = self as usize; // possibly reduce 64bit math

// if we reach here numbers are <= 9999, so at most 4 chars long
let mut n = n as usize; // possibly reduce 64bit math
// decode 2 more chars, if > 2 chars
if n >= 100 {
let d1 = (n % 100) << 1;
n /= 100;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
}

// decode 2 more chars, if > 2 chars
if n >= 100 {
let d1 = (n % 100) << 1;
n /= 100;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
// if we reach here numbers are <= 100, so at most 2 chars long
// The biggest it can be is 99, and 99 << 1 == 198, so a `u8` is enough.
// decode last 1 or 2 chars
if n < 10 {
curr -= 1;
*buf_ptr.add(curr) = (n as u8) + b'0';
} else {
let d1 = n << 1;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
}
}

// decode last 1 or 2 chars
if n < 10 {
curr -= 1;
*buf_ptr.add(curr) = (n as u8) + b'0';
} else {
let d1 = n << 1;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
}
// SAFETY: `curr` > 0 (since we made `buf` large enough), and all the chars are valid
// UTF-8 since `DEC_DIGITS_LUT` is
let buf_slice = unsafe {
str::from_utf8_unchecked(
slice::from_raw_parts(buf_ptr.add(curr), buf.len() - curr))
};
f.pad_integral(is_nonnegative, "", buf_slice)
}

// SAFETY: `curr` > 0 (since we made `buf` large enough), and all the chars are valid
// UTF-8 since `DEC_DIGITS_LUT` is
let buf_slice = unsafe {
str::from_utf8_unchecked(
slice::from_raw_parts(buf_ptr.add(curr), buf.len() - curr))
};
f.pad_integral(is_nonnegative, "", buf_slice)
}
})*

#[cfg(feature = "optimize_for_size")]
fn $name(mut n: $u, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn $gen_name(mut n: $u, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// 2^128 is about 3*10^38, so 39 gives an extra byte of space
let mut buf = [MaybeUninit::<u8>::uninit(); 39];
let mut curr = buf.len();
Expand Down Expand Up @@ -306,21 +350,6 @@ macro_rules! impl_Display {
};
f.pad_integral(is_nonnegative, "", buf_slice)
}

$(#[stable(feature = "rust1", since = "1.0.0")]
impl fmt::Display for $t {
#[allow(unused_comparisons)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let is_nonnegative = *self >= 0;
let n = if is_nonnegative {
self.$conv_fn()
} else {
// convert the negative num to positive by summing 1 to it's 2 complement
(!self.$conv_fn()).wrapping_add(1)
};
$name(n, is_nonnegative, f)
}
})*
};
}

Expand Down Expand Up @@ -374,7 +403,6 @@ macro_rules! impl_Exp {
(n, exponent, exponent, added_precision)
};

// 39 digits (worst case u128) + . = 40
// Since `curr` always decreases by the number of digits copied, this means
// that `curr >= 0`.
let mut buf = [MaybeUninit::<u8>::uninit(); 40];
Expand Down Expand Up @@ -469,7 +497,7 @@ macro_rules! impl_Exp {
let n = if is_nonnegative {
self.$conv_fn()
} else {
// convert the negative num to positive by summing 1 to it's 2 complement
// convert the negative num to positive by summing 1 to its 2s complement
(!self.$conv_fn()).wrapping_add(1)
};
$name(n, is_nonnegative, false, f)
Expand All @@ -484,7 +512,7 @@ macro_rules! impl_Exp {
let n = if is_nonnegative {
self.$conv_fn()
} else {
// convert the negative num to positive by summing 1 to it's 2 complement
// convert the negative num to positive by summing 1 to its 2s complement
(!self.$conv_fn()).wrapping_add(1)
};
$name(n, is_nonnegative, true, f)
Expand All @@ -499,8 +527,17 @@ macro_rules! impl_Exp {
mod imp {
use super::*;
impl_Display!(
i8, u8, i16, u16, i32, u32, i64, u64, usize, isize
as u64 via to_u64 named fmt_u64
i8 as u8 named fmt_i8,
u8 named fmt_u8,
i16 as u16 named fmt_i16,
u16 named fmt_u16,
i32 as u32 named fmt_i32,
u32 named fmt_u32,
i64 as u64 named fmt_i64,
u64 named fmt_u64,
isize as usize named fmt_isize,
usize named fmt_usize,
; as u64 via to_u64 named fmt_u64
);
impl_Exp!(
i8, u8, i16, u16, i32, u32, i64, u64, usize, isize
Expand All @@ -511,8 +548,21 @@ mod imp {
#[cfg(not(any(target_pointer_width = "64", target_arch = "wasm32")))]
mod imp {
use super::*;
impl_Display!(i8, u8, i16, u16, i32, u32, isize, usize as u32 via to_u32 named fmt_u32);
impl_Display!(i64, u64 as u64 via to_u64 named fmt_u64);
impl_Display!(
i8 as u8 named fmt_i8,
u8 named fmt_u8,
i16 as u16 named fmt_i16,
u16 named fmt_u16,
i32 as u32 named fmt_i32,
u32 named fmt_u32,
isize as usize named fmt_isize,
usize named fmt_usize,
; as u32 via to_u32 named fmt_u32);
impl_Display!(
i64 as u64 named fmt_i64,
u64 named fmt_u64,
; as u64 via to_u64 named fmt_u64);

impl_Exp!(i8, u8, i16, u16, i32, u32, isize, usize as u32 via to_u32 named exp_u32);
impl_Exp!(i64, u64 as u64 via to_u64 named exp_u64);
}
Expand Down Expand Up @@ -619,7 +669,7 @@ impl fmt::Display for i128 {
let n = if is_nonnegative {
self.to_u128()
} else {
// convert the negative num to positive by summing 1 to it's 2 complement
// convert the negative num to positive by summing 1 to its 2s complement
(!self.to_u128()).wrapping_add(1)
};
fmt_u128(n, is_nonnegative, f)
Expand Down
Loading