Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added WideMul and Sqrt traits. #6017

Merged
merged 1 commit into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions corelib/src/num/traits.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ pub use ops::overflowing::{OverflowingAdd, OverflowingSub, OverflowingMul};
pub use ops::wrapping::{WrappingAdd, WrappingSub, WrappingMul};
pub use ops::checked::{CheckedAdd, CheckedSub, CheckedMul};
pub use ops::saturating::{SaturatingAdd, SaturatingSub, SaturatingMul};
pub use ops::widemul::WideMul;
pub use ops::sqrt::Sqrt;
2 changes: 2 additions & 0 deletions corelib/src/num/traits/ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ pub mod overflowing;
pub mod wrapping;
pub mod checked;
pub mod saturating;
pub(crate) mod sqrt;
pub(crate) mod widemul;
49 changes: 49 additions & 0 deletions corelib/src/num/traits/ops/sqrt.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/// A trait for computing the square root of a number.
pub trait Sqrt<T> {
/// The type of the result of the square root operation.
type Target;
/// Compute the square root of a number.
fn sqrt(self: T) -> Self::Target;
}

impl SqrtU8 of Sqrt<u8> {
type Target = u8;
fn sqrt(self: u8) -> u8 {
core::integer::u8_sqrt(self)
}
}

impl SqrtU16 of Sqrt<u16> {
type Target = u8;
fn sqrt(self: u16) -> u8 {
core::integer::u16_sqrt(self)
}
}

impl SqrtU32 of Sqrt<u32> {
type Target = u16;
fn sqrt(self: u32) -> u16 {
core::integer::u32_sqrt(self)
}
}

impl SqrtU64 of Sqrt<u64> {
type Target = u32;
fn sqrt(self: u64) -> u32 {
core::integer::u64_sqrt(self)
}
}

impl SqrtU128 of Sqrt<u128> {
type Target = u64;
fn sqrt(self: u128) -> u64 {
core::integer::u128_sqrt(self)
}
}

impl SqrtU256 of Sqrt<u256> {
type Target = u128;
fn sqrt(self: u256) -> u128 {
core::integer::u256_sqrt(self)
}
}
78 changes: 78 additions & 0 deletions corelib/src/num/traits/ops/widemul.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/// A trait for types that can be multiplied together to produce a wider type.
pub trait WideMul<Lhs, Rhs> {
/// The type of the result of the multiplication.
type Target;
/// Multiply two values together, producing a wider type.
fn wide_mul(lhs: Lhs, rhs: Rhs) -> Self::Target;
}

impl WideMulI8 of WideMul<i8, i8> {
type Target = i16;
fn wide_mul(lhs: i8, rhs: i8) -> i16 {
core::integer::i8_wide_mul(lhs, rhs)
}
}

impl WideMulI16 of WideMul<i16, i16> {
type Target = i32;
fn wide_mul(lhs: i16, rhs: i16) -> i32 {
core::integer::i16_wide_mul(lhs, rhs)
}
}

impl WideMulI32 of WideMul<i32, i32> {
type Target = i64;
fn wide_mul(lhs: i32, rhs: i32) -> i64 {
core::integer::i32_wide_mul(lhs, rhs)
}
}

impl WideMulI64 of WideMul<i64, i64> {
type Target = i128;
fn wide_mul(lhs: i64, rhs: i64) -> i128 {
core::integer::i64_wide_mul(lhs, rhs)
}
}

impl WideMulU8 of WideMul<u8, u8> {
type Target = u16;
fn wide_mul(lhs: u8, rhs: u8) -> u16 {
core::integer::u8_wide_mul(lhs, rhs)
}
}

impl WideMulU16 of WideMul<u16, u16> {
type Target = u32;
fn wide_mul(lhs: u16, rhs: u16) -> u32 {
core::integer::u16_wide_mul(lhs, rhs)
}
}

impl WideMulU32 of WideMul<u32, u32> {
type Target = u64;
fn wide_mul(lhs: u32, rhs: u32) -> u64 {
core::integer::u32_wide_mul(lhs, rhs)
}
}

impl WideMulU64 of WideMul<u64, u64> {
type Target = u128;
fn wide_mul(lhs: u64, rhs: u64) -> u128 {
core::integer::u64_wide_mul(lhs, rhs)
}
}

impl WideMulU128 of WideMul<u128, u128> {
type Target = u256;
fn wide_mul(lhs: u128, rhs: u128) -> u256 {
let (high, low) = core::integer::u128_wide_mul(lhs, rhs);
u256 { low, high }
}
}

impl WideMulU256 of WideMul<u256, u256> {
type Target = core::integer::u512;
fn wide_mul(lhs: u256, rhs: u256) -> core::integer::u512 {
core::integer::u256_wide_mul(lhs, rhs)
}
}
126 changes: 52 additions & 74 deletions corelib/src/test/integer_test.cairo
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
use core::{
integer,
integer::{
BoundedInt, u128_sqrt, u128_wrapping_sub, u16_sqrt, u256_sqrt, u256_wide_mul, u32_sqrt,
u512_safe_div_rem_by_u256, u512, u64_sqrt, u8_sqrt
}
};
use core::{integer, integer::{BoundedInt, u512_safe_div_rem_by_u256, u512}};
use core::test::test_utils::{assert_eq, assert_ne, assert_le, assert_lt, assert_gt, assert_ge};
use core::num::traits::{Sqrt, WideMul, WrappingSub};

#[test]
fn test_u8_operators() {
Expand Down Expand Up @@ -35,12 +30,12 @@ fn test_u8_operators() {
assert_ge(5_u8, 2_u8, '5 >= 2');
assert(!(3_u8 > 3_u8), '!(3 > 3)');
assert_ge(3_u8, 3_u8, '3 >= 3');
assert_eq(@u8_sqrt(9), @3, 'u8_sqrt(9) == 3');
assert_eq(@u8_sqrt(10), @3, 'u8_sqrt(10) == 3');
assert_eq(@u8_sqrt(0x40), @0x8, 'u8_sqrt(2^6) == 2^3');
assert_eq(@u8_sqrt(0xff), @0xf, 'Wrong square root result.');
assert_eq(@u8_sqrt(1), @1, 'u8_sqrt(1) == 1');
assert_eq(@u8_sqrt(0), @0, 'u8_sqrt(0) == 0');
assert!(9_u8.sqrt() == 3);
assert!(10_u8.sqrt() == 3);
assert!(0x40_u8.sqrt() == 0x8);
assert!(0xff_u8.sqrt() == 0xf);
assert!(1_u8.sqrt() == 1);
assert!(0_u8.sqrt() == 0);
assert_eq(@~0x00_u8, @0xff, '~0x00 == 0xff');
assert_eq(@~0x81_u8, @0x7e, '~0x81 == 0x7e');
}
Expand Down Expand Up @@ -138,12 +133,12 @@ fn test_u16_operators() {
assert_ge(5_u16, 2_u16, '5 >= 2');
assert(!(3_u16 > 3_u16), '!(3 > 3)');
assert_ge(3_u16, 3_u16, '3 >= 3');
assert_eq(@u16_sqrt(9), @3, 'u16_sqrt(9) == 3');
assert_eq(@u16_sqrt(10), @3, 'u16_sqrt(10) == 3');
assert_eq(@u16_sqrt(0x400), @0x20, 'u16_sqrt(2^10) == 2^5');
assert_eq(@u16_sqrt(0xffff), @0xff, 'Wrong square root result.');
assert_eq(@u16_sqrt(1), @1, 'u64_sqrt(1) == 1');
assert_eq(@u16_sqrt(0), @0, 'u64_sqrt(0) == 0');
assert!(9_u16.sqrt() == 3);
assert!(10_u16.sqrt() == 3);
assert!(0x400_u16.sqrt() == 0x20);
assert!(0xffff_u16.sqrt() == 0xff);
assert!(1_u16.sqrt() == 1);
assert!(0_u16.sqrt() == 0);
assert_eq(@~0x0000_u16, @0xffff, '~0x0000 == 0xffff');
assert_eq(@~0x8421_u16, @0x7bde, '~0x8421 == 0x7bde');
}
Expand Down Expand Up @@ -241,12 +236,12 @@ fn test_u32_operators() {
assert_ge(5_u32, 2_u32, '5 >= 2');
assert(!(3_u32 > 3_u32), '!(3 > 3)');
assert_ge(3_u32, 3_u32, '3 >= 3');
assert_eq(@u32_sqrt(9), @3, 'u32_sqrt(9) == 3');
assert_eq(@u32_sqrt(10), @3, 'u32_sqrt(10) == 3');
assert_eq(@u32_sqrt(0x100000), @0x400, 'u32_sqrt(2^20) == 2^10');
assert_eq(@u32_sqrt(0xffffffff), @0xffff, 'Wrong square root result.');
assert_eq(@u32_sqrt(1), @1, 'u64_sqrt(1) == 1');
assert_eq(@u32_sqrt(0), @0, 'u64_sqrt(0) == 0');
assert!(9_u32.sqrt() == 3);
assert!(10_u32.sqrt() == 3);
assert!(0x100000_u32.sqrt() == 0x400);
assert!(0xffffffff_u32.sqrt() == 0xffff);
assert!(1_u32.sqrt() == 1);
assert!(0_u32.sqrt() == 0);
assert_eq(@~0x00000000_u32, @0xffffffff, '~0x00000000 == 0xffffffff');
assert_eq(@~0x12345678_u32, @0xedcba987, '~0x12345678 == 0xedcba987');
}
Expand Down Expand Up @@ -346,12 +341,12 @@ fn test_u64_operators() {
assert_ge(5_u64, 2_u64, '5 >= 2');
assert(!(3_u64 > 3_u64), '!(3 > 3)');
assert_ge(3_u64, 3_u64, '3 >= 3');
assert_eq(@u64_sqrt(9), @3, 'u64_sqrt(9) == 3');
assert_eq(@u64_sqrt(10), @3, 'u64_sqrt(10) == 3');
assert_eq(@u64_sqrt(0x10000000000), @0x100000, 'u64_sqrt(2^40) == 2^20');
assert_eq(@u64_sqrt(0xffffffffffffffff), @0xffffffff, 'Wrong square root result.');
assert_eq(@u64_sqrt(1), @1, 'u64_sqrt(1) == 1');
assert_eq(@u64_sqrt(0), @0, 'u64_sqrt(0) == 0');
assert!(9_u64.sqrt() == 3);
assert!(10_u64.sqrt() == 3);
assert!(0x10000000000_u64.sqrt() == 0x100000);
assert!(0xffffffffffffffff_u64.sqrt() == 0xffffffff);
assert!(1_u64.sqrt() == 1);
assert!(0_u64.sqrt() == 0);
assert_eq(@~0x0000000000000000_u64, @0xffffffffffffffff, '~0x0..0 == 0xf..f');
assert_eq(@~0x123456789abcdef1_u64, @0xedcba9876543210e, '~0x12..ef1 == 0xed..10e');
}
Expand Down Expand Up @@ -451,18 +446,12 @@ fn test_u128_operators() {
assert_eq(@((2_u128 & 2_u128)), @2_u128, '2 & 2 == 2');
assert_eq(@((2_u128 & 3_u128)), @2_u128, '2 & 3 == 2');
assert_eq(@((3_u128 ^ 6_u128)), @5_u128, '3 ^ 6 == 5');
assert_eq(@u128_sqrt(9), @3, 'u128_sqrt(9) == 3');
assert_eq(@u128_sqrt(10), @3, 'u128_sqrt(10) == 3');
assert_eq(
@u128_sqrt(0x10000000000000000000000000), @0x4000000000000, 'u128_sqrt(2^100) == 2^50'
);
assert_eq(
@u128_sqrt(0xffffffffffffffffffffffffffffffff),
@0xffffffffffffffff,
'Wrong square root result.'
);
assert_eq(@u128_sqrt(1), @1, 'u128_sqrt(1) == 1');
assert_eq(@u128_sqrt(0), @0, 'u128_sqrt(0) == 0');
assert!(9_u128.sqrt() == 3);
assert!(10_u128.sqrt() == 3);
assert!(0x10000000000000000000000000_u128.sqrt() == 0x4000000000000);
assert!(0xffffffffffffffffffffffffffffffff_u128.sqrt() == 0xffffffffffffffff);
assert!(1_u128.sqrt() == 1);
assert!(0_u128.sqrt() == 0);
assert_eq(
@~0x00000000000000000000000000000000_u128,
@0xffffffffffffffffffffffffffffffff,
Expand Down Expand Up @@ -506,27 +495,27 @@ fn test_u128_sub_overflow_4() {
#[test]
fn test_u128_wrapping_sub_1() {
let max_u128: u128 = BoundedInt::max();
let should_be_max = u128_wrapping_sub(0_u128, 1_u128);
let should_be_max = WrappingSub::wrapping_sub(0_u128, 1_u128);
assert_eq(@max_u128, @should_be_max, 'Should be max u128')
}

#[test]
fn test_u128_wrapping_sub_2() {
let max_u128_minus_two: u128 = BoundedInt::max() - 2;
let should_be_max = u128_wrapping_sub(0_u128, 3_u128);
let should_be_max = WrappingSub::wrapping_sub(0_u128, 3_u128);
assert_eq(@max_u128_minus_two, @should_be_max, 'Should be max u128 - 2')
}

#[test]
fn test_u128_wrapping_sub_3() {
let max_u128_minus_899: u128 = BoundedInt::max() - 899;
let should_be_max = u128_wrapping_sub(100, 1000);
let should_be_max = WrappingSub::wrapping_sub(100, 1000);
assert_eq(@max_u128_minus_899, @should_be_max, 'Should be max u128 - 899')
}

#[test]
fn test_u128_wrapping_sub_4() {
let should_be_zero = u128_wrapping_sub(0_u128, 0_u128);
let should_be_zero = WrappingSub::wrapping_sub(0_u128, 0_u128);
assert_eq(@should_be_zero, @0, 'Should be 0')
}

Expand Down Expand Up @@ -745,19 +734,17 @@ fn test_u256_mul_overflow_2() {

#[test]
fn test_u256_wide_mul() {
assert_eq(@u256_wide_mul(0, 0), @u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 }, '0 * 0 != 0');
assert_eq(
@u256_wide_mul(
0x1001001001001001001001001001001001001001001001001001,
0x1000100010001000100010001000100010001000100010001000100010001
),
@u512 {
assert!(WideMul::wide_mul(0_u256, 0_u256) == u512 { limb0: 0, limb1: 0, limb2: 0, limb3: 0 });
assert!(
WideMul::wide_mul(
0x1001001001001001001001001001001001001001001001001001_u256,
0x1000100010001000100010001000100010001000100010001000100010001_u256
) == u512 {
limb0: 0x33233223222222122112111111011001,
limb1: 0x54455445544554454444443443343333,
limb2: 0x21222222322332333333433443444444,
limb3: 0x1001101111112112
},
'long calculation failed'
}
);
}

Expand Down Expand Up @@ -905,24 +892,15 @@ fn test_default_felt252dict_values() {

#[test]
fn test_u256_sqrt() {
assert_eq(@u256_sqrt(9.into()), @3, 'u256_sqrt(9) == 3');
assert_eq(@u256_sqrt(10.into()), @3, 'u256_sqrt(10) == 3');
assert_eq(
@u256_sqrt(1267650600228229401496703205376.into()),
@1125899906842624,
'u256_sqrt(2^100) == 2^50'
);
assert_eq(
@u256_sqrt(340282366920938463463374607431768211455.into()),
@18446744073709551615,
'Wrong square root result.'
);
assert_eq(@u256_sqrt(1.into()), @1, 'u256_sqrt(1) == 1');
assert_eq(@u256_sqrt(0.into()), @0, 'u256_sqrt(0) == 0');

assert_eq(@u256_sqrt(BoundedInt::max()), @BoundedInt::max(), 'u256::MAX**0.5==u128::MAX');
let (high, low) = integer::u128_wide_mul(BoundedInt::max(), BoundedInt::max());
assert_eq(@u256_sqrt(u256 { low, high }), @BoundedInt::max(), '(u128::MAX**2)**0.5==u128::MAX');
assert!(9_u256.sqrt() == 3);
assert!(10_u256.sqrt() == 3);
assert!(1267650600228229401496703205376_u256.sqrt() == 1125899906842624);
assert!(340282366920938463463374607431768211455_u256.sqrt() == 18446744073709551615);
assert!(1_u256.sqrt() == 1);
assert!(0_u256.sqrt() == 0);
assert!(BoundedInt::<u256>::max().sqrt() == BoundedInt::<u128>::max());
let max_u128: u128 = BoundedInt::max();
assert!(WideMul::wide_mul(max_u128, max_u128).sqrt() == max_u128);
}

#[test]
Expand Down