diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs new file mode 100644 index 0000000000..5bfd8465c6 --- /dev/null +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs @@ -0,0 +1,189 @@ +use core::arch::x86_64::*; + +const MSB_1: i64 = 0x8000000000000000u64 as i64; +const P_N_1: i64 = 0xFFFFFFFF; + +#[inline(always)] +pub fn shift_avx(a: &__m256i) -> __m256i { + unsafe { + let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1); + _mm256_xor_si256(*a, msb) + } +} + +#[inline(always)] +pub fn add_avx_a_sc(a_sc: &__m256i, b: &__m256i) -> __m256i { + unsafe { + let c0_s = _mm256_add_epi64(*a_sc, *b); + let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1); + let mask_ = _mm256_cmpgt_epi64(*a_sc, c0_s); + let corr_ = _mm256_and_si256(mask_, p_n); + let c_s = _mm256_add_epi64(c0_s, corr_); + shift_avx(&c_s) + } +} + +#[inline(always)] +pub fn add_avx(a: &__m256i, b: &__m256i) -> __m256i { + let a_sc = shift_avx(a); + add_avx_a_sc(&a_sc, b) +} + +#[inline(always)] +pub fn add_avx_s_b_small(a_s: &__m256i, b_small: &__m256i) -> __m256i { + unsafe { + let c0_s = _mm256_add_epi64(*a_s, *b_small); + let mask_ = _mm256_cmpgt_epi32(*a_s, c0_s); + let corr_ = _mm256_srli_epi64(mask_, 32); + _mm256_add_epi64(c0_s, corr_) + } +} + +#[inline(always)] +pub fn sub_avx_s_b_small(a_s: &__m256i, b: &__m256i) -> __m256i { + unsafe { + let c0_s = _mm256_sub_epi64(*a_s, *b); + let mask_ = _mm256_cmpgt_epi32(c0_s, *a_s); + let corr_ = _mm256_srli_epi64(mask_, 32); + _mm256_sub_epi64(c0_s, corr_) + } +} + +#[inline(always)] +pub fn reduce_avx_128_64(c_h: &__m256i, c_l: &__m256i) -> __m256i { + unsafe { + let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1); + let c_hh = _mm256_srli_epi64(*c_h, 32); + let c_ls = _mm256_xor_si256(*c_l, msb); + let c1_s = sub_avx_s_b_small(&c_ls, &c_hh); + let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1); + let c2 = _mm256_mul_epu32(*c_h, p_n); + let c_s = add_avx_s_b_small(&c1_s, &c2); + _mm256_xor_si256(c_s, msb) + } +} + +// Here we suppose c_h < 2^32 +#[inline(always)] +pub fn reduce_avx_96_64(c_h: &__m256i, c_l: &__m256i) -> __m256i { + unsafe { + let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1); + let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1); + let c_ls = _mm256_xor_si256(*c_l, msb); + let c2 = _mm256_mul_epu32(*c_h, p_n); + let c_s = add_avx_s_b_small(&c_ls, &c2); + _mm256_xor_si256(c_s, msb) + } +} + +#[inline(always)] +pub fn mult_avx_128(a: &__m256i, b: &__m256i) -> (__m256i, __m256i) { + unsafe { + let a_h = _mm256_srli_epi64(*a, 32); + let b_h = _mm256_srli_epi64(*b, 32); + let c_hh = _mm256_mul_epu32(a_h, b_h); + let c_hl = _mm256_mul_epu32(a_h, *b); + let c_lh = _mm256_mul_epu32(*a, b_h); + let c_ll = _mm256_mul_epu32(*a, *b); + let c_ll_h = _mm256_srli_epi64(c_ll, 32); + let r0 = _mm256_add_epi64(c_hl, c_ll_h); + let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1); + let r0_l = _mm256_and_si256(r0, p_n); + let r0_h = _mm256_srli_epi64(r0, 32); + let r1 = _mm256_add_epi64(c_lh, r0_l); + let r1_l = _mm256_slli_epi64(r1, 32); + let c_l = _mm256_blend_epi32(c_ll, r1_l, 0xaa); + let r2 = _mm256_add_epi64(c_hh, r0_h); + let r1_h = _mm256_srli_epi64(r1, 32); + let c_h = _mm256_add_epi64(r2, r1_h); + (c_h, c_l) + } +} + +#[inline(always)] +pub fn mult_avx(a: &__m256i, b: &__m256i) -> __m256i { + let (c_h, c_l) = mult_avx_128(a, b); + reduce_avx_128_64(&c_h, &c_l) +} + +// Multiply two 64bit numbers with the assumption that the product does not averflow. +#[inline] +pub unsafe fn mul64_no_overflow(a: &__m256i, b: &__m256i) -> __m256i { + let r = _mm256_mul_epu32(*a, *b); + let ah = _mm256_srli_epi64(*a, 32); + let bh = _mm256_srli_epi64(*b, 32); + let r1 = _mm256_mul_epu32(*a, bh); + let r1 = _mm256_slli_epi64(r1, 32); + let r = _mm256_add_epi64(r, r1); + let r1 = _mm256_mul_epu32(ah, *b); + let r1 = _mm256_slli_epi64(r1, 32); + let r = _mm256_add_epi64(r, r1); + r +} + +#[inline] +pub unsafe fn add64_no_carry(a: &__m256i, b: &__m256i) -> (__m256i, __m256i) { + /* + * a and b are signed 4 x i64. Suppose a and b represent only one i64, then: + * - (test 1): if a < 2^63 and b < 2^63 (this means a >= 0 and b >= 0) => sum does not overflow => cout = 0 + * - if a >= 2^63 and b >= 2^63 => sum overflows so sum = a + b and cout = 1 + * - (test 2): if (a < 2^63 and b >= 2^63) or (a >= 2^63 and b < 2^63) + * - (test 3): if a + b < 2^64 (this means a + b is negative in signed representation) => no overflow so cout = 0 + * - (test 3): if a + b >= 2^64 (this means a + b becomes positive in signed representation, that is, a + b >= 0) => there is overflow so cout = 1 + */ + let ones = _mm256_set_epi64x(1, 1, 1, 1); + let zeros = _mm256_set_epi64x(0, 0, 0, 0); + let r = _mm256_add_epi64(*a, *b); + let ma = _mm256_cmpgt_epi64(zeros, *a); + let mb = _mm256_cmpgt_epi64(zeros, *b); + let m1 = _mm256_and_si256(ma, mb); // test 1 + let m2 = _mm256_xor_si256(ma, mb); // test 2 + let m23 = _mm256_cmpgt_epi64(zeros, r); // test 3 + let m2 = _mm256_andnot_si256(m23, m2); + let m = _mm256_or_si256(m1, m2); + let co = _mm256_and_si256(m, ones); + (r, co) +} + +#[inline(always)] +pub fn sqr_avx_128(a: &__m256i) -> (__m256i, __m256i) { + unsafe { + let a_h = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(*a))); + let c_ll = _mm256_mul_epu32(*a, *a); + let c_lh = _mm256_mul_epu32(*a, a_h); + let c_hh = _mm256_mul_epu32(a_h, a_h); + let c_ll_hi = _mm256_srli_epi64(c_ll, 33); + let t0 = _mm256_add_epi64(c_lh, c_ll_hi); + let t0_hi = _mm256_srli_epi64(t0, 31); + let res_hi = _mm256_add_epi64(c_hh, t0_hi); + let c_lh_lo = _mm256_slli_epi64(c_lh, 33); + let res_lo = _mm256_add_epi64(c_ll, c_lh_lo); + (res_hi, res_lo) + } +} + +#[inline(always)] +pub fn sqr_avx(a: &__m256i) -> __m256i { + let (c_h, c_l) = sqr_avx_128(a); + reduce_avx_128_64(&c_h, &c_l) +} + +#[inline(always)] +pub fn sbox_avx(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) { + // x^2 + let p10 = sqr_avx(s0); + let p11 = sqr_avx(s1); + let p12 = sqr_avx(s2); + // x^3 + let p30 = mult_avx(&p10, s0); + let p31 = mult_avx(&p11, s1); + let p32 = mult_avx(&p12, s2); + // x^4 = (x^2)^2 + let p40 = sqr_avx(&p10); + let p41 = sqr_avx(&p11); + let p42 = sqr_avx(&p12); + // x^7 + *s0 = mult_avx(&p40, &p30); + *s1 = mult_avx(&p41, &p31); + *s2 = mult_avx(&p42, &p32); +} diff --git a/plonky2/src/hash/arch/x86_64/mod.rs b/plonky2/src/hash/arch/x86_64/mod.rs index 0730b62614..25eb0ed699 100644 --- a/plonky2/src/hash/arch/x86_64/mod.rs +++ b/plonky2/src/hash/arch/x86_64/mod.rs @@ -1,5 +1,6 @@ // // Requires: // // - AVX2 -// // - BMI2 (for MULX and SHRX) -// #[cfg(all(target_feature = "avx2", target_feature = "bmi2"))] -// pub(crate) mod poseidon_goldilocks_avx2_bmi2; +#[cfg(target_feature = "avx2")] +pub(crate) mod goldilocks_avx2; +#[cfg(target_feature = "avx2")] +pub(crate) mod poseidon_goldilocks_avx2; diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs new file mode 100644 index 0000000000..9814b79515 --- /dev/null +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs @@ -0,0 +1,1471 @@ +use core::arch::x86_64::*; + +use unroll::unroll_for_loops; + +use super::goldilocks_avx2::{add64_no_carry, mul64_no_overflow, mult_avx_128, reduce_avx_96_64}; +use crate::field::types::PrimeField64; +use crate::hash::arch::x86_64::goldilocks_avx2::{add_avx, mult_avx, reduce_avx_128_64, sbox_avx}; +use crate::hash::poseidon::{ + add_u160_u128, reduce_u160, Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, + N_PARTIAL_ROUNDS, SPONGE_WIDTH, +}; +use crate::hash::poseidon_goldilocks::poseidon12_mds::block2; + +#[allow(dead_code)] +const MDS_MATRIX_CIRC: [u64; 12] = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20]; + +#[allow(dead_code)] +const MDS_MATRIX_DIAG: [u64; 12] = [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + +const FAST_PARTIAL_FIRST_ROUND_CONSTANT: [u64; 12] = [ + 0x3cc3f892184df408, + 0xe993fd841e7e97f1, + 0xf2831d3575f0f3af, + 0xd2500e0a350994ca, + 0xc5571f35d7288633, + 0x91d89c5184109a02, + 0xf37f925d04e5667b, + 0x2d6e448371955a69, + 0x740ef19ce01398a1, + 0x694d24c0752fdf45, + 0x60936af96ee2f148, + 0xc33448feadc78f0c, +]; + +const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS] = [ + 0x74cb2e819ae421ab, + 0xd2559d2370e7f663, + 0x62bf78acf843d17c, + 0xd5ab7b67e14d1fb4, + 0xb9fe2ae6e0969bdc, + 0xe33fdf79f92a10e8, + 0x0ea2bb4c2b25989b, + 0xca9121fbf9d38f06, + 0xbdd9b0aa81f58fa4, + 0x83079fa4ecf20d7e, + 0x650b838edfcc4ad3, + 0x77180c88583c76ac, + 0xaf8c20753143a180, + 0xb8ccfe9989a39175, + 0x954a1729f60cc9c5, + 0xdeb5b550c4dca53b, + 0xf01bb0b00f77011e, + 0xa1ebb404b676afd9, + 0x860b6e1597a0173e, + 0x308bb65a036acbce, + 0x1aca78f31c97c876, + 0x0, +]; + +const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; 12]; 12] = [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + 0x80772dc2645b280b, + 0xdc927721da922cf8, + 0xc1978156516879ad, + 0x90e80c591f48b603, + 0x3a2432625475e3ae, + 0x00a2d4321cca94fe, + 0x77736f524010c932, + 0x904d3f2804a36c54, + 0xbf9b39e28a16f354, + 0x3a1ded54a6cd058b, + 0x42392870da5737cf, + ], + [ + 0, + 0xe796d293a47a64cb, + 0xb124c33152a2421a, + 0x0ee5dc0ce131268a, + 0xa9032a52f930fae6, + 0x7e33ca8c814280de, + 0xad11180f69a8c29e, + 0xc75ac6d5b5a10ff3, + 0xf0674a8dc5a387ec, + 0xb36d43120eaa5e2b, + 0x6f232aab4b533a25, + 0x3a1ded54a6cd058b, + ], + [ + 0, + 0xdcedab70f40718ba, + 0x14a4a64da0b2668f, + 0x4715b8e5ab34653b, + 0x1e8916a99c93a88e, + 0xbba4b5d86b9a3b2c, + 0xe76649f9bd5d5c2e, + 0xaf8e2518a1ece54d, + 0xdcda1344cdca873f, + 0xcd080204256088e5, + 0xb36d43120eaa5e2b, + 0xbf9b39e28a16f354, + ], + [ + 0, + 0xf4a437f2888ae909, + 0xc537d44dc2875403, + 0x7f68007619fd8ba9, + 0xa4911db6a32612da, + 0x2f7e9aade3fdaec1, + 0xe7ffd578da4ea43d, + 0x43a608e7afa6b5c2, + 0xca46546aa99e1575, + 0xdcda1344cdca873f, + 0xf0674a8dc5a387ec, + 0x904d3f2804a36c54, + ], + [ + 0, + 0xf97abba0dffb6c50, + 0x5e40f0c9bb82aab5, + 0x5996a80497e24a6b, + 0x07084430a7307c9a, + 0xad2f570a5b8545aa, + 0xab7f81fef4274770, + 0xcb81f535cf98c9e9, + 0x43a608e7afa6b5c2, + 0xaf8e2518a1ece54d, + 0xc75ac6d5b5a10ff3, + 0x77736f524010c932, + ], + [ + 0, + 0x7f8e41e0b0a6cdff, + 0x4b1ba8d40afca97d, + 0x623708f28fca70e8, + 0xbf150dc4914d380f, + 0xc26a083554767106, + 0x753b8b1126665c22, + 0xab7f81fef4274770, + 0xe7ffd578da4ea43d, + 0xe76649f9bd5d5c2e, + 0xad11180f69a8c29e, + 0x00a2d4321cca94fe, + ], + [ + 0, + 0x726af914971c1374, + 0x1d7f8a2cce1a9d00, + 0x18737784700c75cd, + 0x7fb45d605dd82838, + 0x862361aeab0f9b6e, + 0xc26a083554767106, + 0xad2f570a5b8545aa, + 0x2f7e9aade3fdaec1, + 0xbba4b5d86b9a3b2c, + 0x7e33ca8c814280de, + 0x3a2432625475e3ae, + ], + [ + 0, + 0x64dd936da878404d, + 0x4db9a2ead2bd7262, + 0xbe2e19f6d07f1a83, + 0x02290fe23c20351a, + 0x7fb45d605dd82838, + 0xbf150dc4914d380f, + 0x07084430a7307c9a, + 0xa4911db6a32612da, + 0x1e8916a99c93a88e, + 0xa9032a52f930fae6, + 0x90e80c591f48b603, + ], + [ + 0, + 0x85418a9fef8a9890, + 0xd8a2eb7ef5e707ad, + 0xbfe85ababed2d882, + 0xbe2e19f6d07f1a83, + 0x18737784700c75cd, + 0x623708f28fca70e8, + 0x5996a80497e24a6b, + 0x7f68007619fd8ba9, + 0x4715b8e5ab34653b, + 0x0ee5dc0ce131268a, + 0xc1978156516879ad, + ], + [ + 0, + 0x156048ee7a738154, + 0x91f7562377e81df5, + 0xd8a2eb7ef5e707ad, + 0x4db9a2ead2bd7262, + 0x1d7f8a2cce1a9d00, + 0x4b1ba8d40afca97d, + 0x5e40f0c9bb82aab5, + 0xc537d44dc2875403, + 0x14a4a64da0b2668f, + 0xb124c33152a2421a, + 0xdc927721da922cf8, + ], + [ + 0, + 0xd841e8ef9dde8ba0, + 0x156048ee7a738154, + 0x85418a9fef8a9890, + 0x64dd936da878404d, + 0x726af914971c1374, + 0x7f8e41e0b0a6cdff, + 0xf97abba0dffb6c50, + 0xf4a437f2888ae909, + 0xdcedab70f40718ba, + 0xe796d293a47a64cb, + 0x80772dc2645b280b, + ], +]; + +const FAST_PARTIAL_ROUND_W_HATS: [[u64; 12 - 1]; N_PARTIAL_ROUNDS] = [ + [ + 0x3d999c961b7c63b0, + 0x814e82efcd172529, + 0x2421e5d236704588, + 0x887af7d4dd482328, + 0xa5e9c291f6119b27, + 0xbdc52b2676a4b4aa, + 0x64832009d29bcf57, + 0x09c4155174a552cc, + 0x463f9ee03d290810, + 0xc810936e64982542, + 0x043b1c289f7bc3ac, + ], + [ + 0x673655aae8be5a8b, + 0xd510fe714f39fa10, + 0x2c68a099b51c9e73, + 0xa667bfa9aa96999d, + 0x4d67e72f063e2108, + 0xf84dde3e6acda179, + 0x40f9cc8c08f80981, + 0x5ead032050097142, + 0x6591b02092d671bb, + 0x00e18c71963dd1b7, + 0x8a21bcd24a14218a, + ], + [ + 0x202800f4addbdc87, + 0xe4b5bdb1cc3504ff, + 0xbe32b32a825596e7, + 0x8e0f68c5dc223b9a, + 0x58022d9e1c256ce3, + 0x584d29227aa073ac, + 0x8b9352ad04bef9e7, + 0xaead42a3f445ecbf, + 0x3c667a1d833a3cca, + 0xda6f61838efa1ffe, + 0xe8f749470bd7c446, + ], + [ + 0xc5b85bab9e5b3869, + 0x45245258aec51cf7, + 0x16e6b8e68b931830, + 0xe2ae0f051418112c, + 0x0470e26a0093a65b, + 0x6bef71973a8146ed, + 0x119265be51812daf, + 0xb0be7356254bea2e, + 0x8584defff7589bd7, + 0x3c5fe4aeb1fb52ba, + 0x9e7cd88acf543a5e, + ], + [ + 0x179be4bba87f0a8c, + 0xacf63d95d8887355, + 0x6696670196b0074f, + 0xd99ddf1fe75085f9, + 0xc2597881fef0283b, + 0xcf48395ee6c54f14, + 0x15226a8e4cd8d3b6, + 0xc053297389af5d3b, + 0x2c08893f0d1580e2, + 0x0ed3cbcff6fcc5ba, + 0xc82f510ecf81f6d0, + ], + [ + 0x94b06183acb715cc, + 0x500392ed0d431137, + 0x861cc95ad5c86323, + 0x05830a443f86c4ac, + 0x3b68225874a20a7c, + 0x10b3309838e236fb, + 0x9b77fc8bcd559e2c, + 0xbdecf5e0cb9cb213, + 0x30276f1221ace5fa, + 0x7935dd342764a144, + 0xeac6db520bb03708, + ], + [ + 0x7186a80551025f8f, + 0x622247557e9b5371, + 0xc4cbe326d1ad9742, + 0x55f1523ac6a23ea2, + 0xa13dfe77a3d52f53, + 0xe30750b6301c0452, + 0x08bd488070a3a32b, + 0xcd800caef5b72ae3, + 0x83329c90f04233ce, + 0xb5b99e6664a0a3ee, + 0x6b0731849e200a7f, + ], + [ + 0xec3fabc192b01799, + 0x382b38cee8ee5375, + 0x3bfb6c3f0e616572, + 0x514abd0cf6c7bc86, + 0x47521b1361dcc546, + 0x178093843f863d14, + 0xad1003c5d28918e7, + 0x738450e42495bc81, + 0xaf947c59af5e4047, + 0x4653fb0685084ef2, + 0x057fde2062ae35bf, + ], + [ + 0xe376678d843ce55e, + 0x66f3860d7514e7fc, + 0x7817f3dfff8b4ffa, + 0x3929624a9def725b, + 0x0126ca37f215a80a, + 0xfce2f5d02762a303, + 0x1bc927375febbad7, + 0x85b481e5243f60bf, + 0x2d3c5f42a39c91a0, + 0x0811719919351ae8, + 0xf669de0add993131, + ], + [ + 0x7de38bae084da92d, + 0x5b848442237e8a9b, + 0xf6c705da84d57310, + 0x31e6a4bdb6a49017, + 0x889489706e5c5c0f, + 0x0e4a205459692a1b, + 0xbac3fa75ee26f299, + 0x5f5894f4057d755e, + 0xb0dc3ecd724bb076, + 0x5e34d8554a6452ba, + 0x04f78fd8c1fdcc5f, + ], + [ + 0x4dd19c38779512ea, + 0xdb79ba02704620e9, + 0x92a29a3675a5d2be, + 0xd5177029fe495166, + 0xd32b3298a13330c1, + 0x251c4a3eb2c5f8fd, + 0xe1c48b26e0d98825, + 0x3301d3362a4ffccb, + 0x09bb6c88de8cd178, + 0xdc05b676564f538a, + 0x60192d883e473fee, + ], + [ + 0x16b9774801ac44a0, + 0x3cb8411e786d3c8e, + 0xa86e9cf505072491, + 0x0178928152e109ae, + 0x5317b905a6e1ab7b, + 0xda20b3be7f53d59f, + 0xcb97dedecebee9ad, + 0x4bd545218c59f58d, + 0x77dc8d856c05a44a, + 0x87948589e4f243fd, + 0x7e5217af969952c2, + ], + [ + 0xbc58987d06a84e4d, + 0x0b5d420244c9cae3, + 0xa3c4711b938c02c0, + 0x3aace640a3e03990, + 0x865a0f3249aacd8a, + 0x8d00b2a7dbed06c7, + 0x6eacb905beb7e2f8, + 0x045322b216ec3ec7, + 0xeb9de00d594828e6, + 0x088c5f20df9e5c26, + 0xf555f4112b19781f, + ], + [ + 0xa8cedbff1813d3a7, + 0x50dcaee0fd27d164, + 0xf1cb02417e23bd82, + 0xfaf322786e2abe8b, + 0x937a4315beb5d9b6, + 0x1b18992921a11d85, + 0x7d66c4368b3c497b, + 0x0e7946317a6b4e99, + 0xbe4430134182978b, + 0x3771e82493ab262d, + 0xa671690d8095ce82, + ], + [ + 0xb035585f6e929d9d, + 0xba1579c7e219b954, + 0xcb201cf846db4ba3, + 0x287bf9177372cf45, + 0xa350e4f61147d0a6, + 0xd5d0ecfb50bcff99, + 0x2e166aa6c776ed21, + 0xe1e66c991990e282, + 0x662b329b01e7bb38, + 0x8aa674b36144d9a9, + 0xcbabf78f97f95e65, + ], + [ + 0xeec24b15a06b53fe, + 0xc8a7aa07c5633533, + 0xefe9c6fa4311ad51, + 0xb9173f13977109a1, + 0x69ce43c9cc94aedc, + 0xecf623c9cd118815, + 0x28625def198c33c7, + 0xccfc5f7de5c3636a, + 0xf5e6c40f1621c299, + 0xcec0e58c34cb64b1, + 0xa868ea113387939f, + ], + [ + 0xd8dddbdc5ce4ef45, + 0xacfc51de8131458c, + 0x146bb3c0fe499ac0, + 0x9e65309f15943903, + 0x80d0ad980773aa70, + 0xf97817d4ddbf0607, + 0xe4626620a75ba276, + 0x0dfdc7fd6fc74f66, + 0xf464864ad6f2bb93, + 0x02d55e52a5d44414, + 0xdd8de62487c40925, + ], + [ + 0xc15acf44759545a3, + 0xcbfdcf39869719d4, + 0x33f62042e2f80225, + 0x2599c5ead81d8fa3, + 0x0b306cb6c1d7c8d0, + 0x658c80d3df3729b1, + 0xe8d1b2b21b41429c, + 0xa1b67f09d4b3ccb8, + 0x0e1adf8b84437180, + 0x0d593a5e584af47b, + 0xa023d94c56e151c7, + ], + [ + 0x49026cc3a4afc5a6, + 0xe06dff00ab25b91b, + 0x0ab38c561e8850ff, + 0x92c3c8275e105eeb, + 0xb65256e546889bd0, + 0x3c0468236ea142f6, + 0xee61766b889e18f2, + 0xa206f41b12c30415, + 0x02fe9d756c9f12d1, + 0xe9633210630cbf12, + 0x1ffea9fe85a0b0b1, + ], + [ + 0x81d1ae8cc50240f3, + 0xf4c77a079a4607d7, + 0xed446b2315e3efc1, + 0x0b0a6b70915178c3, + 0xb11ff3e089f15d9a, + 0x1d4dba0b7ae9cc18, + 0x65d74e2f43b48d05, + 0xa2df8c6b8ae0804a, + 0xa4e6f0a8c33348a6, + 0xc0a26efc7be5669b, + 0xa6b6582c547d0d60, + ], + [ + 0x84afc741f1c13213, + 0x2f8f43734fc906f3, + 0xde682d72da0a02d9, + 0x0bb005236adb9ef2, + 0x5bdf35c10a8b5624, + 0x0739a8a343950010, + 0x52f515f44785cfbc, + 0xcbaf4e5d82856c60, + 0xac9ea09074e3e150, + 0x8f0fa011a2035fb0, + 0x1a37905d8450904a, + ], + [ + 0x3abeb80def61cc85, + 0x9d19c9dd4eac4133, + 0x075a652d9641a985, + 0x9daf69ae1b67e667, + 0x364f71da77920a18, + 0x50bd769f745c95b1, + 0xf223d1180dbbf3fc, + 0x2f885e584e04aa99, + 0xb69a0fa70aea684a, + 0x09584acaa6e062a0, + 0x0bc051640145b19b, + ], +]; + +const FAST_PARTIAL_ROUND_VS: [[u64; 12]; N_PARTIAL_ROUNDS] = [ + [ + 0x0, + 0x94877900674181c3, + 0xc6c67cc37a2a2bbd, + 0xd667c2055387940f, + 0x0ba63a63e94b5ff0, + 0x99460cc41b8f079f, + 0x7ff02375ed524bb3, + 0xea0870b47a8caf0e, + 0xabcad82633b7bc9d, + 0x3b8d135261052241, + 0xfb4515f5e5b0d539, + 0x3ee8011c2b37f77c, + ], + [ + 0x0, + 0x0adef3740e71c726, + 0xa37bf67c6f986559, + 0xc6b16f7ed4fa1b00, + 0x6a065da88d8bfc3c, + 0x4cabc0916844b46f, + 0x407faac0f02e78d1, + 0x07a786d9cf0852cf, + 0x42433fb6949a629a, + 0x891682a147ce43b0, + 0x26cfd58e7b003b55, + 0x2bbf0ed7b657acb3, + ], + [ + 0x0, + 0x481ac7746b159c67, + 0xe367de32f108e278, + 0x73f260087ad28bec, + 0x5cfc82216bc1bdca, + 0xcaccc870a2663a0e, + 0xdb69cd7b4298c45d, + 0x7bc9e0c57243e62d, + 0x3cc51c5d368693ae, + 0x366b4e8cc068895b, + 0x2bd18715cdabbca4, + 0xa752061c4f33b8cf, + ], + [ + 0x0, + 0xb22d2432b72d5098, + 0x9e18a487f44d2fe4, + 0x4b39e14ce22abd3c, + 0x9e77fde2eb315e0d, + 0xca5e0385fe67014d, + 0x0c2cb99bf1b6bddb, + 0x99ec1cd2a4460bfe, + 0x8577a815a2ff843f, + 0x7d80a6b4fd6518a5, + 0xeb6c67123eab62cb, + 0x8f7851650eca21a5, + ], + [ + 0x0, + 0x11ba9a1b81718c2a, + 0x9f7d798a3323410c, + 0xa821855c8c1cf5e5, + 0x535e8d6fac0031b2, + 0x404e7c751b634320, + 0xa729353f6e55d354, + 0x4db97d92e58bb831, + 0xb53926c27897bf7d, + 0x965040d52fe115c5, + 0x9565fa41ebd31fd7, + 0xaae4438c877ea8f4, + ], + [ + 0x0, + 0x37f4e36af6073c6e, + 0x4edc0918210800e9, + 0xc44998e99eae4188, + 0x9f4310d05d068338, + 0x9ec7fe4350680f29, + 0xc5b2c1fdc0b50874, + 0xa01920c5ef8b2ebe, + 0x59fa6f8bd91d58ba, + 0x8bfc9eb89b515a82, + 0xbe86a7a2555ae775, + 0xcbb8bbaa3810babf, + ], + [ + 0x0, + 0x577f9a9e7ee3f9c2, + 0x88c522b949ace7b1, + 0x82f07007c8b72106, + 0x8283d37c6675b50e, + 0x98b074d9bbac1123, + 0x75c56fb7758317c1, + 0xfed24e206052bc72, + 0x26d7c3d1bc07dae5, + 0xf88c5e441e28dbb4, + 0x4fe27f9f96615270, + 0x514d4ba49c2b14fe, + ], + [ + 0x0, + 0xf02a3ac068ee110b, + 0x0a3630dafb8ae2d7, + 0xce0dc874eaf9b55c, + 0x9a95f6cff5b55c7e, + 0x626d76abfed00c7b, + 0xa0c1cf1251c204ad, + 0xdaebd3006321052c, + 0x3d4bd48b625a8065, + 0x7f1e584e071f6ed2, + 0x720574f0501caed3, + 0xe3260ba93d23540a, + ], + [ + 0x0, + 0xab1cbd41d8c1e335, + 0x9322ed4c0bc2df01, + 0x51c3c0983d4284e5, + 0x94178e291145c231, + 0xfd0f1a973d6b2085, + 0xd427ad96e2b39719, + 0x8a52437fecaac06b, + 0xdc20ee4b8c4c9a80, + 0xa2c98e9549da2100, + 0x1603fe12613db5b6, + 0x0e174929433c5505, + ], + [ + 0x0, + 0x3d4eab2b8ef5f796, + 0xcfff421583896e22, + 0x4143cb32d39ac3d9, + 0x22365051b78a5b65, + 0x6f7fd010d027c9b6, + 0xd9dd36fba77522ab, + 0xa44cf1cb33e37165, + 0x3fc83d3038c86417, + 0xc4588d418e88d270, + 0xce1320f10ab80fe2, + 0xdb5eadbbec18de5d, + ], + [ + 0x0, + 0x1183dfce7c454afd, + 0x21cea4aa3d3ed949, + 0x0fce6f70303f2304, + 0x19557d34b55551be, + 0x4c56f689afc5bbc9, + 0xa1e920844334f944, + 0xbad66d423d2ec861, + 0xf318c785dc9e0479, + 0x99e2032e765ddd81, + 0x400ccc9906d66f45, + 0xe1197454db2e0dd9, + ], + [ + 0x0, + 0x84d1ecc4d53d2ff1, + 0xd8af8b9ceb4e11b6, + 0x335856bb527b52f4, + 0xc756f17fb59be595, + 0xc0654e4ea5553a78, + 0x9e9a46b61f2ea942, + 0x14fc8b5b3b809127, + 0xd7009f0f103be413, + 0x3e0ee7b7a9fb4601, + 0xa74e888922085ed7, + 0xe80a7cde3d4ac526, + ], + [ + 0x0, + 0x238aa6daa612186d, + 0x9137a5c630bad4b4, + 0xc7db3817870c5eda, + 0x217e4f04e5718dc9, + 0xcae814e2817bd99d, + 0xe3292e7ab770a8ba, + 0x7bb36ef70b6b9482, + 0x3c7835fb85bca2d3, + 0xfe2cdf8ee3c25e86, + 0x61b3915ad7274b20, + 0xeab75ca7c918e4ef, + ], + [ + 0x0, + 0xd6e15ffc055e154e, + 0xec67881f381a32bf, + 0xfbb1196092bf409c, + 0xdc9d2e07830ba226, + 0x0698ef3245ff7988, + 0x194fae2974f8b576, + 0x7a5d9bea6ca4910e, + 0x7aebfea95ccdd1c9, + 0xf9bd38a67d5f0e86, + 0xfa65539de65492d8, + 0xf0dfcbe7653ff787, + ], + [ + 0x0, + 0x0bd87ad390420258, + 0x0ad8617bca9e33c8, + 0x0c00ad377a1e2666, + 0x0ac6fc58b3f0518f, + 0x0c0cc8a892cc4173, + 0x0c210accb117bc21, + 0x0b73630dbb46ca18, + 0x0c8be4920cbd4a54, + 0x0bfe877a21be1690, + 0x0ae790559b0ded81, + 0x0bf50db2f8d6ce31, + ], + [ + 0x0, + 0x000cf29427ff7c58, + 0x000bd9b3cf49eec8, + 0x000d1dc8aa81fb26, + 0x000bc792d5c394ef, + 0x000d2ae0b2266453, + 0x000d413f12c496c1, + 0x000c84128cfed618, + 0x000db5ebd48fc0d4, + 0x000d1b77326dcb90, + 0x000beb0ccc145421, + 0x000d10e5b22b11d1, + ], + [ + 0x0, + 0x00000e24c99adad8, + 0x00000cf389ed4bc8, + 0x00000e580cbf6966, + 0x00000cde5fd7e04f, + 0x00000e63628041b3, + 0x00000e7e81a87361, + 0x00000dabe78f6d98, + 0x00000efb14cac554, + 0x00000e5574743b10, + 0x00000d05709f42c1, + 0x00000e4690c96af1, + ], + [ + 0x0, + 0x0000000f7157bc98, + 0x0000000e3006d948, + 0x0000000fa65811e6, + 0x0000000e0d127e2f, + 0x0000000fc18bfe53, + 0x0000000fd002d901, + 0x0000000eed6461d8, + 0x0000001068562754, + 0x0000000fa0236f50, + 0x0000000e3af13ee1, + 0x0000000fa460f6d1, + ], + [ + 0x0, + 0x0000000011131738, + 0x000000000f56d588, + 0x0000000011050f86, + 0x000000000f848f4f, + 0x00000000111527d3, + 0x00000000114369a1, + 0x00000000106f2f38, + 0x0000000011e2ca94, + 0x00000000110a29f0, + 0x000000000fa9f5c1, + 0x0000000010f625d1, + ], + [ + 0x0, + 0x000000000011f718, + 0x000000000010b6c8, + 0x0000000000134a96, + 0x000000000010cf7f, + 0x0000000000124d03, + 0x000000000013f8a1, + 0x0000000000117c58, + 0x0000000000132c94, + 0x0000000000134fc0, + 0x000000000010a091, + 0x0000000000128961, + ], + [ + 0x0, + 0x0000000000001300, + 0x0000000000001750, + 0x000000000000114e, + 0x000000000000131f, + 0x000000000000167b, + 0x0000000000001371, + 0x0000000000001230, + 0x000000000000182c, + 0x0000000000001368, + 0x0000000000000f31, + 0x00000000000015c9, + ], + [ + 0x0, + 0x0000000000000014, + 0x0000000000000022, + 0x0000000000000012, + 0x0000000000000027, + 0x000000000000000d, + 0x000000000000000d, + 0x000000000000001c, + 0x0000000000000002, + 0x0000000000000010, + 0x0000000000000029, + 0x000000000000000f, + ], +]; + +const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 32, 16]; +const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(2, -1), (-4, 1), (16, 1)]; +const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-1, -8, 2]; + +#[allow(dead_code)] +#[inline(always)] +#[unroll_for_loops] +fn mds_row_shf(r: usize, v: &[u64; SPONGE_WIDTH]) -> (u64, u64) { + let mut res = 0u128; + + // This is a hacky way of fully unrolling the loop. + for i in 0..12 { + if i < SPONGE_WIDTH { + res += (v[(i + r) % SPONGE_WIDTH] as u128) * (MDS_MATRIX_CIRC[i] as u128); + } + } + res += (v[r] as u128) * (MDS_MATRIX_DIAG[r] as u128); + + ((res >> 64) as u64, res as u64) +} + +#[allow(dead_code)] +#[inline(always)] +#[unroll_for_loops] +unsafe fn mds_layer_avx_v1( + s0: &__m256i, + s1: &__m256i, + s2: &__m256i, +) -> (__m256i, __m256i, __m256i) { + let mut st64 = [0u64; SPONGE_WIDTH]; + + _mm256_storeu_si256((&mut st64[0..4]).as_mut_ptr().cast::<__m256i>(), *s0); + _mm256_storeu_si256((&mut st64[4..8]).as_mut_ptr().cast::<__m256i>(), *s1); + _mm256_storeu_si256((&mut st64[8..12]).as_mut_ptr().cast::<__m256i>(), *s2); + + let mut sumh: [u64; 12] = [0; 12]; + let mut suml: [u64; 12] = [0; 12]; + for r in 0..12 { + if r < SPONGE_WIDTH { + (sumh[r], suml[r]) = mds_row_shf(r, &st64); + } + } + + let ss0h = _mm256_loadu_si256((&sumh[0..4]).as_ptr().cast::<__m256i>()); + let ss0l = _mm256_loadu_si256((&suml[0..4]).as_ptr().cast::<__m256i>()); + let ss1h = _mm256_loadu_si256((&sumh[4..8]).as_ptr().cast::<__m256i>()); + let ss1l = _mm256_loadu_si256((&suml[4..8]).as_ptr().cast::<__m256i>()); + let ss2h = _mm256_loadu_si256((&sumh[8..12]).as_ptr().cast::<__m256i>()); + let ss2l = _mm256_loadu_si256((&suml[8..12]).as_ptr().cast::<__m256i>()); + let r0 = reduce_avx_128_64(&ss0h, &ss0l); + let r1 = reduce_avx_128_64(&ss1h, &ss1l); + let r2 = reduce_avx_128_64(&ss2h, &ss2l); + + (r0, r1, r2) +} + +#[allow(dead_code)] +#[inline(always)] +#[unroll_for_loops] +unsafe fn mds_layer_avx_v2( + s0: &__m256i, + s1: &__m256i, + s2: &__m256i, +) -> (__m256i, __m256i, __m256i) +where + F: PrimeField64, +{ + let mut st64 = [0u64; SPONGE_WIDTH]; + + _mm256_storeu_si256((&mut st64[0..4]).as_mut_ptr().cast::<__m256i>(), *s0); + _mm256_storeu_si256((&mut st64[4..8]).as_mut_ptr().cast::<__m256i>(), *s1); + _mm256_storeu_si256((&mut st64[8..12]).as_mut_ptr().cast::<__m256i>(), *s2); + + let mut result = [F::ZERO; SPONGE_WIDTH]; + // This is a hacky way of fully unrolling the loop. + for r in 0..12 { + if r < SPONGE_WIDTH { + let (sum_hi, sum_lo) = mds_row_shf(r, &st64); + result[r] = F::from_noncanonical_u96((sum_lo, sum_hi.try_into().unwrap())); + } + } + + let r0 = _mm256_loadu_si256((&result[0..4]).as_ptr().cast::<__m256i>()); + let r1 = _mm256_loadu_si256((&result[4..8]).as_ptr().cast::<__m256i>()); + let r2 = _mm256_loadu_si256((&result[8..12]).as_ptr().cast::<__m256i>()); + + (r0, r1, r2) +} + +#[inline(always)] +unsafe fn block1_avx(x: &__m256i, y: [i64; 3]) -> __m256i { + let x0 = _mm256_permute4x64_epi64(*x, 0x0); + let x1 = _mm256_permute4x64_epi64(*x, 0x55); + let x2 = _mm256_permute4x64_epi64(*x, 0xAA); + + let f0 = _mm256_set_epi64x(0, y[2], y[1], y[0]); + let f1 = _mm256_set_epi64x(0, y[1], y[0], y[2]); + let f2 = _mm256_set_epi64x(0, y[0], y[2], y[1]); + + let t0 = mul64_no_overflow(&x0, &f0); + let t1 = mul64_no_overflow(&x1, &f1); + let t2 = mul64_no_overflow(&x2, &f2); + + let t0 = _mm256_add_epi64(t0, t1); + _mm256_add_epi64(t0, t2) +} + +#[allow(dead_code)] +#[inline(always)] +unsafe fn block2_full_avx(xr: &__m256i, xi: &__m256i, y: [(i64, i64); 3]) -> (__m256i, __m256i) { + let yr = _mm256_set_epi64x(0, y[2].0, y[1].0, y[0].0); + let yi = _mm256_set_epi64x(0, y[2].1, y[1].1, y[0].1); + let ys = _mm256_add_epi64(yr, yi); + let xs = _mm256_add_epi64(*xr, *xi); + + // z0 + // z0r = dif2[0] + prod[1] - sum[1] + prod[2] - sum[2] + // z0i = prod[0] - sum[0] + dif1[1] + dif1[2] + let yy = _mm256_permute4x64_epi64(yr, 0x18); + let mr_z0 = mul64_no_overflow(xr, &yy); + let yy = _mm256_permute4x64_epi64(yi, 0x18); + let mi_z0 = mul64_no_overflow(xi, &yy); + let sum = _mm256_add_epi64(mr_z0, mi_z0); + let dif1 = _mm256_sub_epi64(mi_z0, mr_z0); + let dif2 = _mm256_sub_epi64(mr_z0, mi_z0); + let yy = _mm256_permute4x64_epi64(ys, 0x18); + let prod = mul64_no_overflow(&xs, &yy); + let dif3 = _mm256_sub_epi64(prod, sum); + let dif3perm1 = _mm256_permute4x64_epi64(dif3, 0x1); + let dif3perm2 = _mm256_permute4x64_epi64(dif3, 0x2); + let z0r = _mm256_add_epi64(dif2, dif3perm1); + let z0r = _mm256_add_epi64(z0r, dif3perm2); + let dif1perm1 = _mm256_permute4x64_epi64(dif1, 0x1); + let dif1perm2 = _mm256_permute4x64_epi64(dif1, 0x2); + let z0i = _mm256_add_epi64(dif3, dif1perm1); + let z0i = _mm256_add_epi64(z0i, dif1perm2); + let mask = _mm256_set_epi64x(0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64); + let z0r = _mm256_and_si256(z0r, mask); + let z0i = _mm256_and_si256(z0i, mask); + + // z1 + // z1r = dif2[0] + dif2[1] + prod[2] - sum[2]; + // z1i = prod[0] - sum[0] + prod[1] - sum[1] + dif1[2]; + let yy = _mm256_permute4x64_epi64(yr, 0x21); + let mr_z1 = mul64_no_overflow(xr, &yy); + let yy = _mm256_permute4x64_epi64(yi, 0x21); + let mi_z1 = mul64_no_overflow(xi, &yy); + let sum = _mm256_add_epi64(mr_z1, mi_z1); + let dif1 = _mm256_sub_epi64(mi_z1, mr_z1); + let dif2 = _mm256_sub_epi64(mr_z1, mi_z1); + let yy = _mm256_permute4x64_epi64(ys, 0x21); + let prod = mul64_no_overflow(&xs, &yy); + let dif3 = _mm256_sub_epi64(prod, sum); + let dif2perm = _mm256_permute4x64_epi64(dif2, 0x0); + let dif3perm = _mm256_permute4x64_epi64(dif3, 0x8); + let z1r = _mm256_add_epi64(dif2, dif2perm); + let z1r = _mm256_add_epi64(z1r, dif3perm); + let dif3perm = _mm256_permute4x64_epi64(dif3, 0x0); + let dif1perm = _mm256_permute4x64_epi64(dif1, 0x8); + let z1i = _mm256_add_epi64(dif3, dif3perm); + let z1i = _mm256_add_epi64(z1i, dif1perm); + let mask = _mm256_set_epi64x(0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0); + let z1r = _mm256_and_si256(z1r, mask); + let z1i = _mm256_and_si256(z1i, mask); + + // z2 + // z2r = dif2[0] + dif2[1] + dif2[2]; + // z2i = prod[0] - sum[0] + prod[1] - sum[1] + prod[2] - sum[2] + let yy = _mm256_permute4x64_epi64(yr, 0x6); + let mr_z2 = mul64_no_overflow(xr, &yy); + let yy = _mm256_permute4x64_epi64(yi, 0x6); + let mi_z2 = mul64_no_overflow(xi, &yy); + let sum = _mm256_add_epi64(mr_z2, mi_z2); + let dif2 = _mm256_sub_epi64(mr_z2, mi_z2); + let yy = _mm256_permute4x64_epi64(ys, 0x6); + let prod = mul64_no_overflow(&xs, &yy); + let dif3 = _mm256_sub_epi64(prod, sum); + let dif2perm1 = _mm256_permute4x64_epi64(dif2, 0x0); + let dif2perm2 = _mm256_permute4x64_epi64(dif2, 0x10); + let z2r = _mm256_add_epi64(dif2, dif2perm1); + let z2r = _mm256_add_epi64(z2r, dif2perm2); + let dif3perm1 = _mm256_permute4x64_epi64(dif3, 0x0); + let dif3perm2 = _mm256_permute4x64_epi64(dif3, 0x10); + let z2i = _mm256_add_epi64(dif3, dif3perm1); + let z2i = _mm256_add_epi64(z2i, dif3perm2); + let mask = _mm256_set_epi64x(0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0); + let z2r = _mm256_and_si256(z2r, mask); + let z2i = _mm256_and_si256(z2i, mask); + + let zr = _mm256_or_si256(z0r, z1r); + let zr = _mm256_or_si256(zr, z2r); + let zi = _mm256_or_si256(z0i, z1i); + let zi = _mm256_or_si256(zi, z2i); + (zr, zi) +} + +#[inline(always)] +unsafe fn block2_avx(xr: &__m256i, xi: &__m256i, y: [(i64, i64); 3]) -> (__m256i, __m256i) { + let mut vxr: [i64; 4] = [0; 4]; + let mut vxi: [i64; 4] = [0; 4]; + _mm256_storeu_si256(vxr.as_mut_ptr().cast::<__m256i>(), *xr); + _mm256_storeu_si256(vxi.as_mut_ptr().cast::<__m256i>(), *xi); + let x: [(i64, i64); 3] = [(vxr[0], vxi[0]), (vxr[1], vxi[1]), (vxr[2], vxi[2])]; + let b = block2(x, y); + vxr = [b[0].0, b[1].0, b[2].0, 0]; + vxi = [b[0].1, b[1].1, b[2].1, 0]; + let rr = _mm256_loadu_si256(vxr.as_ptr().cast::<__m256i>()); + let ri = _mm256_loadu_si256(vxi.as_ptr().cast::<__m256i>()); + (rr, ri) +} + +#[inline(always)] +unsafe fn block3_avx(x: &__m256i, y: [i64; 3]) -> __m256i { + let x0 = _mm256_permute4x64_epi64(*x, 0x0); + let x1 = _mm256_permute4x64_epi64(*x, 0x55); + let x2 = _mm256_permute4x64_epi64(*x, 0xAA); + + let f0 = _mm256_set_epi64x(0, y[2], y[1], y[0]); + let f1 = _mm256_set_epi64x(0, y[1], y[0], -y[2]); + let f2 = _mm256_set_epi64x(0, y[0], -y[2], -y[1]); + + let t0 = mul64_no_overflow(&x0, &f0); + let t1 = mul64_no_overflow(&x1, &f1); + let t2 = mul64_no_overflow(&x2, &f2); + + let t0 = _mm256_add_epi64(t0, t1); + _mm256_add_epi64(t0, t2) +} + +#[inline(always)] +unsafe fn fft2_real_avx(x0: &__m256i, x1: &__m256i) -> (__m256i, __m256i) { + let y0 = _mm256_add_epi64(*x0, *x1); + let y1 = _mm256_sub_epi64(*x0, *x1); + (y0, y1) +} + +#[inline(always)] +unsafe fn fft4_real_avx( + x0: &__m256i, + x1: &__m256i, + x2: &__m256i, + x3: &__m256i, +) -> (__m256i, __m256i, __m256i, __m256i) { + let zeros = _mm256_set_epi64x(0, 0, 0, 0); + let (z0, z2) = fft2_real_avx(x0, x2); + let (z1, z3) = fft2_real_avx(x1, x3); + let y0 = _mm256_add_epi64(z0, z1); + let y2 = _mm256_sub_epi64(z0, z1); + let y3 = _mm256_sub_epi64(zeros, z3); + (y0, z2, y3, y2) +} + +#[inline(always)] +unsafe fn ifft2_real_unreduced_avx(y0: &__m256i, y1: &__m256i) -> (__m256i, __m256i) { + let x0 = _mm256_add_epi64(*y0, *y1); + let x1 = _mm256_sub_epi64(*y0, *y1); + (x0, x1) +} + +#[inline(always)] +unsafe fn ifft4_real_unreduced_avx( + y: (__m256i, (__m256i, __m256i), __m256i), +) -> (__m256i, __m256i, __m256i, __m256i) { + let zeros = _mm256_set_epi64x(0, 0, 0, 0); + let z0 = _mm256_add_epi64(y.0, y.2); + let z1 = _mm256_sub_epi64(y.0, y.2); + let z2 = y.1 .0; + let z3 = _mm256_sub_epi64(zeros, y.1 .1); + let (x0, x2) = ifft2_real_unreduced_avx(&z0, &z2); + let (x1, x3) = ifft2_real_unreduced_avx(&z1, &z3); + (x0, x1, x2, x3) +} + +#[inline] +unsafe fn mds_multiply_freq_avx(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) { + /* + // Alternative code using store and set. + let mut s: [i64; 12] = [0; 12]; + _mm256_storeu_si256(s[0..4].as_mut_ptr().cast::<__m256i>(), *s0); + _mm256_storeu_si256(s[4..8].as_mut_ptr().cast::<__m256i>(), *s1); + _mm256_storeu_si256(s[8..12].as_mut_ptr().cast::<__m256i>(), *s2); + let f0 = _mm256_set_epi64x(0, s[2], s[1], s[0]); + let f1 = _mm256_set_epi64x(0, s[5], s[4], s[3]); + let f2 = _mm256_set_epi64x(0, s[8], s[7], s[6]); + let f3 = _mm256_set_epi64x(0, s[11], s[10], s[9]); + */ + + // Alternative code using permute and blend (it is faster). + let f0 = *s0; + let f11 = _mm256_permute4x64_epi64(*s0, 0x3); + let f12 = _mm256_permute4x64_epi64(*s1, 0x10); + let f1 = _mm256_blend_epi32(f11, f12, 0x3C); + let f21 = _mm256_permute4x64_epi64(*s1, 0xE); + let f22 = _mm256_permute4x64_epi64(*s2, 0x0); + let f2 = _mm256_blend_epi32(f21, f22, 0x30); + let f3 = _mm256_permute4x64_epi64(*s2, 0x39); + + let (u0, u1, u2, u3) = fft4_real_avx(&f0, &f1, &f2, &f3); + + // let [v0, v4, v8] = block1_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); + // [u[0], u[1], u[2]] are all in u0 + let f0 = block1_avx(&u0, MDS_FREQ_BLOCK_ONE); + + // let [v1, v5, v9] = block2([(u[0], v[0]), (u[1], v[1]), (u[2], v[2])], MDS_FREQ_BLOCK_TWO); + let (f1, f2) = block2_avx(&u1, &u2, MDS_FREQ_BLOCK_TWO); + + // let [v2, v6, v10] = block3_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); + // [u[0], u[1], u[2]] are all in u3 + let f3 = block3_avx(&u3, MDS_FREQ_BLOCK_THREE); + + let (r0, r3, r6, r9) = ifft4_real_unreduced_avx((f0, (f1, f2), f3)); + let t = _mm256_permute4x64_epi64(r3, 0x0); + *s0 = _mm256_blend_epi32(r0, t, 0xC0); + let t1 = _mm256_permute4x64_epi64(r3, 0x9); + let t2 = _mm256_permute4x64_epi64(r6, 0x40); + *s1 = _mm256_blend_epi32(t1, t2, 0xF0); + let t1 = _mm256_permute4x64_epi64(r6, 0x2); + let t2 = _mm256_permute4x64_epi64(r9, 0x90); + *s2 = _mm256_blend_epi32(t1, t2, 0xFC); +} + +#[allow(dead_code)] +#[inline(always)] +#[unroll_for_loops] +unsafe fn mds_layer_avx(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) { + let mask = _mm256_set_epi64x(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF); + let mut sl0 = _mm256_and_si256(*s0, mask); + let mut sl1 = _mm256_and_si256(*s1, mask); + let mut sl2 = _mm256_and_si256(*s2, mask); + let mut sh0 = _mm256_srli_epi64(*s0, 32); + let mut sh1 = _mm256_srli_epi64(*s1, 32); + let mut sh2 = _mm256_srli_epi64(*s2, 32); + + mds_multiply_freq_avx(&mut sl0, &mut sl1, &mut sl2); + mds_multiply_freq_avx(&mut sh0, &mut sh1, &mut sh2); + + let shl0 = _mm256_slli_epi64(sh0, 32); + let shl1 = _mm256_slli_epi64(sh1, 32); + let shl2 = _mm256_slli_epi64(sh2, 32); + let shh0 = _mm256_srli_epi64(sh0, 32); + let shh1 = _mm256_srli_epi64(sh1, 32); + let shh2 = _mm256_srli_epi64(sh2, 32); + + let (rl0, c0) = add64_no_carry(&sl0, &shl0); + let (rh0, _) = add64_no_carry(&shh0, &c0); + let r0 = reduce_avx_128_64(&rh0, &rl0); + + let (rl1, c1) = add64_no_carry(&sl1, &shl1); + let (rh1, _) = add64_no_carry(&shh1, &c1); + *s1 = reduce_avx_128_64(&rh1, &rl1); + + let (rl2, c2) = add64_no_carry(&sl2, &shl2); + let (rh2, _) = add64_no_carry(&shh2, &c2); + *s2 = reduce_avx_128_64(&rh2, &rl2); + + let rl = _mm256_slli_epi64(*s0, 3); // * 8 (low part) + let rh = _mm256_srli_epi64(*s0, 61); // * 8 (high part, only 3 bits) + let rx = reduce_avx_96_64(&rh, &rl); + let rx = add_avx(&r0, &rx); + *s0 = _mm256_blend_epi32(r0, rx, 0x3); +} + +#[allow(dead_code)] +#[inline(always)] +#[unroll_for_loops] +fn mds_partial_layer_init_avx(state: &mut [F; SPONGE_WIDTH]) +where + F: PrimeField64, +{ + let mut result = [F::ZERO; SPONGE_WIDTH]; + let res0 = state[0]; + unsafe { + let mut r0 = _mm256_loadu_si256((&mut result[0..4]).as_mut_ptr().cast::<__m256i>()); + let mut r1 = _mm256_loadu_si256((&mut result[0..4]).as_mut_ptr().cast::<__m256i>()); + let mut r2 = _mm256_loadu_si256((&mut result[0..4]).as_mut_ptr().cast::<__m256i>()); + for r in 1..12 { + let sr = _mm256_set_epi64x( + state[r].to_canonical_u64() as i64, + state[r].to_canonical_u64() as i64, + state[r].to_canonical_u64() as i64, + state[r].to_canonical_u64() as i64, + ); + let t0 = _mm256_loadu_si256( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][0..4]) + .as_ptr() + .cast::<__m256i>(), + ); + let t1 = _mm256_loadu_si256( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][4..8]) + .as_ptr() + .cast::<__m256i>(), + ); + let t2 = _mm256_loadu_si256( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][8..12]) + .as_ptr() + .cast::<__m256i>(), + ); + let m0 = mult_avx(&sr, &t0); + let m1 = mult_avx(&sr, &t1); + let m2 = mult_avx(&sr, &t2); + r0 = add_avx(&r0, &m0); + r1 = add_avx(&r1, &m1); + r2 = add_avx(&r2, &m2); + } + _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), r0); + _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), r1); + _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), r2); + state[0] = res0; + } +} + +#[inline(always)] +#[unroll_for_loops] +unsafe fn mds_partial_layer_fast_avx( + s0: &mut __m256i, + s1: &mut __m256i, + s2: &mut __m256i, + state: &mut [F; SPONGE_WIDTH], + r: usize, +) where + F: PrimeField64, +{ + let mut d_sum = (0u128, 0u32); // u160 accumulator + for i in 1..12 { + if i < SPONGE_WIDTH { + let t = FAST_PARTIAL_ROUND_W_HATS[r][i - 1] as u128; + let si = state[i].to_noncanonical_u64() as u128; + d_sum = add_u160_u128(d_sum, si * t); + } + } + let x0 = state[0].to_noncanonical_u64() as u128; + let mds0to0 = (MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0]) as u128; + d_sum = add_u160_u128(d_sum, x0 * mds0to0); + let d = reduce_u160::(d_sum); + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let ss0 = _mm256_set_epi64x( + state[0].to_noncanonical_u64() as i64, + state[0].to_noncanonical_u64() as i64, + state[0].to_noncanonical_u64() as i64, + state[0].to_noncanonical_u64() as i64, + ); + let rc0 = _mm256_loadu_si256((&FAST_PARTIAL_ROUND_VS[r][0..4]).as_ptr().cast::<__m256i>()); + let rc1 = _mm256_loadu_si256((&FAST_PARTIAL_ROUND_VS[r][4..8]).as_ptr().cast::<__m256i>()); + let rc2 = _mm256_loadu_si256( + (&FAST_PARTIAL_ROUND_VS[r][8..12]) + .as_ptr() + .cast::<__m256i>(), + ); + let (mh, ml) = mult_avx_128(&ss0, &rc0); + let m = reduce_avx_128_64(&mh, &ml); + let r0 = add_avx(s0, &m); + let d0 = _mm256_set_epi64x(0, 0, 0, d.to_canonical_u64() as i64); + *s0 = _mm256_blend_epi32(r0, d0, 0x3); + + let (mh, ml) = mult_avx_128(&ss0, &rc1); + let m = reduce_avx_128_64(&mh, &ml); + *s1 = add_avx(s1, &m); + + let (mh, ml) = mult_avx_128(&ss0, &rc2); + let m = reduce_avx_128_64(&mh, &ml); + *s2 = add_avx(s2, &m); + + _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), *s0); + _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), *s1); + _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), *s2); +} + +#[inline(always)] +#[unroll_for_loops] +unsafe fn mds_partial_layer_init_avx_m256i(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) +where + F: PrimeField64, +{ + let mut result = [F::ZERO; SPONGE_WIDTH]; + let res0 = *s0; + + let mut r0 = _mm256_loadu_si256((&mut result[0..4]).as_mut_ptr().cast::<__m256i>()); + let mut r1 = _mm256_loadu_si256((&mut result[0..4]).as_mut_ptr().cast::<__m256i>()); + let mut r2 = _mm256_loadu_si256((&mut result[0..4]).as_mut_ptr().cast::<__m256i>()); + for r in 1..12 { + let sr = match r { + 1 => _mm256_permute4x64_epi64(*s0, 0x55), + 2 => _mm256_permute4x64_epi64(*s0, 0xAA), + 3 => _mm256_permute4x64_epi64(*s0, 0xFF), + 4 => _mm256_permute4x64_epi64(*s1, 0x0), + 5 => _mm256_permute4x64_epi64(*s1, 0x55), + 6 => _mm256_permute4x64_epi64(*s1, 0xAA), + 7 => _mm256_permute4x64_epi64(*s1, 0xFF), + 8 => _mm256_permute4x64_epi64(*s2, 0x0), + 9 => _mm256_permute4x64_epi64(*s2, 0x55), + 10 => _mm256_permute4x64_epi64(*s2, 0xAA), + 11 => _mm256_permute4x64_epi64(*s2, 0xFF), + _ => _mm256_permute4x64_epi64(*s0, 0x55), + }; + let t0 = _mm256_loadu_si256( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][0..4]) + .as_ptr() + .cast::<__m256i>(), + ); + let t1 = _mm256_loadu_si256( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][4..8]) + .as_ptr() + .cast::<__m256i>(), + ); + let t2 = _mm256_loadu_si256( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][8..12]) + .as_ptr() + .cast::<__m256i>(), + ); + let m0 = mult_avx(&sr, &t0); + let m1 = mult_avx(&sr, &t1); + let m2 = mult_avx(&sr, &t2); + r0 = add_avx(&r0, &m0); + r1 = add_avx(&r1, &m1); + r2 = add_avx(&r2, &m2); + } + *s0 = _mm256_blend_epi32(r0, res0, 0x3); + *s1 = r1; + *s2 = r2; +} + +#[allow(dead_code)] +#[inline(always)] +#[unroll_for_loops] +fn partial_first_constant_layer_avx(state: &mut [F; SPONGE_WIDTH]) +where + F: PrimeField64, +{ + unsafe { + let c0 = _mm256_loadu_si256( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..4]) + .as_ptr() + .cast::<__m256i>(), + ); + let c1 = _mm256_loadu_si256( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[4..8]) + .as_ptr() + .cast::<__m256i>(), + ); + let c2 = _mm256_loadu_si256( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[8..12]) + .as_ptr() + .cast::<__m256i>(), + ); + + let mut s0 = _mm256_loadu_si256((state[0..4]).as_ptr().cast::<__m256i>()); + let mut s1 = _mm256_loadu_si256((state[4..8]).as_ptr().cast::<__m256i>()); + let mut s2 = _mm256_loadu_si256((state[8..12]).as_ptr().cast::<__m256i>()); + s0 = add_avx(&s0, &c0); + s1 = add_avx(&s1, &c1); + s2 = add_avx(&s2, &c2); + _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), s0); + _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), s1); + _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), s2); + } +} + +#[inline(always)] +fn sbox_monomial(x: F) -> F +where + F: PrimeField64, +{ + // x |--> x^7 + let x2 = x.square(); + let x4 = x2.square(); + let x3 = x * x2; + x3 * x4 +} + +pub fn poseidon_avx(input: &[F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] +where + F: PrimeField64 + Poseidon, +{ + let mut state = &mut input.clone(); + let mut round_ctr = 0; + + unsafe { + // load state + let mut s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()); + let mut s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()); + let mut s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()); + + for _ in 0..HALF_N_FULL_ROUNDS { + let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] + .try_into() + .unwrap(); + let rc0 = _mm256_loadu_si256((&rc[0..4]).as_ptr().cast::<__m256i>()); + let rc1 = _mm256_loadu_si256((&rc[4..8]).as_ptr().cast::<__m256i>()); + let rc2 = _mm256_loadu_si256((&rc[8..12]).as_ptr().cast::<__m256i>()); + s0 = add_avx(&s0, &rc0); + s1 = add_avx(&s1, &rc1); + s2 = add_avx(&s2, &rc2); + sbox_avx(&mut s0, &mut s1, &mut s2); + mds_layer_avx(&mut s0, &mut s1, &mut s2); + round_ctr += 1; + } + + // this does partial_first_constant_layer_avx(&mut state); + let c0 = _mm256_loadu_si256( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..4]) + .as_ptr() + .cast::<__m256i>(), + ); + let c1 = _mm256_loadu_si256( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[4..8]) + .as_ptr() + .cast::<__m256i>(), + ); + let c2 = _mm256_loadu_si256( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[8..12]) + .as_ptr() + .cast::<__m256i>(), + ); + s0 = add_avx(&s0, &c0); + s1 = add_avx(&s1, &c1); + s2 = add_avx(&s2, &c2); + + mds_partial_layer_init_avx_m256i::(&mut s0, &mut s1, &mut s2); + + _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), s0); + _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), s1); + _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), s2); + + for i in 0..N_PARTIAL_ROUNDS { + state[0] = sbox_monomial(state[0]); + state[0] = state[0].add_canonical_u64(FAST_PARTIAL_ROUND_CONSTANTS[i]); + mds_partial_layer_fast_avx(&mut s0, &mut s1, &mut s2, &mut state, i); + } + round_ctr += N_PARTIAL_ROUNDS; + + // here state is already loaded in s0, s1, s2 + for _ in 0..HALF_N_FULL_ROUNDS { + let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] + .try_into() + .unwrap(); + let rc0 = _mm256_loadu_si256((&rc[0..4]).as_ptr().cast::<__m256i>()); + let rc1 = _mm256_loadu_si256((&rc[4..8]).as_ptr().cast::<__m256i>()); + let rc2 = _mm256_loadu_si256((&rc[8..12]).as_ptr().cast::<__m256i>()); + s0 = add_avx(&s0, &rc0); + s1 = add_avx(&s1, &rc1); + s2 = add_avx(&s2, &rc2); + sbox_avx(&mut s0, &mut s1, &mut s2); + mds_layer_avx(&mut s0, &mut s1, &mut s2); + round_ctr += 1; + } + + // store state + _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), s0); + _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), s1); + _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), s2); + }; + *state +} diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs deleted file mode 100644 index 7046a7fdd2..0000000000 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ /dev/null @@ -1,981 +0,0 @@ -use core::arch::asm; -use core::arch::x86_64::*; -use core::mem::size_of; - -use static_assertions::const_assert; - -use crate::field::goldilocks_field::GoldilocksField; -use crate::field::types::Field; -use crate::hash::poseidon::{ - Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, -}; -use crate::util::branch_hint; - -// WARNING: This code contains tricks that work for the current MDS matrix and round constants, but -// are not guaranteed to work if those are changed. - -// * Constant definitions * - -const WIDTH: usize = 12; - -// These transformed round constants are used where the constant layer is fused with the preceding -// MDS layer. The FUSED_ROUND_CONSTANTS for round i are the ALL_ROUND_CONSTANTS for round i + 1. -// The FUSED_ROUND_CONSTANTS for the very last round are 0, as it is not followed by a constant -// layer. On top of that, all FUSED_ROUND_CONSTANTS are shifted by 2 ** 63 to save a few XORs per -// round. -const fn make_fused_round_constants() -> [u64; WIDTH * N_ROUNDS] { - let mut res = [0x8000000000000000u64; WIDTH * N_ROUNDS]; - let mut i: usize = WIDTH; - while i < WIDTH * N_ROUNDS { - res[i - WIDTH] ^= ALL_ROUND_CONSTANTS[i]; - i += 1; - } - res -} -const FUSED_ROUND_CONSTANTS: [u64; WIDTH * N_ROUNDS] = make_fused_round_constants(); - -// This is the top row of the MDS matrix. Concretely, it's the MDS exps vector at the following -// indices: [0, 11, ..., 1]. -static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0]; - -// * Compile-time checks * - -/// The MDS matrix multiplication ASM is specific to the MDS matrix below. We want this file to -/// fail to compile if it has been changed. -#[allow(dead_code)] -const fn check_mds_matrix() -> bool { - // Can't == two arrays in a const_assert! (: - let mut i = 0; - let wanted_matrix_exps = [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10]; - while i < WIDTH { - if ::MDS_MATRIX_EXPS[i] != wanted_matrix_exps[i] { - return false; - } - i += 1; - } - true -} -const_assert!(check_mds_matrix()); - -/// The maximum amount by which the MDS matrix will multiply the input. -/// i.e. max(MDS(state)) <= mds_matrix_inf_norm() * max(state). -const fn mds_matrix_inf_norm() -> u64 { - let mut cumul = 0; - let mut i = 0; - while i < WIDTH { - cumul += 1 << ::MDS_MATRIX_EXPS[i]; - i += 1; - } - cumul -} - -/// Ensure that adding round constants to the low result of the MDS multiplication can never -/// overflow. -#[allow(dead_code)] -const fn check_round_const_bounds_mds() -> bool { - let max_mds_res = mds_matrix_inf_norm() * (u32::MAX as u64); - let mut i = WIDTH; // First const layer is handled specially. - while i < WIDTH * N_ROUNDS { - if ALL_ROUND_CONSTANTS[i].overflowing_add(max_mds_res).1 { - return false; - } - i += 1; - } - true -} -const_assert!(check_round_const_bounds_mds()); - -/// Ensure that the first WIDTH round constants are in canonical form for the vpcmpgtd trick. -#[allow(dead_code)] -const fn check_round_const_bounds_init() -> bool { - let max_permitted_round_const = 0xffffffff00000000; - let mut i = 0; // First const layer is handled specially. - while i < WIDTH { - if ALL_ROUND_CONSTANTS[i] > max_permitted_round_const { - return false; - } - i += 1; - } - true -} -const_assert!(check_round_const_bounds_init()); - -// Preliminary notes: -// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily -// emulated. The method recognizes that for a + b overflowed iff (a + b) < a: -// i. res_lo = a_lo + b_lo -// ii. carry_mask = res_lo < a_lo -// iii. res_hi = a_hi + b_hi - carry_mask -// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions -// return -1 (all bits 1) for true and 0 for false. -// -// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons -// by recognizing that a , $v:ident) => { - ($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2)) - }; - ($f:ident::<$l:literal>, $v1:ident, $v2:ident) => { - ( - $f::<$l>($v1.0, $v2.0), - $f::<$l>($v1.1, $v2.1), - $f::<$l>($v1.2, $v2.2), - ) - }; - ($f:ident, $v:ident) => { - ($f($v.0), $f($v.1), $f($v.2)) - }; - ($f:ident, $v0:ident, $v1:ident) => { - ($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2)) - }; - ($f:ident, $v0:ident, rep $v1:ident) => { - ($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1)) - }; -} - -#[inline(always)] -unsafe fn const_layer( - state: (__m256i, __m256i, __m256i), - round_const_arr: &[u64; 12], -) -> (__m256i, __m256i, __m256i) { - let sign_bit = _mm256_set1_epi64x(i64::MIN); - let round_const = ( - _mm256_loadu_si256((&round_const_arr[0..4]).as_ptr().cast::<__m256i>()), - _mm256_loadu_si256((&round_const_arr[4..8]).as_ptr().cast::<__m256i>()), - _mm256_loadu_si256((&round_const_arr[8..12]).as_ptr().cast::<__m256i>()), - ); - let state_s = map3!(_mm256_xor_si256, state, rep sign_bit); // Shift by 2**63. - let res_maybe_wrapped_s = map3!(_mm256_add_epi64, state_s, round_const); - // 32-bit compare is much faster than 64-bit compare on Intel. We can use 32-bit compare here - // as long as we can guarantee that state > res_maybe_wrapped iff state >> 32 > - // res_maybe_wrapped >> 32. Clearly, if state >> 32 > res_maybe_wrapped >> 32, then state > - // res_maybe_wrapped, and similarly for <. - // It remains to show that we can't have state >> 32 == res_maybe_wrapped >> 32 with state > - // res_maybe_wrapped. If state >> 32 == res_maybe_wrapped >> 32, then round_const >> 32 = - // 0xffffffff and the addition of the low doubleword generated a carry bit. This can never - // occur if all round constants are < 0xffffffff00000001 = ORDER: if the high bits are - // 0xffffffff, then the low bits are 0, so the carry bit cannot occur. So this trick is valid - // as long as all the round constants are in canonical form. - // The mask contains 0xffffffff in the high doubleword if wraparound occurred and 0 otherwise. - // We will ignore the low doubleword. - let wraparound_mask = map3!(_mm256_cmpgt_epi32, state_s, res_maybe_wrapped_s); - // wraparound_adjustment contains 0xffffffff = EPSILON if wraparound occurred and 0 otherwise. - let wraparound_adjustment = map3!(_mm256_srli_epi64::<32>, wraparound_mask); - // XOR commutes with the addition below. Placing it here helps mask latency. - let res_maybe_wrapped = map3!(_mm256_xor_si256, res_maybe_wrapped_s, rep sign_bit); - // Add EPSILON = subtract ORDER. - let res = map3!(_mm256_add_epi64, res_maybe_wrapped, wraparound_adjustment); - res -} - -#[inline(always)] -unsafe fn square3( - x: (__m256i, __m256i, __m256i), -) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { - let x_hi = { - // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than - // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. - // This is safe and free. - let x_ps = map3!(_mm256_castsi256_ps, x); - let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); - map3!(_mm256_castps_si256, x_hi_ps) - }; - - // All pairwise multiplications. - let mul_ll = map3!(_mm256_mul_epu32, x, x); - let mul_lh = map3!(_mm256_mul_epu32, x, x_hi); - let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi); - - // Bignum addition, but mul_lh is shifted by 33 bits (not 32). - let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll); - let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi); - let t0_hi = map3!(_mm256_srli_epi64::<31>, t0); - let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi); - - // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high - // position). - let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh); - let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo); - - (res_lo, res_hi) -} - -#[inline(always)] -unsafe fn mul3( - x: (__m256i, __m256i, __m256i), - y: (__m256i, __m256i, __m256i), -) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { - let epsilon = _mm256_set1_epi64x(0xffffffff); - let x_hi = { - // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than - // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. - // This is safe and free. - let x_ps = map3!(_mm256_castsi256_ps, x); - let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); - map3!(_mm256_castps_si256, x_hi_ps) - }; - let y_hi = { - let y_ps = map3!(_mm256_castsi256_ps, y); - let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps); - map3!(_mm256_castps_si256, y_hi_ps) - }; - - // All four pairwise multiplications - let mul_ll = map3!(_mm256_mul_epu32, x, y); - let mul_lh = map3!(_mm256_mul_epu32, x, y_hi); - let mul_hl = map3!(_mm256_mul_epu32, x_hi, y); - let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi); - - // Bignum addition - // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. - let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll); - let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi); - // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. - // Also, extract high 32 bits of t0 and add to mul_hh. - let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon); - let t0_hi = map3!(_mm256_srli_epi64::<32>, t0); - let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo); - let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi); - // Lastly, extract the high 32 bits of t1 and add to t2. - let t1_hi = map3!(_mm256_srli_epi64::<32>, t1); - let res_hi = map3!(_mm256_add_epi64, t2, t1_hi); - - // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high - // position). - let t1_lo = { - let t1_ps = map3!(_mm256_castsi256_ps, t1); - let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps); - map3!(_mm256_castps_si256, t1_lo_ps) - }; - let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo); - - (res_lo, res_hi) -} - -/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`. -#[inline(always)] -unsafe fn add_small( - x_s: (__m256i, __m256i, __m256i), - y: (__m256i, __m256i, __m256i), -) -> (__m256i, __m256i, __m256i) { - let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y); - let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s); - let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0. - let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt); - res_s -} - -#[inline(always)] -unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i { - // The subtraction is very unlikely to overflow so we're best off branching. - // The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd` - // branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to - // floating-point (this is free). - let mask_pd = _mm256_castsi256_pd(mask); - // `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow - // did not occur for any of the vector elements. - if _mm256_testz_pd(mask_pd, mask_pd) == 1 { - res_wrapped_s - } else { - branch_hint(); - // Highly unlikely: underflow did occur. Find adjustment per element and apply it. - let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow. - _mm256_sub_epi64(res_wrapped_s, adj_amount) - } -} - -/// Addition, where the second operand is much smaller than `0xffffffff00000001`. -#[inline(always)] -unsafe fn sub_tiny( - x_s: (__m256i, __m256i, __m256i), - y: (__m256i, __m256i, __m256i), -) -> (__m256i, __m256i, __m256i) { - let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y); - let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s); - let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask); - res_s -} - -#[inline(always)] -unsafe fn reduce3( - (lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)), -) -> (__m256i, __m256i, __m256i) { - let sign_bit = _mm256_set1_epi64x(i64::MIN); - let epsilon = _mm256_set1_epi64x(0xffffffff); - let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit); - let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0); - let lo1_s = sub_tiny(lo0_s, hi_hi0); - let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon); - let lo2_s = add_small(lo1_s, t1); - let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit); - lo2 -} - -#[inline(always)] -unsafe fn sbox_layer_full(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) { - let state2_unreduced = square3(state); - let state2 = reduce3(state2_unreduced); - let state4_unreduced = square3(state2); - let state3_unreduced = mul3(state2, state); - let state4 = reduce3(state4_unreduced); - let state3 = reduce3(state3_unreduced); - let state7_unreduced = mul3(state3, state4); - let state7 = reduce3(state7_unreduced); - state7 -} - -#[inline(always)] -unsafe fn mds_layer_reduce( - lo_s: (__m256i, __m256i, __m256i), - hi: (__m256i, __m256i, __m256i), -) -> (__m256i, __m256i, __m256i) { - // This is done in assembly because, frankly, it's cleaner than intrinsics. We also don't have - // to worry about whether the compiler is doing weird things. This entire routine needs proper - // pipelining so there's no point rewriting this, only to have to rewrite it again. - let res0: __m256i; - let res1: __m256i; - let res2: __m256i; - let epsilon = _mm256_set1_epi64x(0xffffffff); - let sign_bit = _mm256_set1_epi64x(i64::MIN); - asm!( - // The high results are in ymm3, ymm4, ymm5. - // The low results (shifted by 2**63) are in ymm0, ymm1, ymm2 - - // We want to do: ymm0 := ymm0 + (ymm3 * 2**32) in modulo P. - // This can be computed by ymm0 + (ymm3 << 32) + (ymm3 >> 32) * EPSILON, - // where the additions must correct for over/underflow. - - // First, do ymm0 + (ymm3 << 32) (first chain) - "vpsllq ymm6, ymm3, 32", - "vpsllq ymm7, ymm4, 32", - "vpsllq ymm8, ymm5, 32", - "vpaddq ymm6, ymm6, ymm0", - "vpaddq ymm7, ymm7, ymm1", - "vpaddq ymm8, ymm8, ymm2", - "vpcmpgtd ymm0, ymm0, ymm6", - "vpcmpgtd ymm1, ymm1, ymm7", - "vpcmpgtd ymm2, ymm2, ymm8", - - // Now we interleave the chains so this gets a bit uglier. - // Form ymm3 := (ymm3 >> 32) * EPSILON (second chain) - "vpsrlq ymm9, ymm3, 32", - "vpsrlq ymm10, ymm4, 32", - "vpsrlq ymm11, ymm5, 32", - // (first chain again) - "vpsrlq ymm0, ymm0, 32", - "vpsrlq ymm1, ymm1, 32", - "vpsrlq ymm2, ymm2, 32", - // (second chain again) - "vpandn ymm3, ymm14, ymm3", - "vpandn ymm4, ymm14, ymm4", - "vpandn ymm5, ymm14, ymm5", - "vpsubq ymm3, ymm3, ymm9", - "vpsubq ymm4, ymm4, ymm10", - "vpsubq ymm5, ymm5, ymm11", - // (first chain again) - "vpaddq ymm0, ymm6, ymm0", - "vpaddq ymm1, ymm7, ymm1", - "vpaddq ymm2, ymm8, ymm2", - - // Merge two chains (second addition) - "vpaddq ymm3, ymm0, ymm3", - "vpaddq ymm4, ymm1, ymm4", - "vpaddq ymm5, ymm2, ymm5", - "vpcmpgtd ymm0, ymm0, ymm3", - "vpcmpgtd ymm1, ymm1, ymm4", - "vpcmpgtd ymm2, ymm2, ymm5", - "vpsrlq ymm6, ymm0, 32", - "vpsrlq ymm7, ymm1, 32", - "vpsrlq ymm8, ymm2, 32", - "vpxor ymm3, ymm15, ymm3", - "vpxor ymm4, ymm15, ymm4", - "vpxor ymm5, ymm15, ymm5", - "vpaddq ymm0, ymm6, ymm3", - "vpaddq ymm1, ymm7, ymm4", - "vpaddq ymm2, ymm8, ymm5", - inout("ymm0") lo_s.0 => res0, - inout("ymm1") lo_s.1 => res1, - inout("ymm2") lo_s.2 => res2, - inout("ymm3") hi.0 => _, - inout("ymm4") hi.1 => _, - inout("ymm5") hi.2 => _, - out("ymm6") _, out("ymm7") _, out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _, - in("ymm14") epsilon, in("ymm15") sign_bit, - options(pure, nomem, preserves_flags, nostack), - ); - (res0, res1, res2) -} - -#[inline(always)] -unsafe fn mds_multiply_and_add_round_const_s( - state: (__m256i, __m256i, __m256i), - (base, index): (*const u64, usize), -) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { - // TODO: Would it be faster to save the input to memory and do unaligned - // loads instead of swizzling? It would reduce pressure on port 5 but it - // would also have high latency (no store forwarding). - // TODO: Would it be faster to store the lo and hi inputs and outputs on one - // vector? I.e., we currently operate on [lo(s[0]), lo(s[1]), lo(s[2]), - // lo(s[3])] and [hi(s[0]), hi(s[1]), hi(s[2]), hi(s[3])] separately. Using - // [lo(s[0]), lo(s[1]), hi(s[0]), hi(s[1])] and [lo(s[2]), lo(s[3]), - // hi(s[2]), hi(s[3])] would save us a few swizzles but would also need more - // registers. - // TODO: Plain-vanilla matrix-vector multiplication might also work. We take - // one element of the input (a scalar), multiply a column by it, and - // accumulate. It would require shifts by amounts loaded from memory, but - // would eliminate all swizzles. The downside is that we can no longer - // special-case MDS == 0 and MDS == 1, so we end up with more shifts. - // TODO: Building on the above: FMA? It has high latency (4 cycles) but we - // have enough operands to mask it. The main annoyance will be conversion - // to/from floating-point. - // TODO: Try taking the complex Fourier transform and doing the convolution - // with elementwise Fourier multiplication. Alternatively, try a Fourier - // transform modulo Q, such that the prime field fits the result without - // wraparound (i.e. Q > 0x1_1536_fffe_eac9) and has fast multiplication/- - // reduction. - - // At the end of the matrix-vector multiplication r = Ms, - // - ymm3 holds r[0:4] - // - ymm4 holds r[4:8] - // - ymm5 holds r[8:12] - // - ymm6 holds r[2:6] - // - ymm7 holds r[6:10] - // - ymm8 holds concat(r[10:12], r[0:2]) - // Note that there are duplicates. E.g. r[0] is represented by ymm3[0] and - // ymm8[2]. To obtain the final result, we must sum the duplicate entries: - // ymm3[0:2] += ymm8[2:4] - // ymm3[2:4] += ymm6[0:2] - // ymm4[0:2] += ymm6[2:4] - // ymm4[2:4] += ymm7[0:2] - // ymm5[0:2] += ymm7[2:4] - // ymm5[2:4] += ymm8[0:2] - // Thus, the final result resides in ymm3, ymm4, ymm5. - - // WARNING: This code assumes that sum(1 << exp for exp in MDS_EXPS) * 0xffffffff fits in a - // u64. If this guarantee ceases to hold, then it will no longer be correct. - let (unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s): (__m256i, __m256i, __m256i); - let (unreduced_hi0, unreduced_hi1, unreduced_hi2): (__m256i, __m256i, __m256i); - let epsilon = _mm256_set1_epi64x(0xffffffff); - asm!( - // Extract low 32 bits of the word - "vpand ymm9, ymm14, ymm0", - "vpand ymm10, ymm14, ymm1", - "vpand ymm11, ymm14, ymm2", - - "mov eax, 1", - - // Fall through for MDS matrix multiplication on low 32 bits - - // This is a GCC _local label_. For details, see - // https://doc.rust-lang.org/rust-by-example/unsafe/asm.html#labels - // In short, the assembler makes sure to assign a unique name to replace `2:` with a unique - // name, so the label does not clash with any compiler-generated label. `2:` can appear - // multiple times; to disambiguate, we must refer to it as `2b` or `2f`, specifying the - // direction as _backward_ or _forward_. - "2:", - // NB: This block is run twice: once on the low 32 bits and once for the - // high 32 bits. The 32-bit -> 64-bit matrix multiplication is responsible - // for the majority of the instructions in this routine. By reusing them, - // we decrease the burden on instruction caches by over one third. - - // 32-bit -> 64-bit MDS matrix multiplication - // The scalar loop goes: - // for r in 0..WIDTH { - // let mut res = 0u128; - // for i in 0..WIDTH { - // res += (state[(i + r) % WIDTH] as u128) << MDS_MATRIX_EXPS[i]; - // } - // result[r] = reduce(res); - // } - // - // Here, we swap the loops. Equivalent to: - // let mut res = [0u128; WIDTH]; - // for i in 0..WIDTH { - // let mds_matrix_exp = MDS_MATRIX_EXPS[i]; - // for r in 0..WIDTH { - // res[r] += (state[(i + r) % WIDTH] as u128) << mds_matrix_exp; - // } - // } - // for r in 0..WIDTH { - // result[r] = reduce(res[r]); - // } - // - // Notice that in the lower version, all iterations of the inner loop - // shift by the same amount. In vector, we perform multiple iterations of - // the loop at once, and vector shifts are cheaper when all elements are - // shifted by the same amount. - // - // We use a trick to avoid rotating the state vector many times. We - // have as input the state vector and the state vector rotated by one. We - // also have two accumulators: an unrotated one and one that's rotated by - // two. Rotations by three are achieved by matching an input rotated by - // one with an accumulator rotated by two. Rotations by four are free: - // they are done by using a different register. - - // mds[0 - 0] = 0 not done; would be a move from in0 to ymm3 - // ymm3 not set - // mds[0 - 4] = 12 - "vpsllq ymm4, ymm9, 12", - // mds[0 - 8] = 3 - "vpsllq ymm5, ymm9, 3", - // mds[0 - 2] = 16 - "vpsllq ymm6, ymm9, 16", - // mds[0 - 6] = mds[0 - 10] = 1 - "vpaddq ymm7, ymm9, ymm9", - // ymm8 not written - // ymm3 and ymm8 have not been written to, because those would be unnecessary - // copies. Implicitly, ymm3 := in0 and ymm8 := ymm7. - - // ymm12 := [ymm9[1], ymm9[2], ymm9[3], ymm10[0]] - "vperm2i128 ymm13, ymm9, ymm10, 0x21", - "vshufpd ymm12, ymm9, ymm13, 0x5", - - // ymm3 and ymm8 are not read because they have not been written to - // earlier. Instead, the "current value" of ymm3 is read from ymm9 and the - // "current value" of ymm8 is read from ymm7. - // mds[4 - 0] = 3 - "vpsllq ymm13, ymm10, 3", - "vpaddq ymm3, ymm9, ymm13", - // mds[4 - 4] = 0 - "vpaddq ymm4, ymm4, ymm10", - // mds[4 - 8] = 12 - "vpsllq ymm13, ymm10, 12", - "vpaddq ymm5, ymm5, ymm13", - // mds[4 - 2] = mds[4 - 10] = 1 - "vpaddq ymm13, ymm10, ymm10", - "vpaddq ymm6, ymm6, ymm13", - "vpaddq ymm8, ymm7, ymm13", - // mds[4 - 6] = 16 - "vpsllq ymm13, ymm10, 16", - "vpaddq ymm7, ymm7, ymm13", - - // mds[1 - 0] = 0 - "vpaddq ymm3, ymm3, ymm12", - // mds[1 - 4] = 3 - "vpsllq ymm13, ymm12, 3", - "vpaddq ymm4, ymm4, ymm13", - // mds[1 - 8] = 5 - "vpsllq ymm13, ymm12, 5", - "vpaddq ymm5, ymm5, ymm13", - // mds[1 - 2] = 10 - "vpsllq ymm13, ymm12, 10", - "vpaddq ymm6, ymm6, ymm13", - // mds[1 - 6] = 8 - "vpsllq ymm13, ymm12, 8", - "vpaddq ymm7, ymm7, ymm13", - // mds[1 - 10] = 0 - "vpaddq ymm8, ymm8, ymm12", - - // ymm10 := [ymm10[1], ymm10[2], ymm10[3], ymm11[0]] - "vperm2i128 ymm13, ymm10, ymm11, 0x21", - "vshufpd ymm10, ymm10, ymm13, 0x5", - - // mds[8 - 0] = 12 - "vpsllq ymm13, ymm11, 12", - "vpaddq ymm3, ymm3, ymm13", - // mds[8 - 4] = 3 - "vpsllq ymm13, ymm11, 3", - "vpaddq ymm4, ymm4, ymm13", - // mds[8 - 8] = 0 - "vpaddq ymm5, ymm5, ymm11", - // mds[8 - 2] = mds[8 - 6] = 1 - "vpaddq ymm13, ymm11, ymm11", - "vpaddq ymm6, ymm6, ymm13", - "vpaddq ymm7, ymm7, ymm13", - // mds[8 - 10] = 16 - "vpsllq ymm13, ymm11, 16", - "vpaddq ymm8, ymm8, ymm13", - - // ymm9 := [ymm11[1], ymm11[2], ymm11[3], ymm9[0]] - "vperm2i128 ymm13, ymm11, ymm9, 0x21", - "vshufpd ymm9, ymm11, ymm13, 0x5", - - // mds[5 - 0] = 5 - "vpsllq ymm13, ymm10, 5", - "vpaddq ymm3, ymm3, ymm13", - // mds[5 - 4] = 0 - "vpaddq ymm4, ymm4, ymm10", - // mds[5 - 8] = 3 - "vpsllq ymm13, ymm10, 3", - "vpaddq ymm5, ymm5, ymm13", - // mds[5 - 2] = 0 - "vpaddq ymm6, ymm6, ymm10", - // mds[5 - 6] = 10 - "vpsllq ymm13, ymm10, 10", - "vpaddq ymm7, ymm7, ymm13", - // mds[5 - 10] = 8 - "vpsllq ymm13, ymm10, 8", - "vpaddq ymm8, ymm8, ymm13", - - // mds[9 - 0] = 3 - "vpsllq ymm13, ymm9, 3", - "vpaddq ymm3, ymm3, ymm13", - // mds[9 - 4] = 5 - "vpsllq ymm13, ymm9, 5", - "vpaddq ymm4, ymm4, ymm13", - // mds[9 - 8] = 0 - "vpaddq ymm5, ymm5, ymm9", - // mds[9 - 2] = 8 - "vpsllq ymm13, ymm9, 8", - "vpaddq ymm6, ymm6, ymm13", - // mds[9 - 6] = 0 - "vpaddq ymm7, ymm7, ymm9", - // mds[9 - 10] = 10 - "vpsllq ymm13, ymm9, 10", - "vpaddq ymm8, ymm8, ymm13", - - // Rotate ymm6-ymm8 and add to the corresponding elements of ymm3-ymm5 - "vperm2i128 ymm13, ymm8, ymm6, 0x21", - "vpaddq ymm3, ymm3, ymm13", - "vperm2i128 ymm13, ymm6, ymm7, 0x21", - "vpaddq ymm4, ymm4, ymm13", - "vperm2i128 ymm13, ymm7, ymm8, 0x21", - "vpaddq ymm5, ymm5, ymm13", - - // If this is the first time we have run 2: (low 32 bits) then continue. - // If second time (high 32 bits), then jump to 3:. - "dec eax", - // Jump to the _local label_ (see above) `3:`. `f` for _forward_ specifies the direction. - "jnz 3f", - - // Extract high 32 bits - "vpsrlq ymm9, ymm0, 32", - "vpsrlq ymm10, ymm1, 32", - "vpsrlq ymm11, ymm2, 32", - - // Need to move the low result from ymm3-ymm5 to ymm0-13 so it is not - // overwritten. Save three instructions by combining the move with the constant layer, - // which would otherwise be done in 3:. The round constants include the shift by 2**63, so - // the resulting ymm0,1,2 are also shifted by 2**63. - // It is safe to add the round constants here without checking for overflow. The values in - // ymm3,4,5 are guaranteed to be <= 0x11536fffeeac9. All round constants are < 2**64 - // - 0x11536fffeeac9. - // WARNING: If this guarantee ceases to hold due to a change in the MDS matrix or round - // constants, then this code will no longer be correct. - "vpaddq ymm0, ymm3, [{base} + {index}]", - "vpaddq ymm1, ymm4, [{base} + {index} + 32]", - "vpaddq ymm2, ymm5, [{base} + {index} + 64]", - - // MDS matrix multiplication, again. This time on high 32 bits. - // Jump to the _local label_ (see above) `2:`. `b` for _backward_ specifies the direction. - "jmp 2b", - - // `3:` is a _local label_ (see above). - "3:", - // Just done the MDS matrix multiplication on high 32 bits. - // The high results are in ymm3, ymm4, ymm5. - // The low results (shifted by 2**63 and including the following constant layer) are in - // ymm0, ymm1, ymm2. - base = in(reg) base, - index = in(reg) index, - inout("ymm0") state.0 => unreduced_lo0_s, - inout("ymm1") state.1 => unreduced_lo1_s, - inout("ymm2") state.2 => unreduced_lo2_s, - out("ymm3") unreduced_hi0, - out("ymm4") unreduced_hi1, - out("ymm5") unreduced_hi2, - out("ymm6") _,out("ymm7") _, out("ymm8") _, out("ymm9") _, - out("ymm10") _, out("ymm11") _, out("ymm12") _, out("ymm13") _, - in("ymm14") epsilon, - out("rax") _, - options(pure, nomem, nostack), - ); - ( - (unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s), - (unreduced_hi0, unreduced_hi1, unreduced_hi2), - ) -} - -#[inline(always)] -unsafe fn mds_const_layers_full( - state: (__m256i, __m256i, __m256i), - round_constants: (*const u64, usize), -) -> (__m256i, __m256i, __m256i) { - let (unreduced_lo_s, unreduced_hi) = mds_multiply_and_add_round_const_s(state, round_constants); - mds_layer_reduce(unreduced_lo_s, unreduced_hi) -} - -/// Compute x ** 7 -#[inline(always)] -unsafe fn sbox_partial(mut x: u64) -> u64 { - // This is done in assembly to fix LLVM's poor treatment of wraparound addition/subtraction - // and to ensure that multiplication by EPSILON is done with bitshifts, leaving port 1 for - // vector operations. - // TODO: Interleave with MDS multiplication. - asm!( - "mov r9, rdx", - - // rdx := rdx ^ 2 - "mulx rdx, rax, rdx", - "shrx r8, rdx, r15", - "mov r12d, edx", - "shl rdx, 32", - "sub rdx, r12", - // rax - r8, with underflow - "sub rax, r8", - "sbb r8d, r8d", // sets r8 to 2^32 - 1 if subtraction underflowed - "sub rax, r8", - // rdx + rax, with overflow - "add rdx, rax", - "sbb eax, eax", - "add rdx, rax", - - // rax := rdx * r9, rdx := rdx ** 2 - "mulx rax, r11, r9", - "mulx rdx, r12, rdx", - - "shrx r9, rax, r15", - "shrx r10, rdx, r15", - - "sub r11, r9", - "sbb r9d, r9d", - "sub r12, r10", - "sbb r10d, r10d", - "sub r11, r9", - "sub r12, r10", - - "mov r9d, eax", - "mov r10d, edx", - "shl rax, 32", - "shl rdx, 32", - "sub rax, r9", - "sub rdx, r10", - - "add rax, r11", - "sbb r11d, r11d", - "add rdx, r12", - "sbb r12d, r12d", - "add rax, r11", - "add rdx, r12", - - // rax := rax * rdx - "mulx rax, rdx, rax", - "shrx r11, rax, r15", - "mov r12d, eax", - "shl rax, 32", - "sub rax, r12", - // rdx - r11, with underflow - "sub rdx, r11", - "sbb r11d, r11d", // sets r11 to 2^32 - 1 if subtraction underflowed - "sub rdx, r11", - // rdx + rax, with overflow - "add rdx, rax", - "sbb eax, eax", - "add rdx, rax", - inout("rdx") x, - out("rax") _, - out("r8") _, - out("r9") _, - out("r10") _, - out("r11") _, - out("r12") _, - in("r15") 32, - options(pure, nomem, nostack), - ); - x -} - -#[inline(always)] -unsafe fn partial_round( - (state0, state1, state2): (__m256i, __m256i, __m256i), - round_constants: (*const u64, usize), -) -> (__m256i, __m256i, __m256i) { - // Extract the low quadword - let state0ab: __m128i = _mm256_castsi256_si128(state0); - let mut state0a = _mm_cvtsi128_si64(state0ab) as u64; - - // Zero the low quadword - let zero = _mm256_setzero_si256(); - let state0bcd = _mm256_blend_epi32::<0x3>(state0, zero); - - // Scalar exponentiation - state0a = sbox_partial(state0a); - - let epsilon = _mm256_set1_epi64x(0xffffffff); - let ( - (mut unreduced_lo0_s, mut unreduced_lo1_s, mut unreduced_lo2_s), - (mut unreduced_hi0, mut unreduced_hi1, mut unreduced_hi2), - ) = mds_multiply_and_add_round_const_s((state0bcd, state1, state2), round_constants); - asm!( - // Just done the MDS matrix multiplication on high 32 bits. - // The high results are in ymm3, ymm4, ymm5. - // The low results (shifted by 2**63) are in ymm0, ymm1, ymm2 - - // The MDS matrix multiplication was done with state[0] set to 0. - // We must: - // 1. propagate the vector product to state[0], which is stored in rdx. - // 2. offset state[1..12] by the appropriate multiple of rdx - // 3. zero the lowest quadword in the vector registers - "vmovq xmm12, {state0a}", - "vpbroadcastq ymm12, xmm12", - "vpsrlq ymm13, ymm12, 32", - "vpand ymm12, ymm14, ymm12", - - // The current matrix-vector product goes not include state[0] as an input. (Imagine Mv - // multiplication where we've set the first element to 0.) Add the remaining bits now. - // TODO: This is a bit of an afterthought, which is why these constants are loaded 22 - // times... There's likely a better way of merging those results. - "vmovdqu ymm6, [{mds_matrix}]", - "vmovdqu ymm7, [{mds_matrix} + 32]", - "vmovdqu ymm8, [{mds_matrix} + 64]", - "vpsllvq ymm9, ymm13, ymm6", - "vpsllvq ymm10, ymm13, ymm7", - "vpsllvq ymm11, ymm13, ymm8", - "vpsllvq ymm6, ymm12, ymm6", - "vpsllvq ymm7, ymm12, ymm7", - "vpsllvq ymm8, ymm12, ymm8", - "vpaddq ymm3, ymm9, ymm3", - "vpaddq ymm4, ymm10, ymm4", - "vpaddq ymm5, ymm11, ymm5", - "vpaddq ymm0, ymm6, ymm0", - "vpaddq ymm1, ymm7, ymm1", - "vpaddq ymm2, ymm8, ymm2", - // Reduction required. - - state0a = in(reg) state0a, - mds_matrix = in(reg) &TOP_ROW_EXPS, - inout("ymm0") unreduced_lo0_s, - inout("ymm1") unreduced_lo1_s, - inout("ymm2") unreduced_lo2_s, - inout("ymm3") unreduced_hi0, - inout("ymm4") unreduced_hi1, - inout("ymm5") unreduced_hi2, - out("ymm6") _,out("ymm7") _, out("ymm8") _, out("ymm9") _, - out("ymm10") _, out("ymm11") _, out("ymm12") _, out("ymm13") _, - in("ymm14") epsilon, - options(pure, nomem, preserves_flags, nostack), - ); - mds_layer_reduce( - (unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s), - (unreduced_hi0, unreduced_hi1, unreduced_hi2), - ) -} - -#[inline(always)] -unsafe fn full_round( - state: (__m256i, __m256i, __m256i), - round_constants: (*const u64, usize), -) -> (__m256i, __m256i, __m256i) { - let state = sbox_layer_full(state); - let state = mds_const_layers_full(state, round_constants); - state -} - -#[inline] // Called twice; permit inlining but don't _require_ it -unsafe fn half_full_rounds( - mut state: (__m256i, __m256i, __m256i), - start_round: usize, -) -> (__m256i, __m256i, __m256i) { - let base = (&FUSED_ROUND_CONSTANTS - [WIDTH * start_round..WIDTH * start_round + WIDTH * HALF_N_FULL_ROUNDS]) - .as_ptr(); - - for i in 0..HALF_N_FULL_ROUNDS { - state = full_round(state, (base, i * WIDTH * size_of::())); - } - state -} - -#[inline(always)] -unsafe fn all_partial_rounds( - mut state: (__m256i, __m256i, __m256i), - start_round: usize, -) -> (__m256i, __m256i, __m256i) { - let base = (&FUSED_ROUND_CONSTANTS - [WIDTH * start_round..WIDTH * start_round + WIDTH * N_PARTIAL_ROUNDS]) - .as_ptr(); - - for i in 0..N_PARTIAL_ROUNDS { - state = partial_round(state, (base, i * WIDTH * size_of::())); - } - state -} - -#[inline(always)] -unsafe fn load_state(state: &[GoldilocksField; 12]) -> (__m256i, __m256i, __m256i) { - ( - _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()), - _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()), - _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()), - ) -} - -#[inline(always)] -unsafe fn store_state(buf: &mut [GoldilocksField; 12], state: (__m256i, __m256i, __m256i)) { - _mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0); - _mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1); - _mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2); -} - -#[inline] -pub unsafe fn poseidon(state: &[GoldilocksField; 12]) -> [GoldilocksField; 12] { - let state = load_state(state); - - // The first constant layer must be done explicitly. The remaining constant layers are fused - // with the preceding MDS layer. - let state = const_layer(state, &ALL_ROUND_CONSTANTS[0..WIDTH].try_into().unwrap()); - - let state = half_full_rounds(state, 0); - let state = all_partial_rounds(state, HALF_N_FULL_ROUNDS); - let state = half_full_rounds(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS); - - let mut res = [GoldilocksField::ZERO; 12]; - store_state(&mut res, state); - res -} - -#[inline(always)] -pub unsafe fn constant_layer(state_arr: &mut [GoldilocksField; WIDTH], round_ctr: usize) { - let state = load_state(state_arr); - let round_consts = &ALL_ROUND_CONSTANTS[WIDTH * round_ctr..][..WIDTH] - .try_into() - .unwrap(); - let state = const_layer(state, round_consts); - store_state(state_arr, state); -} - -#[inline(always)] -pub unsafe fn sbox_layer(state_arr: &mut [GoldilocksField; WIDTH]) { - let state = load_state(state_arr); - let state = sbox_layer_full(state); - store_state(state_arr, state); -} - -#[inline(always)] -pub unsafe fn mds_layer(state: &[GoldilocksField; WIDTH]) -> [GoldilocksField; WIDTH] { - let state = load_state(state); - // We want to do an MDS layer without the constant layer. - // The FUSED_ROUND_CONSTANTS for the last round are all 0 (shifted by 2**63 as required). - let round_consts = FUSED_ROUND_CONSTANTS[WIDTH * (N_ROUNDS - 1)..].as_ptr(); - let state = mds_const_layers_full(state, (round_consts, 0)); - let mut res = [GoldilocksField::ZERO; 12]; - store_state(&mut res, state); - res -} diff --git a/plonky2/src/hash/poseidon.rs b/plonky2/src/hash/poseidon.rs index a7c763252e..6fdf9ed66f 100644 --- a/plonky2/src/hash/poseidon.rs +++ b/plonky2/src/hash/poseidon.rs @@ -8,6 +8,8 @@ use core::fmt::Debug; use plonky2_field::packed::PackedField; use unroll::unroll_for_loops; +#[cfg(target_feature = "avx2")] +use super::arch::x86_64::poseidon_goldilocks_avx2::poseidon_avx; use crate::field::extension::{Extendable, FieldExtension}; use crate::field::types::{Field, PrimeField64}; use crate::gates::gate::Gate; @@ -37,14 +39,14 @@ pub const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) #[inline(always)] -const fn add_u160_u128((x_lo, x_hi): (u128, u32), y: u128) -> (u128, u32) { +pub(crate) const fn add_u160_u128((x_lo, x_hi): (u128, u32), y: u128) -> (u128, u32) { let (res_lo, over) = x_lo.overflowing_add(y); let res_hi = x_hi + (over as u32); (res_lo, res_hi) } #[inline(always)] -fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { +pub(crate) fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { let n_lo_hi = (n_lo >> 64) as u64; let n_lo_lo = n_lo as u64; let reduced_hi: u64 = F::from_noncanonical_u96((n_lo_hi, n_hi)).to_noncanonical_u64(); @@ -763,6 +765,7 @@ pub trait Poseidon: PrimeField64 { *round_ctr += N_PARTIAL_ROUNDS; } + #[cfg(not(target_feature = "avx2"))] #[inline] fn poseidon(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { let mut state = input; @@ -776,6 +779,12 @@ pub trait Poseidon: PrimeField64 { state } + #[cfg(target_feature = "avx2")] + #[inline] + fn poseidon(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + poseidon_avx(&input) + } + // For testing only, to ensure that various tricks are correct. #[inline] fn partial_rounds_naive(state: &mut [Self; SPONGE_WIDTH], round_ctr: &mut usize) { diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index 12d061265e..300b737c56 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -308,7 +308,7 @@ impl Poseidon for GoldilocksField { // The following code has been adapted from winterfell/crypto/src/hash/mds/mds_f64_12x12.rs // located at https://github.com/facebook/winterfell. #[cfg(not(all(target_arch = "aarch64", target_feature = "neon")))] -mod poseidon12_mds { +pub(crate) mod poseidon12_mds { const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 32, 16]; const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(2, -1), (-4, 1), (16, 1)]; const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-1, -8, 2]; @@ -354,7 +354,7 @@ mod poseidon12_mds { } #[inline(always)] - const fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] { + pub(crate) const fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] { let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x; let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y; let x0s = x0r + x0i;