Skip to content

Commit

Permalink
Added fast trig methods (lower precision
Browse files Browse the repository at this point in the history
  • Loading branch information
clexmond committed Jul 8, 2023
1 parent 90c7825 commit 1a2eb93
Show file tree
Hide file tree
Showing 5 changed files with 745 additions and 186 deletions.
1 change: 1 addition & 0 deletions src/math.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod comp;
mod core;
mod hyp;
mod lut;
mod trig;
175 changes: 6 additions & 169 deletions src/math/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,9 @@ use result::{ResultTrait, ResultTraitImpl};
use traits::{Into, TryInto};
use integer::{u256_safe_divmod, u256_as_non_zero, upcast};

use cubit::math::lut;
use cubit::types::fixed::{
HALF_u128,
MAX_u128,
ONE_u128,
Fixed,
FixedInto,
FixedTrait,
FixedAdd,
FixedDiv,
FixedMul,
FixedNeg
HALF_u128, MAX_u128, ONE_u128, Fixed, FixedInto, FixedTrait, FixedAdd, FixedDiv, FixedMul, FixedNeg
};

// PUBLIC
Expand Down Expand Up @@ -80,7 +72,7 @@ fn exp2(a: Fixed) -> Fixed {
}

let (int_part, frac_part) = _split_unsigned(a);
let int_res = FixedTrait::new_unscaled(_exp2(int_part), false);
let int_res = FixedTrait::new_unscaled(lut::exp2(int_part), false);
let mut res_u = int_res;

if frac_part > 0 {
Expand All @@ -104,7 +96,7 @@ fn exp2(a: Fixed) -> Fixed {
}

fn exp2_int(exp: u128) -> Fixed {
return FixedTrait::new_unscaled(_exp2(exp), false);
return FixedTrait::new_unscaled(lut::exp2(exp), false);
}

fn floor(a: Fixed) -> Fixed {
Expand Down Expand Up @@ -162,7 +154,7 @@ fn log2(a: Fixed) -> Fixed {
return -log2(div);
}

let (msb, div) = msb(a.mag);
let (msb, div) = lut::msb(a.mag / ONE_u128);
let norm = a / FixedTrait::new_unscaled(div, false);

let r8 = FixedTrait::new(167660832607149504, true) * norm;
Expand Down Expand Up @@ -293,161 +285,6 @@ fn sub(a: Fixed, b: Fixed) -> Fixed {
return add(a, -b);
}

// Calculates the most significant bit
fn msb(a: u128) -> (u128, u128) {
let whole = a / ONE_u128;

if whole < 256 {
if whole < 2 { return (0, 1); }
if whole < 4 { return (1, 2); }
if whole < 8 { return (2, 4); }
if whole < 16 { return (3, 8); }
if whole < 32 { return (4, 16); }
if whole < 64 { return (5, 32); }
if whole < 128 { return (6, 64); }
if whole < 256 { return (7, 128); }
} else if whole < 65536 {
if whole < 512 { return (8, 256); }
if whole < 1024 { return (9, 512); }
if whole < 2048 { return (10, 1024); }
if whole < 4096 { return (11, 2048); }
if whole < 8192 { return (12, 4096); }
if whole < 16384 { return (13, 8192); }
if whole < 32768 { return (14, 16384); }
if whole < 65536 { return (15, 32768); }
} else if whole < 16777216 {
if whole < 131072 { return (16, 65536); }
if whole < 262144 { return (17, 131072); }
if whole < 524288 { return (18, 262144); }
if whole < 1048576 { return (19, 524288); }
if whole < 2097152 { return (20, 1048576); }
if whole < 4194304 { return (21, 2097152); }
if whole < 8388608 { return (22, 4194304); }
if whole < 16777216 { return (23, 8388608); }
} else if whole < 4294967296 {
if whole < 33554432 { return (24, 16777216); }
if whole < 67108864 { return (25, 33554432); }
if whole < 134217728 { return (26, 67108864); }
if whole < 268435456 { return (27, 134217728); }
if whole < 536870912 { return (28, 268435456); }
if whole < 1073741824 { return (29, 536870912); }
if whole < 2147483648 { return (30, 1073741824); }
if whole < 4294967296 { return (31, 2147483648); }
} else if whole < 1099511627776 {
if whole < 8589934592 { return (32, 4294967296); }
if whole < 17179869184 { return (33, 8589934592); }
if whole < 34359738368 { return (34, 17179869184); }
if whole < 68719476736 { return (35, 34359738368); }
if whole < 137438953472 { return (36, 68719476736); }
if whole < 274877906944 { return (37, 137438953472); }
if whole < 549755813888 { return (38, 274877906944); }
if whole < 1099511627776 { return (39, 549755813888); }
} else if whole < 281474976710656 {
if whole < 2199023255552 { return (40, 1099511627776); }
if whole < 4398046511104 { return (41, 2199023255552); }
if whole < 8796093022208 { return (42, 4398046511104); }
if whole < 17592186044416 { return (43, 8796093022208); }
if whole < 35184372088832 { return (44, 17592186044416); }
if whole < 70368744177664 { return (45, 35184372088832); }
if whole < 140737488355328 { return (46, 70368744177664); }
if whole < 281474976710656 { return (47, 140737488355328); }
} else if whole < 72057594037927936 {
if whole < 562949953421312 { return (48, 281474976710656); }
if whole < 1125899906842624 { return (49, 562949953421312); }
if whole < 2251799813685248 { return (50, 1125899906842624); }
if whole < 4503599627370496 { return (51, 2251799813685248); }
if whole < 9007199254740992 { return (52, 4503599627370496); }
if whole < 18014398509481984 { return (53, 9007199254740992); }
if whole < 36028797018963968 { return (54, 18014398509481984); }
if whole < 72057594037927936 { return (55, 36028797018963968); }
} else {
if whole < 144115188075855872 { return (56, 72057594037927936); }
if whole < 288230376151711744 { return (57, 144115188075855872); }
if whole < 576460752303423488 { return (58, 288230376151711744); }
if whole < 1152921504606846976 { return (59, 576460752303423488); }
if whole < 2305843009213693952 { return (60, 1152921504606846976); }
if whole < 4611686018427387904 { return (61, 2305843009213693952); }
if whole < 9223372036854775808 { return (62, 4611686018427387904); }
if whole < 18446744073709551616 { return (63, 9223372036854775808); }
}

return (64, 18446744073709551616);
}

fn _exp2(exp: u128) -> u128 {
if exp <= 16 {
if exp == 0 { return 1; }
if exp == 1 { return 2; }
if exp == 2 { return 4; }
if exp == 3 { return 8; }
if exp == 4 { return 16; }
if exp == 5 { return 32; }
if exp == 6 { return 64; }
if exp == 7 { return 128; }
if exp == 8 { return 256; }
if exp == 9 { return 512; }
if exp == 10 { return 1024; }
if exp == 11 { return 2048; }
if exp == 12 { return 4096; }
if exp == 13 { return 8192; }
if exp == 14 { return 16384; }
if exp == 15 { return 32768; }
if exp == 16 { return 65536; }
} else if exp <= 32 {
if exp == 17 { return 131072; }
if exp == 18 { return 262144; }
if exp == 19 { return 524288; }
if exp == 20 { return 1048576; }
if exp == 21 { return 2097152; }
if exp == 22 { return 4194304; }
if exp == 23 { return 8388608; }
if exp == 24 { return 16777216; }
if exp == 25 { return 33554432; }
if exp == 26 { return 67108864; }
if exp == 27 { return 134217728; }
if exp == 28 { return 268435456; }
if exp == 29 { return 536870912; }
if exp == 30 { return 1073741824; }
if exp == 31 { return 2147483648; }
if exp == 32 { return 4294967296; }
} else if exp <= 48 {
if exp == 33 { return 8589934592; }
if exp == 34 { return 17179869184; }
if exp == 35 { return 34359738368; }
if exp == 36 { return 68719476736; }
if exp == 37 { return 137438953472; }
if exp == 38 { return 274877906944; }
if exp == 39 { return 549755813888; }
if exp == 40 { return 1099511627776; }
if exp == 41 { return 2199023255552; }
if exp == 42 { return 4398046511104; }
if exp == 43 { return 8796093022208; }
if exp == 44 { return 17592186044416; }
if exp == 45 { return 35184372088832; }
if exp == 46 { return 70368744177664; }
if exp == 47 { return 140737488355328; }
if exp == 48 { return 281474976710656; }
} else {
if exp == 49 { return 562949953421312; }
if exp == 50 { return 1125899906842624; }
if exp == 51 { return 2251799813685248; }
if exp == 52 { return 4503599627370496; }
if exp == 53 { return 9007199254740992; }
if exp == 54 { return 18014398509481984; }
if exp == 55 { return 36028797018963968; }
if exp == 56 { return 72057594037927936; }
if exp == 57 { return 144115188075855872; }
if exp == 58 { return 288230376151711744; }
if exp == 59 { return 576460752303423488; }
if exp == 60 { return 1152921504606846976; }
if exp == 61 { return 2305843009213693952; }
if exp == 62 { return 4611686018427387904; }
if exp == 63 { return 9223372036854775808; }
}

return 18446744073709551616;
}

// Ignores sign and always returns false
fn _split_unsigned(a: Fixed) -> (u128, u128) {
return integer::u128_safe_divmod(a.mag, integer::u128_as_non_zero(ONE_u128));
Expand Down Expand Up @@ -598,7 +435,7 @@ fn test_sqrt() {
#[available_gas(100000)]
fn test_msb() {
let a = FixedTrait::new_unscaled(4503599627370495, false);
let (msb, div) = msb(a.mag);
let (msb, div) = lut::msb(a.mag / ONE_u128);
assert(msb == 51, 'invalid msb');
assert(div == 2251799813685248, 'invalid msb ceil');
}
Expand Down
Loading

0 comments on commit 1a2eb93

Please sign in to comment.