Skip to content

Commit

Permalink
Rollup merge of rust-lang#126199 - ivan-shrimp:nonzero_isqrt, r=tgross35
Browse files Browse the repository at this point in the history
Add `isqrt` to `NonZero<uN>`

Implements [rust-lang#70887 (comment)](rust-lang#116226 (comment)), with the following signature:

```rust
impl NonZero<uN> {
    const fn isqrt(self) -> Self;
}
```

Unintended benefits include one fewer panicking branch in `ilog2` for LLVM to optimize away, and one fewer `assume_unchecked` as `NonZero` already does that.

The fast path for `self == 1` is dropped, but the current implementation is very slow anyways compared to hardware. Performance improvements can always come later.

(I didn't add the function to `NonZero<iN>`, since _every_ existing `NonZero` method is non-panicking, and it might be nice to leave it that way.)
  • Loading branch information
matthiaskrgr authored Jul 19, 2024
2 parents f06530c + f3b1c8a commit 9f20a0f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 45 deletions.
1 change: 0 additions & 1 deletion core/src/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#![stable(feature = "rust1", since = "1.0.0")]

use crate::ascii;
use crate::hint;
use crate::intrinsics;
use crate::mem;
use crate::str::FromStr;
Expand Down
60 changes: 54 additions & 6 deletions core/src/num/nonzero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::cmp::Ordering;
use crate::fmt;
use crate::hash::{Hash, Hasher};
use crate::hint;
use crate::intrinsics;
use crate::marker::{Freeze, StructuralPartialEq};
use crate::ops::{BitOr, BitOrAssign, Div, DivAssign, Neg, Rem, RemAssign};
Expand Down Expand Up @@ -604,7 +605,6 @@ macro_rules! nonzero_integer {
}

nonzero_integer_signedness_dependent_methods! {
Self = $Ty,
Primitive = $signedness $Int,
UnsignedPrimitive = $Uint,
}
Expand Down Expand Up @@ -823,7 +823,7 @@ macro_rules! nonzero_integer {
}
}

nonzero_integer_signedness_dependent_impls!($Ty $signedness $Int);
nonzero_integer_signedness_dependent_impls!($signedness $Int);
};

(Self = $Ty:ident, Primitive = unsigned $Int:ident $(,)?) => {
Expand All @@ -849,7 +849,7 @@ macro_rules! nonzero_integer {

macro_rules! nonzero_integer_signedness_dependent_impls {
// Impls for unsigned nonzero types only.
($Ty:ident unsigned $Int:ty) => {
(unsigned $Int:ty) => {
#[stable(feature = "nonzero_div", since = "1.51.0")]
impl Div<NonZero<$Int>> for $Int {
type Output = $Int;
Expand Down Expand Up @@ -897,7 +897,7 @@ macro_rules! nonzero_integer_signedness_dependent_impls {
}
};
// Impls for signed nonzero types only.
($Ty:ident signed $Int:ty) => {
(signed $Int:ty) => {
#[stable(feature = "signed_nonzero_neg", since = "1.71.0")]
impl Neg for NonZero<$Int> {
type Output = Self;
Expand All @@ -918,7 +918,6 @@ macro_rules! nonzero_integer_signedness_dependent_impls {
macro_rules! nonzero_integer_signedness_dependent_methods {
// Associated items for unsigned nonzero types only.
(
Self = $Ty:ident,
Primitive = unsigned $Int:ident,
UnsignedPrimitive = $Uint:ty,
) => {
Expand Down Expand Up @@ -1224,11 +1223,60 @@ macro_rules! nonzero_integer_signedness_dependent_methods {

intrinsics::ctpop(self.get()) < 2
}

/// Returns the square root of the number, rounded down.
///
/// # Examples
///
/// Basic usage:
/// ```
/// #![feature(isqrt)]
/// # use std::num::NonZero;
/// #
/// # fn main() { test().unwrap(); }
/// # fn test() -> Option<()> {
#[doc = concat!("let ten = NonZero::new(10", stringify!($Int), ")?;")]
#[doc = concat!("let three = NonZero::new(3", stringify!($Int), ")?;")]
///
/// assert_eq!(ten.isqrt(), three);
/// # Some(())
/// # }
#[unstable(feature = "isqrt", issue = "116226")]
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
// The algorithm is based on the one presented in
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
// which cites as source the following C code:
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.

let mut op = self.get();
let mut res = 0;
let mut one = 1 << (self.ilog2() & !1);

while one != 0 {
if op >= res + one {
op -= res + one;
res = (res >> 1) + one;
} else {
res >>= 1;
}
one >>= 2;
}

// SAFETY: The result fits in an integer with half as many bits.
// Inform the optimizer about it.
unsafe { hint::assert_unchecked(res < 1 << (Self::BITS / 2)) };

// SAFETY: The square root of an integer >= 1 is always >= 1.
unsafe { Self::new_unchecked(res) }
}
};

// Associated items for signed nonzero types only.
(
Self = $Ty:ident,
Primitive = signed $Int:ident,
UnsignedPrimitive = $Uint:ty,
) => {
Expand Down
47 changes: 9 additions & 38 deletions core/src/num/uint_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1226,10 +1226,9 @@ macro_rules! uint_impl {
without modifying the original"]
#[inline]
pub const fn checked_ilog2(self) -> Option<u32> {
if let Some(x) = NonZero::new(self) {
Some(x.ilog2())
} else {
None
match NonZero::new(self) {
Some(x) => Some(x.ilog2()),
None => None,
}
}

Expand All @@ -1248,10 +1247,9 @@ macro_rules! uint_impl {
without modifying the original"]
#[inline]
pub const fn checked_ilog10(self) -> Option<u32> {
if let Some(x) = NonZero::new(self) {
Some(x.ilog10())
} else {
None
match NonZero::new(self) {
Some(x) => Some(x.ilog10()),
None => None,
}
}

Expand Down Expand Up @@ -2590,37 +2588,10 @@ macro_rules! uint_impl {
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
if self < 2 {
return self;
}

// The algorithm is based on the one presented in
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
// which cites as source the following C code:
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.

let mut op = self;
let mut res = 0;
let mut one = 1 << (self.ilog2() & !1);

while one != 0 {
if op >= res + one {
op -= res + one;
res = (res >> 1) + one;
} else {
res >>= 1;
}
one >>= 2;
match NonZero::new(self) {
Some(x) => x.isqrt().get(),
None => 0,
}

// SAFETY: the result is positive and fits in an integer with half as many bits.
// Inform the optimizer about it.
unsafe {
hint::assert_unchecked(0 < res);
hint::assert_unchecked(res < 1 << (Self::BITS / 2));
}

res
}

/// Performs Euclidean division.
Expand Down

0 comments on commit 9f20a0f

Please sign in to comment.