From 6153969286a597f72377cb9b4d0fd9a6e763e4a7 Mon Sep 17 00:00:00 2001 From: kilic Date: Wed, 1 Feb 2023 20:57:34 +0300 Subject: [PATCH 1/7] feat: add endo scalar decomposition --- src/arithmetic.rs | 286 ++++++++++++++++++++++++++++++++++++++++++++ src/bn256/curve.rs | 41 ++++++- src/derive/curve.rs | 51 ++++++++ src/pasta/mod.rs | 60 ++++++++++ 4 files changed, 436 insertions(+), 2 deletions(-) diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 42b84772..70748ead 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -4,6 +4,19 @@ //! This module is temporary, and the extension traits defined here are expected to be //! upstreamed into the `ff` and `group` crates after some refactoring. +use crate::CurveExt; + +pub(crate) struct EndoParameters { + pub(crate) gamma1: [u64; 4], + pub(crate) gamma2: [u64; 4], + pub(crate) b1: [u64; 4], + pub(crate) b2: [u64; 4], +} + +pub trait CurveEndo: CurveExt { + fn decompose_scalar(e: &Self::ScalarExt) -> (u128, bool, u128, bool); +} + pub trait CurveAffineExt: pasta_curves::arithmetic::CurveAffine { fn batch_add( points: &mut [Self], @@ -42,3 +55,276 @@ pub(crate) const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) { let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128); (ret as u64, (ret >> 64) as u64) } + +/// Compute a + (b * c), returning the result and the new carry over. +#[inline(always)] +pub(crate) const fn macx(a: u64, b: u64, c: u64) -> (u64, u64) { + let res = (a as u128) + ((b as u128) * (c as u128)); + (res as u64, (res >> 64) as u64) +} + +/// Compute a * b, returning the result. +#[inline(always)] +pub(crate) fn mul_512(a: [u64; 4], b: [u64; 4]) -> [u64; 8] { + let (r0, carry) = macx(0, a[0], b[0]); + let (r1, carry) = macx(carry, a[0], b[1]); + let (r2, carry) = macx(carry, a[0], b[2]); + let (r3, carry_out) = macx(carry, a[0], b[3]); + + let (r1, carry) = macx(r1, a[1], b[0]); + let (r2, carry) = mac(r2, a[1], b[1], carry); + let (r3, carry) = mac(r3, a[1], b[2], carry); + let (r4, carry_out) = mac(carry_out, a[1], b[3], carry); + + let (r2, carry) = macx(r2, a[2], b[0]); + let (r3, carry) = mac(r3, a[2], b[1], carry); + let (r4, carry) = mac(r4, a[2], b[2], carry); + let (r5, carry_out) = mac(carry_out, a[2], b[3], carry); + + let (r3, carry) = macx(r3, a[3], b[0]); + let (r4, carry) = mac(r4, a[3], b[1], carry); + let (r5, carry) = mac(r5, a[3], b[2], carry); + let (r6, carry_out) = mac(carry_out, a[3], b[3], carry); + + [r0, r1, r2, r3, r4, r5, r6, carry_out] +} + +#[cfg(test)] +mod test { + use super::CurveEndo; + use crate::bn256::G1; + use ff::Field; + use pasta_curves::Ep; + use pasta_curves::Eq; + use rand_core::OsRng; + + // naive glv multiplication implementation + fn glv_mul(point: C, scalar: &C::ScalarExt) -> C { + const WINDOW: usize = 3; + // decompose scalar and convert to wnaf representation + let (k1, k1_neg, k2, k2_neg) = C::decompose_scalar(scalar); + + let mut k1_wnaf: Vec = Vec::new(); + let mut k2_wnaf: Vec = Vec::new(); + wnaf::form(&mut k1_wnaf, k1.to_le_bytes(), WINDOW); + wnaf::form(&mut k2_wnaf, k2.to_le_bytes(), WINDOW); + + let n = std::cmp::max(k1_wnaf.len(), k2_wnaf.len()); + k1_wnaf.resize(n, 0); + k2_wnaf.resize(n, 0); + + // prepare tables + let two_p = point.double(); + // T1 = {P, 3P, 5P, ...} + let mut table_k1 = vec![point.clone()]; + // T2 = {λP, 3λP, 5λP, ...} + let mut table_k2 = vec![point.endo()]; + for i in 1..WINDOW - 1 { + table_k1.push(table_k1[i - 1] + two_p); + table_k2.push(table_k1[i].endo()) + } + if !k2_neg { + table_k2.iter_mut().for_each(|p| *p = -p.clone()); + } + if k1_neg { + table_k1.iter_mut().for_each(|p| *p = -p.clone()); + } + // TODO: batch affine tables for mixed add? + + macro_rules! add { + ($acc:expr, $e:expr, $table:expr) => { + let idx = ($e.abs() >> 1) as usize; + $acc += if $e.is_positive() { + $table[idx] + } else if $e.is_negative() { + -$table[idx] + } else { + C::identity() + }; + }; + } + + // apply simultaneus double add + k1_wnaf + .iter() + .rev() + .zip(k2_wnaf.iter().rev()) + .fold(C::identity(), |acc, (e1, e2)| { + let mut acc = acc.double(); + add!(acc, e1, table_k1); + add!(acc, e2, table_k2); + acc + }) + } + + fn run_glv_mul_test() { + for _ in 0..10000 { + let point = C::random(OsRng); + let scalar = C::ScalarExt::random(OsRng); + let r0 = point * scalar; + let r1 = glv_mul(point, &scalar); + assert_eq!(r0, r1); + } + } + + #[test] + fn test_glv_mul() { + run_glv_mul_test::(); + run_glv_mul_test::(); + run_glv_mul_test::(); + } + + #[test] + fn test_wnaf_form() { + use rand::Rng; + fn from_wnaf(wnaf: &Vec) -> u128 { + wnaf.iter().rev().fold(0, |acc, next| { + let mut acc = acc * 2; + acc += *next as u128; + acc + }) + } + for w in 2..64 { + for e in 0..=u16::MAX { + let mut wnaf = vec![]; + wnaf::form(&mut wnaf, e.to_le_bytes(), w); + assert_eq!(e as u128, from_wnaf(&wnaf)); + } + } + let mut wnaf = vec![]; + for w in 2..64 { + for e in u128::MAX - 10000..=u128::MAX { + wnaf::form(&mut wnaf, e.to_le_bytes(), w); + assert_eq!(e, from_wnaf(&wnaf)); + } + } + for w in 2..10 { + for _ in 0..10000 { + let e: u128 = OsRng.gen(); + wnaf::form(&mut wnaf, e.to_le_bytes(), w); + assert_eq!(e as u128, from_wnaf(&wnaf)); + } + } + } + + // taken from zkcrypto/group + mod wnaf { + use std::convert::TryInto; + + #[derive(Debug, Clone)] + struct LimbBuffer<'a> { + buf: &'a [u8], + cur_idx: usize, + cur_limb: u64, + next_limb: u64, + } + + impl<'a> LimbBuffer<'a> { + fn new(buf: &'a [u8]) -> Self { + let mut ret = Self { + buf, + cur_idx: 0, + cur_limb: 0, + next_limb: 0, + }; + + // Initialise the limb buffers. + ret.increment_limb(); + ret.increment_limb(); + ret.cur_idx = 0usize; + ret + } + + fn increment_limb(&mut self) { + self.cur_idx += 1; + self.cur_limb = self.next_limb; + match self.buf.len() { + // There are no more bytes in the buffer; zero-extend. + 0 => self.next_limb = 0, + + // There are fewer bytes in the buffer than a u64 limb; zero-extend. + x @ 1..=7 => { + let mut next_limb = [0; 8]; + next_limb[..x].copy_from_slice(self.buf); + self.next_limb = u64::from_le_bytes(next_limb); + self.buf = &[]; + } + + // There are at least eight bytes in the buffer; read the next u64 limb. + _ => { + let (next_limb, rest) = self.buf.split_at(8); + self.next_limb = u64::from_le_bytes(next_limb.try_into().unwrap()); + self.buf = rest; + } + } + } + + fn get(&mut self, idx: usize) -> (u64, u64) { + assert!([self.cur_idx, self.cur_idx + 1].contains(&idx)); + if idx > self.cur_idx { + self.increment_limb(); + } + (self.cur_limb, self.next_limb) + } + } + + /// Replaces the contents of `wnaf` with the w-NAF representation of a little-endian + /// scalar. + pub(crate) fn form>(wnaf: &mut Vec, c: S, window: usize) { + // Required by the NAF definition + debug_assert!(window >= 2); + // Required so that the NAF digits fit in i64 + debug_assert!(window < 64); + + let bit_len = c.as_ref().len() * 8; + + wnaf.truncate(0); + wnaf.reserve(bit_len + 1); + + // Initialise the current and next limb buffers. + let mut limbs = LimbBuffer::new(c.as_ref()); + + let width = 1u64 << window; + let window_mask = width - 1; + + let mut pos = 0; + let mut carry = 0; + while pos <= bit_len { + // Construct a buffer of bits of the scalar, starting at bit `pos` + let u64_idx = pos / 64; + let bit_idx = pos % 64; + let (cur_u64, next_u64) = limbs.get(u64_idx); + let bit_buf = if bit_idx + window < 64 { + // This window's bits are contained in a single u64 + cur_u64 >> bit_idx + } else { + // Combine the current u64's bits with the bits from the next u64 + (cur_u64 >> bit_idx) | (next_u64 << (64 - bit_idx)) + }; + + // Add the carry into the current window + let window_val = carry + (bit_buf & window_mask); + + if window_val & 1 == 0 { + // If the window value is even, preserve the carry and emit 0. + // Why is the carry preserved? + // If carry == 0 and window_val & 1 == 0, then the next carry should be 0 + // If carry == 1 and window_val & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1 + wnaf.push(0); + pos += 1; + } else { + wnaf.push(if window_val < width / 2 { + carry = 0; + window_val as i64 + } else { + carry = 1; + (window_val as i64).wrapping_sub(width as i64) + }); + wnaf.extend(std::iter::repeat(0).take(window - 1)); + pos += window; + } + } + wnaf.truncate(wnaf.len().saturating_sub(window - 1)); + } + } +} diff --git a/src/bn256/curve.rs b/src/bn256/curve.rs index 7c359233..981eedd8 100644 --- a/src/bn256/curve.rs +++ b/src/bn256/curve.rs @@ -1,6 +1,11 @@ +use crate::arithmetic::mul_512; +use crate::arithmetic::sbb; +use crate::arithmetic::CurveEndo; +use crate::arithmetic::EndoParameters; use crate::bn256::Fq; use crate::bn256::Fq2; use crate::bn256::Fr; +use crate::endo; use crate::{Coordinates, CurveAffine, CurveAffineExt, CurveExt}; use core::cmp; use core::fmt::Debug; @@ -11,6 +16,7 @@ use ff::{Field, PrimeField}; use group::Curve; use group::{cofactor::CofactorGroup, prime::PrimeCurveAffine, Group, GroupEncoding}; use rand::RngCore; +use std::convert::TryInto; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use crate::{ @@ -109,6 +115,20 @@ const G2_GENERATOR_Y: Fq2 = Fq2 { ]), }; +const ENDO_PARAMS: EndoParameters = EndoParameters { + gamma1: [ + 0x7a7bd9d4391eb18du64, + 0x4ccef014a773d2cfu64, + 0x0000000000000002u64, + 0u64, + ], + gamma2: [0xd91d232ec7e0b3d7u64, 0x0000000000000002u64, 0u64, 0u64], + b1: [0x8211bbeb7d4f1128u64, 0x6f4d8248eeb859fcu64, 0u64, 0u64], + b2: [0x89d3256894d213e3u64, 0u64, 0u64, 0u64], +}; + +endo!(G1, Fr, ENDO_PARAMS); + impl group::cofactor::CofactorGroup for G1 { type Subgroup = G1; @@ -178,9 +198,14 @@ impl CofactorGroup for G2 { #[cfg(test)] mod tests { + + use crate::arithmetic::CurveEndo; use crate::bn256::{Fr, G1, G2}; use crate::CurveExt; + use ff::Field; + use ff::PrimeField; use ff::WithSmallOrderMulGroup; + use rand_core::OsRng; #[test] fn test_curve() { @@ -189,12 +214,24 @@ mod tests { } #[test] - fn test_endo_consistency() { + fn test_endo() { let g = G1::generator(); assert_eq!(g * Fr::ZETA, g.endo()); - let g = G2::generator(); assert_eq!(g * Fr::ZETA, g.endo()); + for _ in 0..100000 { + let k = Fr::random(OsRng); + let (k1, k1_neg, k2, k2_neg) = G1::decompose_scalar(&k); + if k1_neg & k2_neg { + assert_eq!(k, -Fr::from_u128(k1) + Fr::ZETA * Fr::from_u128(k2)) + } else if k1_neg { + assert_eq!(k, -Fr::from_u128(k1) - Fr::ZETA * Fr::from_u128(k2)) + } else if k2_neg { + assert_eq!(k, Fr::from_u128(k1) + Fr::ZETA * Fr::from_u128(k2)) + } else { + assert_eq!(k, Fr::from_u128(k1) - Fr::ZETA * Fr::from_u128(k2)) + } + } } #[test] diff --git a/src/derive/curve.rs b/src/derive/curve.rs index ba065fe2..b1e1daca 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -140,6 +140,57 @@ macro_rules! batch_add { }; } +#[macro_export] +macro_rules! endo { + ($name:ident, $field:ident, $params:expr) => { + impl CurveEndo for $name { + fn decompose_scalar(k: &$field) -> (u128, bool, u128, bool) { + let to_limbs = |e: &$field| { + let repr = e.to_repr(); + let repr = repr.as_ref(); + let tmp0 = u64::from_le_bytes(repr[0..8].try_into().unwrap()); + let tmp1 = u64::from_le_bytes(repr[8..16].try_into().unwrap()); + let tmp2 = u64::from_le_bytes(repr[16..24].try_into().unwrap()); + let tmp3 = u64::from_le_bytes(repr[24..32].try_into().unwrap()); + [tmp0, tmp1, tmp2, tmp3] + }; + + let get_lower_128 = |e: &$field| { + let e = to_limbs(e); + u128::from(e[0]) | (u128::from(e[1]) << 64) + }; + + let is_neg = |e: &$field| { + let e = to_limbs(e); + let (_, borrow) = sbb(0xffffffffffffffff, e[0], 0); + let (_, borrow) = sbb(0xffffffffffffffff, e[1], borrow); + let (_, borrow) = sbb(0xffffffffffffffff, e[2], borrow); + let (_, borrow) = sbb(0x00, e[3], borrow); + borrow & 1 != 0 + }; + + let input = to_limbs(&k); + let c1 = mul_512($params.gamma2, input); + let c2 = mul_512($params.gamma1, input); + let c1 = [c1[4], c1[5], c1[6], c1[7]]; + let c2 = [c2[4], c2[5], c2[6], c2[7]]; + let q1 = mul_512(c1, $params.b1); + let q2 = mul_512(c2, $params.b2); + let q1 = $field::from_raw([q1[0], q1[1], q1[2], q1[3]]); + let q2 = $field::from_raw([q2[0], q2[1], q2[2], q2[3]]); + let k2 = q2 - q1; + let k1 = k + k2 * $field::ZETA; + let k1_neg = is_neg(&k1); + let k2_neg = is_neg(&k2); + let k1 = if k1_neg { -k1 } else { k1 }; + let k2 = if k2_neg { -k2 } else { k2 }; + + (get_lower_128(&k1), k1_neg, get_lower_128(&k2), k2_neg) + } + } + }; +} + #[macro_export] macro_rules! new_curve_impl { (($($privacy:tt)*), diff --git a/src/pasta/mod.rs b/src/pasta/mod.rs index f6aee547..0252b199 100644 --- a/src/pasta/mod.rs +++ b/src/pasta/mod.rs @@ -1,4 +1,13 @@ +use crate::arithmetic::mul_512; +use crate::arithmetic::sbb; +use crate::{ + arithmetic::{CurveEndo, EndoParameters}, + endo, +}; +use ff::PrimeField; +use ff::WithSmallOrderMulGroup; pub use pasta_curves::{pallas, vesta, Ep, EpAffine, Eq, EqAffine, Fp, Fq}; +use std::convert::TryInto; impl crate::CurveAffineExt for EpAffine { fn batch_add( @@ -25,3 +34,54 @@ impl crate::CurveAffineExt for EqAffine { unimplemented!(); } } + +const ENDO_PARAMS_EQ: EndoParameters = EndoParameters { + gamma1: [0x32c49e4c00000003, 0x279a745902a2654e, 0x1, 0x0], + gamma2: [0x31f0256800000002, 0x4f34e8b2066389a4, 0x2, 0x0], + b1: [0x8cb1279300000001, 0x49e69d1640a89953, 0x0, 0x0], + b2: [0x0c7c095a00000001, 0x93cd3a2c8198e269, 0x0, 0x0], +}; + +const ENDO_PARAMS_EP: EndoParameters = EndoParameters { + gamma1: [0x32c49e4bffffffff, 0x279a745902a2654e, 0x1, 0x0], + gamma2: [0x31f0256800000002, 0x4f34e8b2066389a4, 0x2, 0x0], + b1: [0x8cb1279300000000, 0x49e69d1640a89953, 0x0, 0x0], + b2: [0x0c7c095a00000001, 0x93cd3a2c8198e269, 0x0, 0x0], +}; + +endo!(Eq, Fp, ENDO_PARAMS_EQ); +endo!(Ep, Fq, ENDO_PARAMS_EP); + +#[test] +fn test_endo() { + use ff::Field; + use rand_core::OsRng; + + for _ in 0..100000 { + let k = Fp::random(OsRng); + let (k1, k1_neg, k2, k2_neg) = Eq::decompose_scalar(&k); + if k1_neg & k2_neg { + assert_eq!(k, -Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) + } else if k1_neg { + assert_eq!(k, -Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) + } else if k2_neg { + assert_eq!(k, Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) + } else { + assert_eq!(k, Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) + } + } + + for _ in 0..100000 { + let k = Fp::random(OsRng); + let (k1, k1_neg, k2, k2_neg) = Eq::decompose_scalar(&k); + if k1_neg & k2_neg { + assert_eq!(k, -Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) + } else if k1_neg { + assert_eq!(k, -Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) + } else if k2_neg { + assert_eq!(k, Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) + } else { + assert_eq!(k, Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) + } + } +} From abb34c3e503f2b8f01f6f607e0d7cc7e579c5ff1 Mon Sep 17 00:00:00 2001 From: kilic Date: Wed, 1 Feb 2023 22:48:42 +0300 Subject: [PATCH 2/7] fix: clippy --- src/arithmetic.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 70748ead..277a24ff 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -116,7 +116,7 @@ mod test { // prepare tables let two_p = point.double(); // T1 = {P, 3P, 5P, ...} - let mut table_k1 = vec![point.clone()]; + let mut table_k1 = vec![point]; // T2 = {λP, 3λP, 5λP, ...} let mut table_k2 = vec![point.endo()]; for i in 1..WINDOW - 1 { @@ -124,10 +124,10 @@ mod test { table_k2.push(table_k1[i].endo()) } if !k2_neg { - table_k2.iter_mut().for_each(|p| *p = -p.clone()); + table_k2.iter_mut().for_each(|p| *p = -*p); } if k1_neg { - table_k1.iter_mut().for_each(|p| *p = -p.clone()); + table_k1.iter_mut().for_each(|p| *p = -*p); } // TODO: batch affine tables for mixed add? @@ -177,7 +177,7 @@ mod test { #[test] fn test_wnaf_form() { use rand::Rng; - fn from_wnaf(wnaf: &Vec) -> u128 { + fn from_wnaf(wnaf: &[i64]) -> u128 { wnaf.iter().rev().fold(0, |acc, next| { let mut acc = acc * 2; acc += *next as u128; From 1fd2e54142992b70892dda64ccc2a79645312045 Mon Sep 17 00:00:00 2001 From: kilic Date: Tue, 28 Feb 2023 02:08:06 +0300 Subject: [PATCH 3/7] feat: add msm function with other reference impls --- .gitignore | 3 +- Cargo.toml | 1 + src/arithmetic.rs | 17 +- src/bn256/curve.rs | 22 +- src/bn256/mod.rs | 1 + src/bn256/msm/mod.rs | 257 +++++++++++ src/bn256/msm/pr40.rs | 982 +++++++++++++++++++++++++++++++++++++++++ src/bn256/msm/round.rs | 361 +++++++++++++++ src/bn256/msm/zcash.rs | 123 ++++++ src/derive/curve.rs | 82 ++-- src/pasta/mod.rs | 63 --- src/secp256k1/curve.rs | 8 - 12 files changed, 1790 insertions(+), 130 deletions(-) create mode 100644 src/bn256/msm/mod.rs create mode 100644 src/bn256/msm/pr40.rs create mode 100644 src/bn256/msm/round.rs create mode 100644 src/bn256/msm/zcash.rs diff --git a/.gitignore b/.gitignore index bceff729..de51da8b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ Cargo.lock **/*.rs.bk .vscode -**/*.html \ No newline at end of file +**/*.html +*.tmp \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 822c6acc..2effac82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ rand_xorshift = "0.3" ark-std = { version = "0.3" } [dependencies] +rayon = "1.5.1" subtle = "2.4" ff = "0.13.0" group = "0.13.0" diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 277a24ff..2882161b 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -4,6 +4,8 @@ //! This module is temporary, and the extension traits defined here are expected to be //! upstreamed into the `ff` and `group` crates after some refactoring. +use pasta_curves::arithmetic::CurveAffine; + use crate::CurveExt; pub(crate) struct EndoParameters { @@ -17,7 +19,9 @@ pub trait CurveEndo: CurveExt { fn decompose_scalar(e: &Self::ScalarExt) -> (u128, bool, u128, bool); } -pub trait CurveAffineExt: pasta_curves::arithmetic::CurveAffine { +pub trait CurveAffineExt: CurveAffine { + fn decompose_scalar(k: &Self::ScalarExt) -> (u128, bool, u128, bool); + fn endo(&self) -> Self; fn batch_add( points: &mut [Self], output_indices: &[u32], @@ -26,13 +30,6 @@ pub trait CurveAffineExt: pasta_curves::arithmetic::CurveAffine { bases: &[Self], base_positions: &[u32], ); - - /// Unlike the `Coordinates` trait, this just returns the raw affine coordinates without checking `is_on_curve` - fn into_coordinates(self) -> (Self::Base, Self::Base) { - // fallback implementation - let coordinates = self.coordinates().unwrap(); - (*coordinates.x(), *coordinates.y()) - } } /// Compute a + b + carry, returning the result and the new carry over. @@ -170,8 +167,8 @@ mod test { #[test] fn test_glv_mul() { run_glv_mul_test::(); - run_glv_mul_test::(); - run_glv_mul_test::(); + // run_glv_mul_test::(); + // run_glv_mul_test::(git); } #[test] diff --git a/src/bn256/curve.rs b/src/bn256/curve.rs index 981eedd8..61485d77 100644 --- a/src/bn256/curve.rs +++ b/src/bn256/curve.rs @@ -51,17 +51,29 @@ new_curve_impl!( impl CurveAffineExt for G1Affine { batch_add!(); + endo!(ENDO_PARAMS); - fn into_coordinates(self) -> (Self::Base, Self::Base) { - (self.x, self.y) + fn endo(&self) -> Self { + Self { + x: self.x * Self::Base::ZETA, + y: self.y, + } } } +impl CurveEndo for G1 { + endo!(ENDO_PARAMS); +} + impl CurveAffineExt for G2Affine { batch_add!(); + endo!(ENDO_PARAMS); - fn into_coordinates(self) -> (Self::Base, Self::Base) { - (self.x, self.y) + fn endo(&self) -> Self { + Self { + x: self.x * Self::Base::ZETA, + y: self.y, + } } } @@ -127,8 +139,6 @@ const ENDO_PARAMS: EndoParameters = EndoParameters { b2: [0x89d3256894d213e3u64, 0u64, 0u64, 0u64], }; -endo!(G1, Fr, ENDO_PARAMS); - impl group::cofactor::CofactorGroup for G1 { type Subgroup = G1; diff --git a/src/bn256/mod.rs b/src/bn256/mod.rs index 9cd08946..cd8c5bad 100644 --- a/src/bn256/mod.rs +++ b/src/bn256/mod.rs @@ -5,6 +5,7 @@ mod fq12; mod fq2; mod fq6; mod fr; +pub mod msm; #[cfg(feature = "asm")] mod assembly; diff --git a/src/bn256/msm/mod.rs b/src/bn256/msm/mod.rs new file mode 100644 index 00000000..b1688cdf --- /dev/null +++ b/src/bn256/msm/mod.rs @@ -0,0 +1,257 @@ +use super::{Fr, G1Affine}; +use crate::bn256::{msm::round::Round, G1}; +use crate::group::Group; +use ff::PrimeField; +use rayon::{current_num_threads, scope}; + +#[cfg(test)] +mod pr40; +mod round; +#[cfg(test)] +mod zcash; + +macro_rules! div_ceil { + ($a:expr, $b:expr) => { + (($a - 1) / $b) + 1 + }; +} + +macro_rules! double_n { + ($acc:expr, $n:expr) => { + (0..$n).fold($acc, |acc, _| acc.double()) + }; +} + +macro_rules! range { + ($index:expr, $n_items:expr) => { + $index * $n_items..($index + 1) * $n_items + }; +} + +pub struct MSM { + bucket_sizes: Vec, + sorted_positions: Vec, + bucket_indexes: Vec, + bucket_offsets: Vec, + n_windows: usize, + window: usize, + n_buckets: usize, + n_points: usize, + round: Round, +} + +impl MSM { + pub fn alloacate(n_points: usize) -> Self { + fn best_window(n: usize) -> usize { + if n >= 262144 { + 15 + } else if n >= 65536 { + 12 + } else if n >= 16384 { + 11 + } else if n >= 8192 { + 10 + } else if n >= 1024 { + 9 + } else { + 7 + } + } + let window = best_window(n_points); + let n_windows = div_ceil!(Fr::NUM_BITS as usize, window); + let n_buckets = 1 << window; + let round = Round::new(n_buckets, n_points); + MSM { + bucket_indexes: vec![0usize; n_windows * n_points], + bucket_sizes: vec![0usize; n_windows * n_buckets], + sorted_positions: vec![0usize; n_windows * n_points], + bucket_offsets: vec![0; n_buckets], + n_windows, + window, + n_buckets, + n_points, + round, + } + } + + fn decompose(&mut self, scalars: &[Fr]) { + pub(crate) fn get_bits(segment: usize, c: usize, bytes: &[u8]) -> u64 { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + if skip_bytes >= 32 { + return 0; + } + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes[skip_bytes..].iter()) { + *v = *o; + } + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + tmp as u64 + } + let scalars = scalars + .iter() + .map(|scalar| scalar.to_repr()) + .collect::>(); + for window_idx in 0..self.n_windows { + for (point_index, scalar) in scalars.iter().enumerate() { + let bucket_index = get_bits(window_idx, self.window, scalar.as_ref()) as usize; + self.bucket_sizes[window_idx * self.n_buckets + bucket_index] += 1; + self.bucket_indexes[window_idx * self.n_points + point_index] = bucket_index; + } + } + self.sort(); + } + + fn sort(&mut self) { + for w_i in 0..self.n_windows { + let sorted_positions = &mut self.sorted_positions[range!(w_i, self.n_points)]; + let bucket_sizes = &self.bucket_sizes[range!(w_i, self.n_buckets)]; + let bucket_indexes = &self.bucket_indexes[range!(w_i, self.n_points)]; + let mut offset = 0; + for (i, size) in bucket_sizes.iter().enumerate() { + self.bucket_offsets[i] = offset; + offset += size; + } + for (idx, bucket_index) in bucket_indexes.iter().enumerate() { + sorted_positions[self.bucket_offsets[*bucket_index]] = idx; + self.bucket_offsets[*bucket_index] += 1; + } + } + } + + pub fn evalulate(scalars: &[Fr], bases: &[G1Affine], acc: &mut G1) { + let mut msm = Self::alloacate(bases.len()); + msm.decompose(scalars); + for w_i in (0..msm.n_windows).rev() { + if w_i != msm.n_windows - 1 { + *acc = double_n!(*acc, msm.window); + } + msm.round.init( + bases, + &msm.sorted_positions[range!(w_i, msm.n_points)], + &msm.bucket_sizes[range!(w_i, msm.n_buckets)], + ); + let buckets = msm.round.evaluate(); + let mut running_sum = G1::identity(); + for bucket in buckets.into_iter().skip(1).rev() { + running_sum += bucket; + *acc += &running_sum; + } + } + } + + pub fn best(coeffs: &[Fr], bases: &[G1Affine]) -> G1 { + assert_eq!(coeffs.len(), bases.len()); + let num_threads = current_num_threads(); + if coeffs.len() > num_threads { + let chunk = coeffs.len() / num_threads; + let num_chunks = coeffs.chunks(chunk).len(); + let mut results = vec![G1::identity(); num_chunks]; + scope(|scope| { + let chunk = coeffs.len() / num_threads; + + for ((coeffs, bases), acc) in coeffs + .chunks(chunk) + .zip(bases.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + Self::evalulate(coeffs, bases, acc); + }); + } + }); + results.iter().fold(G1::identity(), |a, b| a + b) + } else { + let mut acc = G1::identity(); + Self::evalulate(coeffs, bases, &mut acc); + acc + } + } +} + +#[cfg(test)] +mod test { + use crate::bn256::msm::pr40::{MultiExp, MultiExpContext}; + use crate::bn256::msm::zcash::{best_multiexp_zcash, msm_zcash}; + use crate::bn256::{Fr, G1Affine, G1}; + use crate::group::Group; + use crate::serde::SerdeObject; + use ff::Field; + use group::Curve; + use rand_core::OsRng; + use std::fs::File; + use std::path::Path; + + fn read_data(n: usize) -> (Vec, Vec) { + let mut file = File::open("data.tmp").unwrap(); + (0..n) + .map(|_| { + let point = G1Affine::read_raw_unchecked(&mut file); + let scalar = Fr::read_raw_unchecked(&mut file); + (point, scalar) + }) + .unzip() + } + + pub(crate) fn get_data(n: usize) -> (Vec, Vec) { + const MAX_N: usize = 1 << 22; + assert!(n <= MAX_N); + if Path::new("data.tmp").is_file() { + read_data(n) + } else { + let mut file = File::create("data.tmp").unwrap(); + (0..MAX_N) + .map(|_| { + let point = G1::random(OsRng).to_affine(); + let scalar = Fr::random(OsRng); + point.write_raw(&mut file).unwrap(); + scalar.write_raw(&mut file).unwrap(); + (point, scalar) + }) + .take(n) + .unzip() + } + } + + #[test] + + fn test_msm() { + let (points, scalars) = get_data(1 << 22); + + for k in 10..=22 { + let n_points = 1 << k; + let scalars = &scalars[..n_points]; + let points = &points[..n_points]; + println!("------ {}", k); + + let mut r0 = G1::identity(); + let time = std::time::Instant::now(); + msm_zcash(scalars, points, &mut r0); + println!("zcash serial {:?}", time.elapsed()); + + let time = std::time::Instant::now(); + let r0 = best_multiexp_zcash(scalars, points); + println!("zcash parallel {:?}", time.elapsed()); + + let time = std::time::Instant::now(); + let mut r1 = G1::identity(); + super::MSM::evalulate(scalars, points, &mut r1); + assert_eq!(r0, r1); + println!("this {:?}", time.elapsed()); + + let time = std::time::Instant::now(); + let r1 = super::MSM::best(scalars, points); + assert_eq!(r0, r1); + println!("this parallel {:?}", time.elapsed()); + + let time = std::time::Instant::now(); + let msm = MultiExp::new(&points); + let mut ctx = MultiExpContext::default(); + let _ = msm.evaluate(&mut ctx, scalars, false); + // assert_eq!(r0, r1); // fails + println!("pr40 {:?}", time.elapsed()); + } + } +} diff --git a/src/bn256/msm/pr40.rs b/src/bn256/msm/pr40.rs new file mode 100644 index 00000000..56908ff4 --- /dev/null +++ b/src/bn256/msm/pr40.rs @@ -0,0 +1,982 @@ +//! This module implements a fast method for multi-scalar multiplications. +//! +//! Generally it works like pippenger with a couple of tricks to make if faster. +//! +//! - First the coefficients are split into two parts (using the endomorphism). This +//! reduces the number of rounds by half, but doubles the number of points per round. +//! This is faster because half the rounds also means only needing to add all bucket +//! results together half the number of times. +//! +//! - The coefficients are then sorted in buckets. Instead of using +//! the binary representation to do this, a signed digit representation is +//! used instead (WNAF). Unfortunately this doesn't directly reduce the number of additions +//! in a bucket, but it does reduce the number of buckets in half, which halves the +//! work required to accumulate the results of the buckets. +//! +//! - We then need to add all the points in each bucket together. To do this +//! the affine addition formulas are used. If the points are linearly independent the +//! incomplete version of the formula can be used which is quite a bit faster than +//! the full one because some checks can be skipped. +//! The affine formula is only fast if a lot of independent points can be added +//! together. This is because to get the actual result of an addition an inversion is +//! needed which is very expensive, but it's cheap when batched inversion can be used. +//! So the idea is to add a lot of pairs of points together using a single batched inversion. +//! We then have the results of all those additions, and can do a new batch of additions on those +//! results. This process is repeated as many times as needed until all additions for each bucket +//! are done. To do this efficiently we first build up an addition tree that sets everything +//! up correctly per round. We then process each addition tree per round. + +use core::slice; +pub use ff::Field; +use group::{ff::PrimeField, Group as _}; +pub use rayon::{current_num_threads, scope, Scope}; + +fn num_bits(value: usize) -> usize { + (0usize.leading_zeros() - value.leading_zeros()) as usize +} + +fn div_up(a: usize, b: usize) -> usize { + (a + (b - 1)) / b +} + +fn get_wnaf_size_bits(num_bits: usize, w: usize) -> usize { + div_up(num_bits, w) +} + +fn get_wnaf_size(w: usize) -> usize { + get_wnaf_size_bits(div_up(C::Scalar::NUM_BITS as usize, 2), w) +} + +fn get_num_rounds(c: usize) -> usize { + get_wnaf_size::(c + 1) +} + +fn get_num_buckets(c: usize) -> usize { + (1 << c) + 1 +} + +fn get_max_tree_size(num_points: usize, c: usize) -> usize { + num_points * 2 + get_num_buckets(c) +} + +fn get_num_tree_levels(num_points: usize) -> usize { + 1 + num_bits(num_points - 1) +} + +/// Returns the signed digit representation of value with the specified window size. +/// The result is written to the wnaf slice with the specified stride. +fn get_wnaf(value: u128, w: usize, num_rounds: usize, wnaf: &mut [u32], stride: usize) { + fn get_bits_at(v: u128, pos: usize, num: usize) -> usize { + ((v >> pos) & ((1 << num) - 1)) as usize + } + + let mut borrow = 0; + let max = 1 << (w - 1); + for idx in 0..num_rounds { + let b = get_bits_at(value, idx * w, w) + borrow; + if b >= max { + // Set the highest bit to 1 to represent a negative value. + // This way the lower bits directly represent the bucket index. + wnaf[idx * stride] = (0x80000000 | ((1 << w) - b)) as u32; + borrow = 1; + } else { + wnaf[idx * stride] = b as u32; + borrow = 0; + } + } + assert_eq!(borrow, 0); +} + +/// Returns the best bucket width for the given number of points. +fn get_best_c(num_points: usize) -> usize { + if num_points >= 262144 { + 15 + } else if num_points >= 65536 { + 12 + } else if num_points >= 16384 { + 11 + } else if num_points >= 8192 { + 10 + } else if num_points >= 1024 { + 9 + } else { + 7 + } +} + +/// MultiExp +#[derive(Clone, Debug, Default)] +pub struct MultiExp { + /// The bases + bases: Vec, +} + +/// MultiExp context object +#[derive(Clone, Debug, Default)] +pub struct MultiExpContext { + /// Memory to store the points in the addition tree + points: Vec, + /// Memory to store wnafs + wnafs: Vec, + /// Memory split up between rounds + rounds: SharedRoundData, +} + +/// SharedRoundData +#[derive(Clone, Debug, Default)] +struct SharedRoundData { + /// Memory to store bucket sizes + bucket_sizes: Vec, + /// Memory to store bucket offsets + bucket_offsets: Vec, + /// Memory to store the point data + point_data: Vec, + /// Memory to store the output indices + output_indices: Vec, + /// Memory to store the base positions (on the first level) + base_positions: Vec, + /// Memory to store the scatter maps + scatter_map: Vec, +} + +/// RoundData +#[derive(Debug, Default)] +struct RoundData<'a> { + /// Number of levels in the addition tree + pub num_levels: usize, + /// The length of each level in the addition tree + pub level_sizes: Vec, + /// The offset to each level in the addition tree + pub level_offset: Vec, + /// The size of each bucket + pub bucket_sizes: &'a mut [usize], + /// The offset of each bucket + pub bucket_offsets: &'a mut [usize], + /// The point to use for each coefficient + pub point_data: &'a mut [u32], + /// The output index in the point array for each pair addition + pub output_indices: &'a mut [u32], + /// The point to use on the first level in the addition tree + pub base_positions: &'a mut [u32], + /// List of points that are scattered to the addition tree + pub scatter_map: &'a mut [ScatterData], + /// The length of scatter_map + pub scatter_map_len: usize, +} + +/// ScatterData +#[derive(Default, Debug, Clone)] +struct ScatterData { + /// The position in the addition tree to store the point + pub position: u32, + /// The point to write + pub point_data: u32, +} + +impl MultiExp { + /// Create a new MultiExp instance with the specified bases + pub fn new(bases: &[C]) -> Self { + let mut endo_bases = vec![C::identity(); bases.len() * 2]; + + // Generate the endomorphism bases + let num_threads = current_num_threads(); + scope(|scope| { + let num_points_per_thread = div_up(bases.len(), num_threads); + for (endo_bases, bases) in endo_bases + .chunks_mut(num_points_per_thread * 2) + .zip(bases.chunks(num_points_per_thread)) + { + scope.spawn(move |_| { + for (idx, base) in bases.iter().enumerate() { + endo_bases[idx * 2] = *base; + endo_bases[idx * 2 + 1] = C::endo(base); + } + }); + } + }); + + Self { bases: endo_bases } + } + + /// Performs a multi-exponentiation operation. + /// Set complete to true if the bases are not guaranteed linearly independent. + pub fn evaluate( + &self, + ctx: &mut MultiExpContext, + coeffs: &[C::Scalar], + complete: bool, + ) -> C::Curve { + self.evaluate_with(ctx, coeffs, complete, get_best_c(coeffs.len())) + } + + /// Performs a multi-exponentiation operation with the given bucket width. + /// Set complete to true if the bases are not guaranteed linearly independent. + pub fn evaluate_with( + &self, + ctx: &mut MultiExpContext, + coeffs: &[C::Scalar], + complete: bool, + c: usize, + ) -> C::Curve { + assert!(coeffs.len() * 2 <= self.bases.len()); + assert!(c >= 4); + + // Allocate more memory if required + ctx.allocate(coeffs.len(), c); + + // Get the data for each round + let mut rounds = ctx.rounds.get_rounds::(coeffs.len(), c); + + // Get the bases for the coefficients + let bases = &self.bases[..coeffs.len() * 2]; + + let num_threads = current_num_threads(); + let start = start_measure( + format!("msm {} ({}) ({} threads)", coeffs.len(), c, num_threads), + false, + ); + // if coeffs.len() >= 16 { + let num_points = coeffs.len() * 2; + let w = c + 1; + let num_rounds = get_num_rounds::(c); + + // Prepare WNAFs of all coefficients for all rounds + calculate_wnafs::(coeffs, &mut ctx.wnafs, c); + // Sort WNAFs into buckets for all rounds + sort::(&mut ctx.wnafs[0..num_rounds * num_points], &mut rounds, c); + // Calculate addition trees for all rounds + create_addition_trees(&mut rounds); + + // Now process each round individually + let mut partials = vec![C::Curve::identity(); num_rounds]; + for (round, acc) in rounds.iter().zip(partials.iter_mut()) { + // Scatter the odd points in the odd length buckets to the addition tree + do_point_scatter::(round, bases, &mut ctx.points); + // Do all bucket additions + do_batch_additions::(round, bases, &mut ctx.points, complete); + // Get the final result of the round + *acc = accumulate_buckets::(round, &mut ctx.points, c); + } + + // Accumulate round results + let res = partials + .iter() + .rev() + .skip(1) + .fold(partials[num_rounds - 1], |acc, partial| { + let mut res = acc; + for _ in 0..w { + res = res.double(); + } + res + partial + }); + stop_measure(start); + + res + // } else { + // // Just do a naive msm + // let mut acc = C::Curve::identity(); + // for (idx, coeff) in coeffs.iter().enumerate() { + // // Skip over endomorphism bases + // acc += bases[idx * 2] * coeff; + // } + // stop_measure(start); + // acc + // } + } +} + +impl MultiExpContext { + /// Allocate memory for the evalution + pub fn allocate(&mut self, num_points: usize, c: usize) { + let num_points = num_points * 2; + let num_buckets = get_num_buckets(c); + let num_rounds = get_num_rounds::(c); + let tree_size = get_max_tree_size(num_points, c); + let num_points_total = num_rounds * num_points; + let num_buckets_total = num_rounds * num_buckets; + let tree_size_total = num_rounds * tree_size; + + // Allocate memory when necessary + if self.points.len() < tree_size { + self.points.resize(tree_size, C::identity()); + } + if self.wnafs.len() < num_points_total { + self.wnafs.resize(num_points_total, 0u32); + } + if self.rounds.bucket_sizes.len() < num_buckets_total { + self.rounds.bucket_sizes.resize(num_buckets_total, 0usize); + } + if self.rounds.bucket_offsets.len() < num_buckets_total { + self.rounds.bucket_offsets.resize(num_buckets_total, 0usize); + } + if self.rounds.point_data.len() < num_points_total { + self.rounds.point_data.resize(num_points_total, 0u32); + } + if self.rounds.output_indices.len() < tree_size_total / 2 { + self.rounds.output_indices.resize(tree_size_total / 2, 0u32); + } + if self.rounds.base_positions.len() < num_points_total { + self.rounds.base_positions.resize(num_points_total, 0u32); + } + if self.rounds.scatter_map.len() < num_buckets_total { + self.rounds + .scatter_map + .resize(num_buckets_total, ScatterData::default()); + } + } +} + +impl SharedRoundData { + fn get_rounds(&mut self, num_points: usize, c: usize) -> Vec { + let num_points = num_points * 2; + let num_buckets = get_num_buckets(c); + let num_rounds = get_num_rounds::(c); + let tree_size = num_points * 2 + num_buckets; + + let mut bucket_sizes_rest = self.bucket_sizes.as_mut_slice(); + let mut bucket_offsets_rest = self.bucket_offsets.as_mut_slice(); + let mut point_data_rest = self.point_data.as_mut_slice(); + let mut output_indices_rest = self.output_indices.as_mut_slice(); + let mut base_positions_rest = self.base_positions.as_mut_slice(); + let mut scatter_map_rest = self.scatter_map.as_mut_slice(); + + // Use the allocated memory above to init the memory used for each round. + // This way the we don't need to reallocate memory for each msm with + // a different configuration (different number of points or different bucket width) + let mut rounds: Vec = Vec::with_capacity(num_rounds); + for _ in 0..num_rounds { + let (bucket_sizes, rest) = bucket_sizes_rest.split_at_mut(num_buckets); + bucket_sizes_rest = rest; + let (bucket_offsets, rest) = bucket_offsets_rest.split_at_mut(num_buckets); + bucket_offsets_rest = rest; + let (point_data, rest) = point_data_rest.split_at_mut(num_points); + point_data_rest = rest; + let (output_indices, rest) = output_indices_rest.split_at_mut(tree_size / 2); + output_indices_rest = rest; + let (base_positions, rest) = base_positions_rest.split_at_mut(num_points); + base_positions_rest = rest; + let (scatter_map, rest) = scatter_map_rest.split_at_mut(num_buckets); + scatter_map_rest = rest; + + rounds.push(RoundData { + num_levels: 0, + level_sizes: vec![], + level_offset: vec![], + bucket_sizes, + bucket_offsets, + point_data, + output_indices, + base_positions, + scatter_map, + scatter_map_len: 0, + }); + } + rounds + } +} + +#[derive(Clone, Copy)] +struct ThreadBox(*mut T, usize); +#[allow(unsafe_code)] +unsafe impl Send for ThreadBox {} +#[allow(unsafe_code)] +unsafe impl Sync for ThreadBox {} + +/// Wraps a mutable slice so it can be passed into a thread without +/// hard to fix borrow checks caused by difficult data access patterns. +impl ThreadBox { + fn wrap(data: &mut [T]) -> Self { + Self(data.as_mut_ptr(), data.len()) + } + + fn unwrap(&mut self) -> &mut [T] { + #[allow(unsafe_code)] + unsafe { + slice::from_raw_parts_mut(self.0, self.1) + } + } +} + +fn calculate_wnafs(coeffs: &[C::Scalar], wnafs: &mut [u32], c: usize) { + let num_threads = current_num_threads(); + let num_points = coeffs.len() * 2; + let num_rounds = get_num_rounds::(c); + let w = c + 1; + + let start = start_measure("calculate wnafs".to_string(), false); + let mut wnafs_box = ThreadBox::wrap(wnafs); + let chunk_size = div_up(coeffs.len(), num_threads); + scope(|scope| { + for (thread_idx, coeffs) in coeffs.chunks(chunk_size).enumerate() { + scope.spawn(move |_| { + let wnafs = &mut wnafs_box.unwrap()[thread_idx * chunk_size * 2..]; + for (idx, coeff) in coeffs.iter().enumerate() { + let (p0, _, p1, _) = C::decompose_scalar(coeff); + get_wnaf(p0, w, num_rounds, &mut wnafs[idx * 2..], num_points); + get_wnaf(p1, w, num_rounds, &mut wnafs[idx * 2 + 1..], num_points); + } + }); + } + }); + stop_measure(start); +} + +fn radix_sort(wnafs: &mut [u32], round: &mut RoundData) { + let bucket_sizes = &mut round.bucket_sizes; + let bucket_offsets = &mut round.bucket_offsets; + + // Calculate bucket sizes, first resetting all sizes to 0 + bucket_sizes.fill_with(|| 0); + for wnaf in wnafs.iter() { + bucket_sizes[(wnaf & 0x7FFFFFFF) as usize] += 1; + } + + // Calculate bucket offsets + let mut offset = 0; + let mut max_bucket_size = 0; + bucket_offsets[0] = offset; + offset += bucket_sizes[0]; + for (bucket_offset, bucket_size) in bucket_offsets + .iter_mut() + .skip(1) + .zip(bucket_sizes.iter().skip(1)) + { + *bucket_offset = offset; + offset += bucket_size; + max_bucket_size = max_bucket_size.max(*bucket_size); + } + // Number of levels we need in our addition tree + round.num_levels = get_num_tree_levels(max_bucket_size); + + // Fill in point data grouped in buckets + let point_data = &mut round.point_data; + for (idx, wnaf) in wnafs.iter().enumerate() { + let bucket_idx = (wnaf & 0x7FFFFFFF) as usize; + point_data[bucket_offsets[bucket_idx]] = (wnaf & 0x80000000) | (idx as u32); + bucket_offsets[bucket_idx] += 1; + } +} + +/// Sorts the points so they are grouped per bucket +fn sort(wnafs: &mut [u32], rounds: &mut [RoundData], c: usize) { + let num_rounds = get_num_rounds::(c); + let num_points = wnafs.len() / num_rounds; + + // Sort per bucket for each round separately + let start = start_measure("radix sort".to_string(), false); + scope(|scope| { + for (round, wnafs) in rounds.chunks_mut(1).zip(wnafs.chunks_mut(num_points)) { + scope.spawn(move |_| { + radix_sort(wnafs, &mut round[0]); + }); + } + }); + stop_measure(start); +} + +/// Creates the addition tree. +/// When PREPROCESS is false we just calculate the size of each level. +/// All points in a bucket need to be added to each other. Because the affine formulas +/// are used we need to add points together in pairs. So we have to make sure that +/// on each level we have an even number of points for each level. Odd points are +/// added to lower levels where the length of the addition results is odd (which then +/// makes the length even). +fn process_addition_tree(round: &mut RoundData) { + let num_levels = round.num_levels; + let bucket_sizes = &round.bucket_sizes; + let point_data = &round.point_data; + + let mut level_sizes = vec![0usize; num_levels]; + let mut level_offset = vec![0usize; num_levels]; + let output_indices = &mut round.output_indices; + let scatter_map = &mut round.scatter_map; + let base_positions = &mut round.base_positions; + let mut point_idx = bucket_sizes[0]; + + if !PREPROCESS { + // Set the offsets to the different levels in the tree + level_offset[0] = 0; + for idx in 1..level_offset.len() { + level_offset[idx] = level_offset[idx - 1] + round.level_sizes[idx - 1]; + } + } + + // The level where all bucket results will be stored + let bucket_level = num_levels - 1; + + // Run over all buckets + for bucket_size in bucket_sizes.iter().skip(1) { + let mut size = *bucket_size; + if size == 0 { + level_sizes[bucket_level] += 1; + } else if size == 1 { + if !PREPROCESS { + scatter_map[round.scatter_map_len] = ScatterData { + position: (level_offset[bucket_level] + level_sizes[bucket_level]) as u32, + point_data: point_data[point_idx], + }; + round.scatter_map_len += 1; + point_idx += 1; + } + level_sizes[bucket_level] += 1; + } else { + #[derive(Clone, Copy, PartialEq)] + enum State { + Even, + OddPoint(usize), + OddResult(usize), + } + let mut state = State::Even; + let num_levels_bucket = get_num_tree_levels(size); + + let mut start_level_size = level_sizes[0]; + for level in 0..num_levels_bucket - 1 { + let is_level_odd = size & 1; + let first_level = level == 0; + let last_level = level == num_levels_bucket - 2; + + // If this level has odd size we have to handle it + if is_level_odd == 1 { + // If we already have a point saved from a previous odd level, use it + // to make the current level even + if state != State::Even { + if !PREPROCESS { + let pos = (level_offset[level] + level_sizes[level]) as u32; + match state { + State::OddPoint(point_idx) => { + scatter_map[round.scatter_map_len] = ScatterData { + position: pos, + point_data: point_data[point_idx], + }; + round.scatter_map_len += 1; + } + State::OddResult(output_idx) => { + output_indices[output_idx] = pos; + } + _ => unreachable!(), + }; + } + level_sizes[level] += 1; + size += 1; + state = State::Even; + } else { + // Not odd yet, so the state is now odd + // Store the point we have to add later + if !PREPROCESS { + if first_level { + state = State::OddPoint(point_idx + size - 1); + } else { + state = State::OddResult( + (level_offset[level] + level_sizes[level] + size) >> 1, + ); + } + } else { + // Just mark it as odd, we won't use the actual value anywhere + state = State::OddPoint(0); + } + size -= 1; + } + } + + // Write initial points on the first level + if first_level { + if !PREPROCESS { + // Just write all points (except the odd size one) + let pos = level_offset[level] + level_sizes[level]; + base_positions[pos..pos + size] + .copy_from_slice(&point_data[point_idx..point_idx + size]); + point_idx += size + is_level_odd; + } + level_sizes[level] += size; + } + + // Write output indices + // If the next level would be odd, we have to make it even + // by writing the last result of this level to the next level that is odd + // (unless we are writing the final result to the bucket level) + let next_level_size = size >> 1; + let next_level_odd = next_level_size & 1 == 1; + let redirect = + if next_level_odd && state == State::Even && level < num_levels_bucket - 2 { + 1usize + } else { + 0usize + }; + // An addition works on two points and has one result, so this takes only half the size + let sub_level_offset = (level_offset[level] + start_level_size) >> 1; + // Cache the start position of the next level + start_level_size = level_sizes[level + 1]; + if !PREPROCESS { + // Write the destination positions of the addition results in the tree + let dst_pos = level_offset[level + 1] + level_sizes[level + 1]; + for (idx, output_index) in output_indices + [sub_level_offset..sub_level_offset + next_level_size] + .iter_mut() + .enumerate() + { + *output_index = (dst_pos + idx) as u32; + } + } + if last_level { + // The result of the last addition for this bucket is written + // to the last level (so all bucket results are nicely after each other). + // Overwrite the output locations of the last result here. + if !PREPROCESS { + output_indices[sub_level_offset] = + (level_offset[bucket_level] + level_sizes[bucket_level]) as u32; + } + level_sizes[bucket_level] += 1; + } else { + // Update the sizes + level_sizes[level + 1] += next_level_size - redirect; + size -= redirect; + // We have to redirect the last result to a lower level + if redirect == 1 { + state = State::OddResult(sub_level_offset + next_level_size - 1); + } + } + + // We added pairs of points together so the next level has half the size + size >>= 1; + } + } + } + + // Store the tree level data + round.level_sizes = level_sizes; + round.level_offset = level_offset; +} + +/// The affine formula is only efficient for independent point additions +/// (using the result of the addition requires an inversion which needs to be avoided as much as possible). +/// And so we try to add as many points together on each level of the tree, writing the result of the addition +/// to a lower level. Each level thus contains independent point additions, with only requiring a single inversion +/// per level in the tree. +fn create_addition_trees(rounds: &mut [RoundData]) { + let start = start_measure("create addition trees".to_string(), false); + scope(|scope| { + for round in rounds.chunks_mut(1) { + scope.spawn(move |_| { + // Collect tree levels sizes + process_addition_tree::(&mut round[0]); + // Construct the tree + process_addition_tree::(&mut round[0]); + }); + } + }); + stop_measure(start); +} + +/// Here we write the odd points in odd length buckets (the other points are loaded on the fly). +/// This will do random reads AND random writes, which is normally terrible for performance. +/// Luckily this doesn't really matter because we only have to write at most num_buckets points. +fn do_point_scatter(round: &RoundData, bases: &[C], points: &mut [C]) { + let num_threads = current_num_threads(); + let scatter_map = &round.scatter_map[..round.scatter_map_len]; + let mut points_box = ThreadBox::wrap(points); + let start = start_measure("point scatter".to_string(), false); + if !scatter_map.is_empty() { + scope(|scope| { + let num_copies_per_thread = div_up(scatter_map.len(), num_threads); + for scatter_map in scatter_map.chunks(num_copies_per_thread) { + scope.spawn(move |_| { + let points = points_box.unwrap(); + for scatter_data in scatter_map.iter() { + let target_idx = scatter_data.position as usize; + let negate = scatter_data.point_data & 0x80000000 != 0; + let base_idx = (scatter_data.point_data & 0x7FFFFFFF) as usize; + if negate { + points[target_idx] = bases[base_idx].neg(); + } else { + points[target_idx] = bases[base_idx]; + } + } + }); + } + }); + } + stop_measure(start); +} + +/// Finally do all additions using the addition tree we've setup. +fn do_batch_additions( + round: &RoundData, + bases: &[C], + points: &mut [C], + complete: bool, +) { + let num_threads = current_num_threads(); + + let num_levels = round.num_levels; + let level_counter = &round.level_sizes; + let level_offset = &round.level_offset; + let output_indices = &round.output_indices; + let base_positions = &round.base_positions; + let mut points_box = ThreadBox::wrap(points); + + let start = start_measure("batch additions".to_string(), false); + for i in 0..num_levels - 1 { + let start = level_offset[i]; + let num_points = level_counter[i]; + scope(|scope| { + // We have to make sure we have an even amount here so we don't split within a pair + let num_points_per_thread = div_up(num_points / 2, num_threads) * 2; + for thread_idx in 0..num_threads { + scope.spawn(move |_| { + let points = points_box.unwrap(); + + let thread_start = thread_idx * num_points_per_thread; + let mut thread_num_points = num_points_per_thread; + + if thread_start < num_points { + if thread_start + thread_num_points > num_points { + thread_num_points = num_points - thread_start; + } + + let points = &mut points[(start + thread_start)..]; + let output_indices = &output_indices[(start + thread_start) / 2..]; + let offset = start + thread_start; + if i == 0 { + let base_positions = &base_positions[(start + thread_start)..]; + if complete { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + bases, + base_positions, + ); + } else { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + bases, + base_positions, + ); + } + } else { + #[allow(collapsible-else-if)] + if complete { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + &[], + &[], + ); + } else { + C::batch_add::( + points, + output_indices, + thread_num_points, + offset, + &[], + &[], + ); + } + } + } + }); + } + }); + } + stop_measure(start); +} + +/// Accumulate all bucket results to get the result of the round +fn accumulate_buckets( + round: &RoundData, + points: &mut [C], + c: usize, +) -> C::Curve { + let num_threads = current_num_threads(); + let num_buckets = get_num_buckets(c); + + let num_levels = round.num_levels; + let bucket_sizes = &round.bucket_sizes; + let level_offset = &round.level_offset; + + let start_time = start_measure("accumulate buckets".to_string(), false); + let start = level_offset[num_levels - 1]; + let buckets = &mut points[start..(start + num_buckets)]; + let mut results: Vec = vec![C::Curve::identity(); num_threads]; + scope(|scope| { + let chunk_size = num_buckets / num_threads; + for (thread_idx, ((bucket_sizes, buckets), result)) in bucket_sizes[1..] + .chunks(chunk_size) + .zip(buckets[..].chunks_mut(chunk_size)) + .zip(results.chunks_mut(1)) + .enumerate() + { + scope.spawn(move |_| { + // Accumulate all bucket results + let num_buckets_thread = bucket_sizes.len(); + let mut acc = C::Curve::identity(); + let mut running_sum = C::Curve::identity(); + for b in (0..num_buckets_thread).rev() { + if bucket_sizes[b] > 0 { + running_sum = running_sum + buckets[b]; + } + acc = acc + &running_sum; + } + + // Each thread started at a different bucket location + // so correct for that here + let bucket_start = thread_idx * chunk_size; + let num_bits = num_bits(bucket_start); + let mut accumulator = C::Curve::identity(); + for idx in (0..num_bits).rev() { + accumulator = accumulator.double(); + if (bucket_start >> idx) & 1 != 0 { + accumulator += running_sum; + } + } + acc += accumulator; + + // Store the result + result[0] = acc; + }); + } + }); + stop_measure(start_time); + + // Add the results of all threads together + results + .iter() + .fold(C::Curve::identity(), |acc, result| acc + result) +} + +use crate::bn256::{Fr, G1Affine}; +use crate::group::prime::PrimeCurveAffine; +use crate::CurveAffineExt; + +fn get_random_data(n: usize) -> (Vec, Vec) { + use rand::SeedableRng; + use rand_xorshift::XorShiftRng; + + let mut bases = vec![G1Affine::identity(); n]; + // parallelize(&mut bases, |bases, _| { + // let mut rng = rand::thread_rng(); + let mut rng = XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, + 0xe5, + ]); + // let base_rnd = G1Affine::random(&mut rng); + for base in bases.iter_mut() { + if INDEPENDENT { + *base = G1Affine::random(&mut rng); + } else { + unreachable!() + // *base = base_rnd; + } + } + // }); + + let mut coeffs = vec![Fr::zero(); n]; + // parallelize(&mut coeffs, |coeffs, _| { + for coeff in coeffs.iter_mut() { + *coeff = Fr::random(&mut rng); + // *coeff = Fr::from(1u64); + } + // }); + + (bases, coeffs) +} + +#[test] +// #[ignore] +fn test_multiexp_bench() { + let min_k = 10; + let max_k = 20; + let n = 1 << max_k; + let (bases, coeffs) = get_random_data::(n); + + let msm = MultiExp::new(&bases); + let mut ctx = MultiExpContext::default(); + for k in min_k..=max_k { + let n = 1 << k; + let coeffs = &coeffs[..n]; + + let start = start_measure("msm".to_string(), false); + msm.evaluate(&mut ctx, coeffs, false); + let duration = stop_measure(start); + + println!("{} {}: {}s", k, n, (duration as f32) / 1000000.0); + } +} + +use std::{ + env::var, + sync::atomic::{AtomicUsize, Ordering}, + time::Instant, +}; + +#[allow(missing_debug_implementations)] +pub struct MeasurementInfo { + /// Show measurement + pub show: bool, + /// The start time + pub time: Instant, + /// What is being measured + pub message: String, + /// The indent + pub indent: usize, +} + +/// Global indent counter +pub static NUM_INDENT: AtomicUsize = AtomicUsize::new(0); + +/// Gets the time difference between the current time and the passed in time +pub fn get_duration(start: Instant) -> usize { + let final_time = Instant::now() - start; + let secs = final_time.as_secs() as usize; + let millis = final_time.subsec_millis() as usize; + let micros = (final_time.subsec_micros() % 1000) as usize; + secs * 1000000 + millis * 1000 + micros +} + +/// Prints a measurement on screen +pub fn log_measurement(indent: Option, msg: String, duration: usize) { + let indent = indent.unwrap_or(0); + println!( + "{}{} ........ {}s", + "*".repeat(indent), + msg, + (duration as f32) / 1000000.0 + ); +} + +/// Starts a measurement +pub fn start_measure(msg: String, always: bool) -> MeasurementInfo { + let measure = env_value("MEASURE", 0); + let indent = NUM_INDENT.fetch_add(1, Ordering::Relaxed); + MeasurementInfo { + show: always || measure == 1, + time: Instant::now(), + message: msg, + indent, + } +} + +/// Stops a measurement, returns the duration +pub fn stop_measure(info: MeasurementInfo) -> usize { + NUM_INDENT.fetch_sub(1, Ordering::Relaxed); + let duration = get_duration(info.time); + if info.show { + log_measurement(Some(info.indent), info.message, duration); + } + duration +} + +/// Gets the ENV variable if defined, otherwise returns the default value +pub fn env_value(key: &str, default: usize) -> usize { + match var(key) { + Ok(val) => val.parse().unwrap(), + Err(_) => default, + } +} diff --git a/src/bn256/msm/round.rs b/src/bn256/msm/round.rs new file mode 100644 index 00000000..430d65c3 --- /dev/null +++ b/src/bn256/msm/round.rs @@ -0,0 +1,361 @@ +use super::super::{Fq, G1Affine}; +use crate::group::prime::PrimeCurveAffine; +use ff::Field; + +macro_rules! log_ceil { + ($a:expr) => { + $a.next_power_of_two().trailing_zeros() + }; +} + +pub(crate) struct Round { + t: Vec, + odd_points: Vec, + bucket_sizes: Vec, + n_buckets: usize, + n_points: usize, + out_off: usize, + in_off: usize, +} + +macro_rules! first_phase { + ($out:expr, $p0:expr, $p1:expr, $acc:expr) => { + $out.x = $p0.x + $p1.x; + $p1.x = $p1.x - $p0.x; + $p1.y = ($p1.y - $p0.y) * $acc; + $acc = $acc * $p1.x; + }; +} +macro_rules! second_phase { + ($out:expr, $p0:expr, $p1:expr, $acc:expr) => { + $p1.y = $p1.y * $acc; + $acc = $acc * $p1.x; + $out.x = $p1.y.square() - $out.x; + $out.y = ($p1.y * ($p0.x - $out.x)) - $p0.y; + }; +} + +impl Round { + pub(crate) fn new(n_buckets: usize, n_points: usize) -> Self { + let odd_points = vec![G1Affine::identity(); n_buckets]; + let bucket_sizes = vec![0; n_buckets]; + // TODO: requires less than allocated + let t = vec![G1Affine::identity(); n_points * 2]; + Self { + t, + odd_points, + bucket_sizes, + n_buckets, + n_points, + out_off: 0, + in_off: 0, + } + } + pub(crate) fn init(&mut self, bases: &[G1Affine], positions: &[usize], bucket_sizes: &[usize]) { + { + assert_eq!(self.n_points, positions.len()); + assert_eq!(self.n_points, bases.len()); + assert_eq!(self.bucket_sizes.len(), self.n_buckets); + } + let n_additions = bucket_sizes + .iter() + .map(|bucket_sizes| bucket_sizes / 2) + .collect::>(); + let mut out_off = n_additions.iter().sum::(); + let mut tmp_off = 0; + let mut position_off = 0; + let mut acc = Fq::ONE; + for (bucket_index, bucket_size) in bucket_sizes.iter().enumerate() { + let positions = &positions[position_off..position_off + bucket_size]; + position_off += bucket_size; + let bucket_size_pre = positions.len(); + if bucket_size_pre == 0 { + self.odd_points[bucket_index] = G1Affine::identity(); + self.bucket_sizes[bucket_index] = 0; + } else { + let mut in_off = 0; + let bucket_size_post = (bucket_size_pre + 1) / 2; + let n_additions = bucket_size_pre / 2; + // process even number of additions + for _ in 0..n_additions & (usize::MAX - 1) { + // second operand must be mutable + self.t[tmp_off] = bases[positions[in_off]]; + first_phase!( + self.t[out_off], + bases[positions[in_off + 1]], + self.t[tmp_off], + acc + ); + tmp_off += 1; + out_off += 1; + in_off += 2; + } + // process the latest elements if there are odd number of additions + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + // 1 base point left + // move to odd-point cache + (true, true) => { + assert_eq!(positions.len() - 1, in_off); + self.odd_points[bucket_index] = bases[positions[in_off]]; + } + // 2 base point left + // move addition result to odd-point cache + (false, true) => { + self.t[tmp_off] = bases[positions[in_off]]; + first_phase!( + self.odd_points[bucket_index], + bases[positions[in_off + 1]], + self.t[tmp_off], + acc + ); + tmp_off += 1; + } + // 3 base point left + // move addition of first two to intermediate and last to odd-point cache + (true, false) => { + self.t[tmp_off] = bases[positions[in_off]]; + first_phase!( + self.t[out_off], + bases[positions[in_off + 1]], + self.t[tmp_off], + acc + ); + self.t[out_off + 1] = bases[positions[in_off + 2]]; + tmp_off += 1; + out_off += 2; + } + _ => { /* 0 base point left */ } + } + self.bucket_sizes[bucket_index] = bucket_size_post; + } + } + self.in_off = tmp_off; + self.out_off = out_off; + tmp_off -= 1; + out_off -= 1; + acc = acc.invert().unwrap(); + for (bucket_index, bucket_size_pre) in bucket_sizes.iter().enumerate().rev() { + let positions = &positions[position_off - bucket_size_pre..position_off]; + position_off -= bucket_size_pre; + if *bucket_size_pre == 0 { + // already updated in first phase + } else { + if positions.len() != 0 { + let bucket_size_post = (bucket_size_pre + 1) / 2; + let n_additions = bucket_size_pre / 2; + let mut in_off = positions.len() - 1; + // process the latest elements if there are odd number of additions + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + // 1 base point left + // move to odd-point cache + (true, true) => { + in_off -= 1; + } + // 2 base point left + // move addition result to odd-point cache + (false, true) => { + second_phase!( + self.odd_points[bucket_index], + bases[positions[in_off]], + self.t[tmp_off], + acc + ); + tmp_off -= 1; + in_off -= 2; + } + // 3 base point left + // move addition of first two to intermediate and last to odd-point cache + (true, false) => { + in_off -= 1; + out_off -= 1; + second_phase!( + self.t[out_off], + bases[positions[in_off]], + self.t[tmp_off], + acc + ); + tmp_off -= 1; + in_off -= 2; + out_off -= 1; + } + _ => { /* 0 base point left */ } + } + // process even number of additions + for _ in (0..n_additions & (usize::MAX - 1)).rev() { + second_phase!( + self.t[out_off], + bases[positions[in_off]], + self.t[tmp_off], + acc + ); + tmp_off -= 1; + out_off -= 1; + in_off -= 2; + } + } + } + } + } + fn batch_add(&mut self) { + let (mut out_off, mut in_off) = (self.out_off, self.in_off); + let mut acc = Fq::ONE; + for bucket_index in 0..self.n_buckets { + let bucket_size_pre = self.bucket_sizes[bucket_index]; + let n_additions = bucket_size_pre / 2; + let bucket_size_post = (bucket_size_pre + 1) / 2; + for _ in 0..n_additions & (usize::MAX - 1) { + first_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + (out_off, in_off) = (out_off + 1, in_off + 2); + } + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + (true, false) => { + first_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + self.t[out_off + 1] = self.odd_points[bucket_index]; + out_off += 2; + in_off += 2; + } + (false, true) => { + first_phase!( + self.odd_points[bucket_index], + self.t[in_off], + self.t[in_off + 1], + acc + ); + in_off += 2; + } + _ => { /* clean sheets */ } + } + } + self.out_off = out_off; + self.in_off = in_off; + out_off -= 1; + in_off -= 2; + acc = acc.invert().unwrap(); + // process second phase + for bucket_index in (0..self.n_buckets).rev() { + let bucket_size_pre = self.bucket_sizes[bucket_index]; + let n_additions = bucket_size_pre / 2; + let bucket_size_post = (bucket_size_pre + 1) / 2; + + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + (true, false) => { + out_off -= 1; + second_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + out_off -= 1; + in_off -= 2; + } + (false, true) => { + second_phase!( + self.odd_points[bucket_index], + self.t[in_off], + self.t[in_off + 1], + acc + ); + in_off -= 2; + } + _ => { /* clean sheets */ } + } + + for _ in 0..n_additions & (usize::MAX - 1) { + second_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + out_off -= 1; + in_off -= 2; + } + self.bucket_sizes[bucket_index] = bucket_size_post; + } + } + + fn max_tree_height(&self) -> usize { + *self.tree_heights().iter().max().unwrap() + } + fn tree_heights(&self) -> Vec { + self.bucket_sizes + .iter() + .map(|bucket_size| log_ceil!(bucket_size) as usize) + .collect() + } + pub fn evaluate(&mut self) -> &[G1Affine] { + for _ in 0..self.max_tree_height() { + self.batch_add(); + } + &self.odd_points + } +} + +#[cfg(test)] +mod tests { + use super::Round; + use crate::bn256::msm::test::get_data; + use crate::bn256::{G1Affine, G1}; + use crate::group::Group; + use crate::CurveAffine; + use group::Curve; + use rand::seq::SliceRandom; + use rand::Rng; + use rand_core::OsRng; + + pub(crate) fn rand_positions( + rng: &mut impl Rng, + n_buckets: usize, + n_points: usize, + ) -> (Vec, Vec) { + let mut positions: Vec> = vec![vec![]; n_buckets]; + (0..n_points).for_each(|i| { + let index = rng.gen_range(0..n_buckets); + positions[index].push(i); + }); + positions.iter_mut().for_each(|positions| { + positions.shuffle(rng); + }); + let bucket_sizes = positions.iter().map(|positions| positions.len()).collect(); + let positions: Vec = positions.iter().flatten().cloned().collect(); + (positions, bucket_sizes) + } + + impl Round { + fn sanity_check(&self, bases: &[G1Affine], positions: &[usize], bucket_sizes: &[usize]) { + { + let mut off = 0; + let sums: Vec<_> = bucket_sizes + .iter() + .map(|bucket_size| { + let sum = (off..off + bucket_size) + .map(|i| bases[positions[i]]) + .fold(G1::identity(), |acc, next| acc + next); + off += bucket_size; + sum + }) + .collect(); + self.odd_points + .iter() + .for_each(|r| assert!(bool::from(r.is_on_curve()))); + sums.iter() + .zip(self.odd_points.iter()) + .for_each(|(r0, r1)| assert_eq!(r0.to_affine(), *r1)); + self.bucket_sizes + .iter() + .for_each(|bucket_size| assert_eq!(bucket_size & 1, *bucket_size)); + assert_eq!(self.max_tree_height(), 0); + self.tree_heights() + .iter() + .for_each(|tree_height| assert_eq!(tree_height, &0)); + } + } + } + + #[test] + fn test_round() { + let mut rng = OsRng; + + let n_points = 1 << 12; + let window_size = 5; + let n_buckets = (1 << window_size) - 1; + + let (bases, _) = get_data(n_points); + let (positions, bucket_sizes) = rand_positions(&mut rng, n_buckets, n_points); + let mut round = Round::new(n_buckets, n_points); + round.init(&bases[..n_points], &positions, &bucket_sizes); + round.evaluate(); + round.sanity_check(&bases, &positions, &bucket_sizes); + } +} diff --git a/src/bn256/msm/zcash.rs b/src/bn256/msm/zcash.rs new file mode 100644 index 00000000..1e5d2b59 --- /dev/null +++ b/src/bn256/msm/zcash.rs @@ -0,0 +1,123 @@ +use crate::group::Group; +use ff::PrimeField; +use pasta_curves::arithmetic::CurveAffine; +use rayon::{current_num_threads, scope}; + +pub(crate) fn best_multiexp_zcash(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(coeffs.len(), bases.len()); + + let num_threads = current_num_threads(); + if coeffs.len() > num_threads { + let chunk = coeffs.len() / num_threads; + let num_chunks = coeffs.chunks(chunk).len(); + let mut results = vec![C::Curve::identity(); num_chunks]; + scope(|scope| { + let chunk = coeffs.len() / num_threads; + + for ((coeffs, bases), acc) in coeffs + .chunks(chunk) + .zip(bases.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + msm_zcash(coeffs, bases, acc); + }); + } + }); + results.iter().fold(C::Curve::identity(), |a, b| a + b) + } else { + let mut acc = C::Curve::identity(); + msm_zcash(coeffs, bases, &mut acc); + acc + } +} + +pub(crate) fn msm_zcash(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + + if skip_bytes >= 32 { + return 0; + } + + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *v = *o; + } + + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + + tmp as usize + } + + let segments = (256 / c) + 1; + + for current_segment in (0..segments).rev() { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + &a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; + + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + let coeff = get_at::(current_segment, c, coeff); + if coeff != 0 { + buckets[coeff - 1].add_assign(base); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} diff --git a/src/derive/curve.rs b/src/derive/curve.rs index b1e1daca..d42384ba 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -142,51 +142,49 @@ macro_rules! batch_add { #[macro_export] macro_rules! endo { - ($name:ident, $field:ident, $params:expr) => { - impl CurveEndo for $name { - fn decompose_scalar(k: &$field) -> (u128, bool, u128, bool) { - let to_limbs = |e: &$field| { - let repr = e.to_repr(); - let repr = repr.as_ref(); - let tmp0 = u64::from_le_bytes(repr[0..8].try_into().unwrap()); - let tmp1 = u64::from_le_bytes(repr[8..16].try_into().unwrap()); - let tmp2 = u64::from_le_bytes(repr[16..24].try_into().unwrap()); - let tmp3 = u64::from_le_bytes(repr[24..32].try_into().unwrap()); - [tmp0, tmp1, tmp2, tmp3] - }; + ($params:expr) => { + fn decompose_scalar(k: &Self::ScalarExt) -> (u128, bool, u128, bool) { + let to_limbs = |e: &Self::ScalarExt| { + let repr = e.to_repr(); + let repr = repr.as_ref(); + let tmp0 = u64::from_le_bytes(repr[0..8].try_into().unwrap()); + let tmp1 = u64::from_le_bytes(repr[8..16].try_into().unwrap()); + let tmp2 = u64::from_le_bytes(repr[16..24].try_into().unwrap()); + let tmp3 = u64::from_le_bytes(repr[24..32].try_into().unwrap()); + [tmp0, tmp1, tmp2, tmp3] + }; - let get_lower_128 = |e: &$field| { - let e = to_limbs(e); - u128::from(e[0]) | (u128::from(e[1]) << 64) - }; + let get_lower_128 = |e: &Self::ScalarExt| { + let e = to_limbs(e); + u128::from(e[0]) | (u128::from(e[1]) << 64) + }; - let is_neg = |e: &$field| { - let e = to_limbs(e); - let (_, borrow) = sbb(0xffffffffffffffff, e[0], 0); - let (_, borrow) = sbb(0xffffffffffffffff, e[1], borrow); - let (_, borrow) = sbb(0xffffffffffffffff, e[2], borrow); - let (_, borrow) = sbb(0x00, e[3], borrow); - borrow & 1 != 0 - }; + let is_neg = |e: &Self::ScalarExt| { + let e = to_limbs(e); + let (_, borrow) = sbb(0xffffffffffffffff, e[0], 0); + let (_, borrow) = sbb(0xffffffffffffffff, e[1], borrow); + let (_, borrow) = sbb(0xffffffffffffffff, e[2], borrow); + let (_, borrow) = sbb(0x00, e[3], borrow); + borrow & 1 != 0 + }; - let input = to_limbs(&k); - let c1 = mul_512($params.gamma2, input); - let c2 = mul_512($params.gamma1, input); - let c1 = [c1[4], c1[5], c1[6], c1[7]]; - let c2 = [c2[4], c2[5], c2[6], c2[7]]; - let q1 = mul_512(c1, $params.b1); - let q2 = mul_512(c2, $params.b2); - let q1 = $field::from_raw([q1[0], q1[1], q1[2], q1[3]]); - let q2 = $field::from_raw([q2[0], q2[1], q2[2], q2[3]]); - let k2 = q2 - q1; - let k1 = k + k2 * $field::ZETA; - let k1_neg = is_neg(&k1); - let k2_neg = is_neg(&k2); - let k1 = if k1_neg { -k1 } else { k1 }; - let k2 = if k2_neg { -k2 } else { k2 }; - - (get_lower_128(&k1), k1_neg, get_lower_128(&k2), k2_neg) - } + let input = to_limbs(&k); + let c1 = mul_512($params.gamma2, input); + let c2 = mul_512($params.gamma1, input); + let c1 = [c1[4], c1[5], c1[6], c1[7]]; + let c2 = [c2[4], c2[5], c2[6], c2[7]]; + let q1 = mul_512(c1, $params.b1); + let q2 = mul_512(c2, $params.b2); + let q1 = Self::ScalarExt::from_raw([q1[0], q1[1], q1[2], q1[3]]); + let q2 = Self::ScalarExt::from_raw([q2[0], q2[1], q2[2], q2[3]]); + let k2 = q2 - q1; + let k1 = k + k2 * Self::ScalarExt::ZETA; + let k1_neg = is_neg(&k1); + let k2_neg = is_neg(&k2); + let k1 = if k1_neg { -k1 } else { k1 }; + let k2 = if k2_neg { -k2 } else { k2 }; + + (get_lower_128(&k1), k1_neg, get_lower_128(&k2), k2_neg) } }; } diff --git a/src/pasta/mod.rs b/src/pasta/mod.rs index 0252b199..6831372c 100644 --- a/src/pasta/mod.rs +++ b/src/pasta/mod.rs @@ -9,32 +9,6 @@ use ff::WithSmallOrderMulGroup; pub use pasta_curves::{pallas, vesta, Ep, EpAffine, Eq, EqAffine, Fp, Fq}; use std::convert::TryInto; -impl crate::CurveAffineExt for EpAffine { - fn batch_add( - _: &mut [Self], - _: &[u32], - _: usize, - _: usize, - _: &[Self], - _: &[u32], - ) { - unimplemented!(); - } -} - -impl crate::CurveAffineExt for EqAffine { - fn batch_add( - _: &mut [Self], - _: &[u32], - _: usize, - _: usize, - _: &[Self], - _: &[u32], - ) { - unimplemented!(); - } -} - const ENDO_PARAMS_EQ: EndoParameters = EndoParameters { gamma1: [0x32c49e4c00000003, 0x279a745902a2654e, 0x1, 0x0], gamma2: [0x31f0256800000002, 0x4f34e8b2066389a4, 0x2, 0x0], @@ -48,40 +22,3 @@ const ENDO_PARAMS_EP: EndoParameters = EndoParameters { b1: [0x8cb1279300000000, 0x49e69d1640a89953, 0x0, 0x0], b2: [0x0c7c095a00000001, 0x93cd3a2c8198e269, 0x0, 0x0], }; - -endo!(Eq, Fp, ENDO_PARAMS_EQ); -endo!(Ep, Fq, ENDO_PARAMS_EP); - -#[test] -fn test_endo() { - use ff::Field; - use rand_core::OsRng; - - for _ in 0..100000 { - let k = Fp::random(OsRng); - let (k1, k1_neg, k2, k2_neg) = Eq::decompose_scalar(&k); - if k1_neg & k2_neg { - assert_eq!(k, -Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) - } else if k1_neg { - assert_eq!(k, -Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) - } else if k2_neg { - assert_eq!(k, Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) - } else { - assert_eq!(k, Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) - } - } - - for _ in 0..100000 { - let k = Fp::random(OsRng); - let (k1, k1_neg, k2, k2_neg) = Eq::decompose_scalar(&k); - if k1_neg & k2_neg { - assert_eq!(k, -Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) - } else if k1_neg { - assert_eq!(k, -Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) - } else if k2_neg { - assert_eq!(k, Fp::from_u128(k1) + Fp::ZETA * Fp::from_u128(k2)) - } else { - assert_eq!(k, Fp::from_u128(k1) - Fp::ZETA * Fp::from_u128(k2)) - } - } -} diff --git a/src/secp256k1/curve.rs b/src/secp256k1/curve.rs index 89c197b5..3f644249 100644 --- a/src/secp256k1/curve.rs +++ b/src/secp256k1/curve.rs @@ -61,14 +61,6 @@ new_curve_impl!( "secp256k1", ); -impl CurveAffineExt for Secp256k1Affine { - batch_add!(); - - fn into_coordinates(self) -> (Self::Base, Self::Base) { - (self.x, self.y) - } -} - #[test] fn test_curve() { crate::tests::curve::curve_tests::(); From d81d67705cd45a2ab99261e62b1319c106c4423a Mon Sep 17 00:00:00 2001 From: kilic Date: Wed, 1 Mar 2023 14:30:05 +0300 Subject: [PATCH 4/7] implement signed digit representation Co-Authored-By: Brechtpd --- src/bn256/msm/mod.rs | 110 ++++++++++++++++++++++++++--------------- src/bn256/msm/pr40.rs | 58 ---------------------- src/bn256/msm/round.rs | 97 ++++++++++++++++++------------------ 3 files changed, 119 insertions(+), 146 deletions(-) diff --git a/src/bn256/msm/mod.rs b/src/bn256/msm/mod.rs index b1688cdf..7e9a1e20 100644 --- a/src/bn256/msm/mod.rs +++ b/src/bn256/msm/mod.rs @@ -4,34 +4,58 @@ use crate::group::Group; use ff::PrimeField; use rayon::{current_num_threads, scope}; -#[cfg(test)] -mod pr40; -mod round; -#[cfg(test)] -mod zcash; - +#[macro_export] macro_rules! div_ceil { ($a:expr, $b:expr) => { (($a - 1) / $b) + 1 }; } +#[macro_export] macro_rules! double_n { ($acc:expr, $n:expr) => { (0..$n).fold($acc, |acc, _| acc.double()) }; } +#[macro_export] macro_rules! range { ($index:expr, $n_items:expr) => { $index * $n_items..($index + 1) * $n_items }; } +#[macro_export] +macro_rules! index { + ($digit:expr) => { + ($digit & 0x7fffffff) as usize + }; +} + +#[macro_export] +macro_rules! is_neg { + ($digit:expr) => { + sign_bit!($digit) != 0 + }; +} + +#[macro_export] +macro_rules! sign_bit { + ($digit:expr) => { + $digit & 0x80000000 + }; +} + +#[cfg(test)] +mod pr40; +mod round; +#[cfg(test)] +mod zcash; + pub struct MSM { + signed_digits: Vec, + sorted_positions: Vec, bucket_sizes: Vec, - sorted_positions: Vec, - bucket_indexes: Vec, bucket_offsets: Vec, n_windows: usize, window: usize, @@ -59,12 +83,12 @@ impl MSM { } let window = best_window(n_points); let n_windows = div_ceil!(Fr::NUM_BITS as usize, window); - let n_buckets = 1 << window; + let n_buckets = (1 << (window - 1)) + 1; let round = Round::new(n_buckets, n_points); MSM { - bucket_indexes: vec![0usize; n_windows * n_points], + signed_digits: vec![0u32; n_windows * n_points], + sorted_positions: vec![0u32; n_windows * n_points], bucket_sizes: vec![0usize; n_windows * n_buckets], - sorted_positions: vec![0usize; n_windows * n_points], bucket_offsets: vec![0; n_buckets], n_windows, window, @@ -75,30 +99,36 @@ impl MSM { } fn decompose(&mut self, scalars: &[Fr]) { - pub(crate) fn get_bits(segment: usize, c: usize, bytes: &[u8]) -> u64 { - let skip_bits = segment * c; + pub(crate) fn get_bits(segment: usize, window: usize, bytes: &[u8]) -> u32 { + let skip_bits = segment * window; let skip_bytes = skip_bits / 8; if skip_bytes >= 32 { return 0; } - let mut v = [0; 8]; + let mut v = [0; 4]; for (v, o) in v.iter_mut().zip(bytes[skip_bytes..].iter()) { *v = *o; } - let mut tmp = u64::from_le_bytes(v); + let mut tmp = u32::from_le_bytes(v); tmp >>= skip_bits - (skip_bytes * 8); - tmp %= 1 << c; - tmp as u64 + tmp %= 1 << window; + tmp } - let scalars = scalars - .iter() - .map(|scalar| scalar.to_repr()) - .collect::>(); - for window_idx in 0..self.n_windows { - for (point_index, scalar) in scalars.iter().enumerate() { - let bucket_index = get_bits(window_idx, self.window, scalar.as_ref()) as usize; - self.bucket_sizes[window_idx * self.n_buckets + bucket_index] += 1; - self.bucket_indexes[window_idx * self.n_points + point_index] = bucket_index; + let max = 1 << (self.window - 1); + for (point_idx, scalar) in scalars.iter().enumerate() { + let repr = scalar.to_repr(); + let mut borrow = 0u32; + for window_idx in 0..self.n_windows { + let windowed_digit = get_bits(window_idx, self.window, repr.as_ref()) + borrow; + let signed_digit = if windowed_digit >= max { + borrow = 1; + ((1 << self.window) - windowed_digit) | 0x80000000 + } else { + borrow = 0; + windowed_digit + }; + self.bucket_sizes[window_idx * self.n_buckets + index!(signed_digit)] += 1; + self.signed_digits[window_idx * self.n_points + point_idx] = signed_digit; } } self.sort(); @@ -108,20 +138,22 @@ impl MSM { for w_i in 0..self.n_windows { let sorted_positions = &mut self.sorted_positions[range!(w_i, self.n_points)]; let bucket_sizes = &self.bucket_sizes[range!(w_i, self.n_buckets)]; - let bucket_indexes = &self.bucket_indexes[range!(w_i, self.n_points)]; + let signed_digits = &self.signed_digits[range!(w_i, self.n_points)]; let mut offset = 0; for (i, size) in bucket_sizes.iter().enumerate() { self.bucket_offsets[i] = offset; offset += size; } - for (idx, bucket_index) in bucket_indexes.iter().enumerate() { - sorted_positions[self.bucket_offsets[*bucket_index]] = idx; - self.bucket_offsets[*bucket_index] += 1; + for (sorted_idx, signed_digit) in signed_digits.iter().enumerate() { + let bucket_idx = index!(signed_digit); + sorted_positions[self.bucket_offsets[bucket_idx]] = + sign_bit!(signed_digit) | (sorted_idx as u32); + self.bucket_offsets[bucket_idx] += 1; } } } - pub fn evalulate(scalars: &[Fr], bases: &[G1Affine], acc: &mut G1) { + pub fn evaluate(scalars: &[Fr], bases: &[G1Affine], acc: &mut G1) { let mut msm = Self::alloacate(bases.len()); msm.decompose(scalars); for w_i in (0..msm.n_windows).rev() { @@ -158,14 +190,14 @@ impl MSM { .zip(results.iter_mut()) { scope.spawn(move |_| { - Self::evalulate(coeffs, bases, acc); + Self::evaluate(coeffs, bases, acc); }); } }); results.iter().fold(G1::identity(), |a, b| a + b) } else { let mut acc = G1::identity(); - Self::evalulate(coeffs, bases, &mut acc); + Self::evaluate(coeffs, bases, &mut acc); acc } } @@ -218,14 +250,14 @@ mod test { #[test] fn test_msm() { - let (points, scalars) = get_data(1 << 22); + let (min_k, max_k) = (10, 22); + let (points, scalars) = get_data(1 << max_k); - for k in 10..=22 { + for k in min_k..=max_k { + println!("k = {}", k); let n_points = 1 << k; let scalars = &scalars[..n_points]; let points = &points[..n_points]; - println!("------ {}", k); - let mut r0 = G1::identity(); let time = std::time::Instant::now(); msm_zcash(scalars, points, &mut r0); @@ -237,9 +269,9 @@ mod test { let time = std::time::Instant::now(); let mut r1 = G1::identity(); - super::MSM::evalulate(scalars, points, &mut r1); + super::MSM::evaluate(scalars, points, &mut r1); assert_eq!(r0, r1); - println!("this {:?}", time.elapsed()); + println!("this serial 1 {:?}", time.elapsed()); let time = std::time::Instant::now(); let r1 = super::MSM::best(scalars, points); diff --git a/src/bn256/msm/pr40.rs b/src/bn256/msm/pr40.rs index 56908ff4..c1778f14 100644 --- a/src/bn256/msm/pr40.rs +++ b/src/bn256/msm/pr40.rs @@ -851,65 +851,7 @@ fn accumulate_buckets( .fold(C::Curve::identity(), |acc, result| acc + result) } -use crate::bn256::{Fr, G1Affine}; -use crate::group::prime::PrimeCurveAffine; use crate::CurveAffineExt; - -fn get_random_data(n: usize) -> (Vec, Vec) { - use rand::SeedableRng; - use rand_xorshift::XorShiftRng; - - let mut bases = vec![G1Affine::identity(); n]; - // parallelize(&mut bases, |bases, _| { - // let mut rng = rand::thread_rng(); - let mut rng = XorShiftRng::from_seed([ - 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, - 0xe5, - ]); - // let base_rnd = G1Affine::random(&mut rng); - for base in bases.iter_mut() { - if INDEPENDENT { - *base = G1Affine::random(&mut rng); - } else { - unreachable!() - // *base = base_rnd; - } - } - // }); - - let mut coeffs = vec![Fr::zero(); n]; - // parallelize(&mut coeffs, |coeffs, _| { - for coeff in coeffs.iter_mut() { - *coeff = Fr::random(&mut rng); - // *coeff = Fr::from(1u64); - } - // }); - - (bases, coeffs) -} - -#[test] -// #[ignore] -fn test_multiexp_bench() { - let min_k = 10; - let max_k = 20; - let n = 1 << max_k; - let (bases, coeffs) = get_random_data::(n); - - let msm = MultiExp::new(&bases); - let mut ctx = MultiExpContext::default(); - for k in min_k..=max_k { - let n = 1 << k; - let coeffs = &coeffs[..n]; - - let start = start_measure("msm".to_string(), false); - msm.evaluate(&mut ctx, coeffs, false); - let duration = stop_measure(start); - - println!("{} {}: {}s", k, n, (duration as f32) / 1000000.0); - } -} - use std::{ env::var, sync::atomic::{AtomicUsize, Ordering}, diff --git a/src/bn256/msm/round.rs b/src/bn256/msm/round.rs index 430d65c3..6fb03384 100644 --- a/src/bn256/msm/round.rs +++ b/src/bn256/msm/round.rs @@ -51,12 +51,21 @@ impl Round { in_off: 0, } } - pub(crate) fn init(&mut self, bases: &[G1Affine], positions: &[usize], bucket_sizes: &[usize]) { + pub(crate) fn init(&mut self, bases: &[G1Affine], positions: &[u32], bucket_sizes: &[usize]) { { assert_eq!(self.n_points, positions.len()); assert_eq!(self.n_points, bases.len()); assert_eq!(self.bucket_sizes.len(), self.n_buckets); } + macro_rules! get_base { + ($positions:expr, $off:expr) => { + if is_neg!($positions[$off]) { + -bases[index!($positions[$off])] + } else { + bases[index!($positions[$off])] + } + }; + } let n_additions = bucket_sizes .iter() .map(|bucket_sizes| bucket_sizes / 2) @@ -79,13 +88,9 @@ impl Round { // process even number of additions for _ in 0..n_additions & (usize::MAX - 1) { // second operand must be mutable - self.t[tmp_off] = bases[positions[in_off]]; - first_phase!( - self.t[out_off], - bases[positions[in_off + 1]], - self.t[tmp_off], - acc - ); + self.t[tmp_off] = get_base!(positions, in_off); + let lhs = get_base!(positions, in_off + 1); + first_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); tmp_off += 1; out_off += 1; in_off += 2; @@ -96,31 +101,23 @@ impl Round { // move to odd-point cache (true, true) => { assert_eq!(positions.len() - 1, in_off); - self.odd_points[bucket_index] = bases[positions[in_off]]; + self.odd_points[bucket_index] = get_base!(positions, in_off); } // 2 base point left // move addition result to odd-point cache (false, true) => { - self.t[tmp_off] = bases[positions[in_off]]; - first_phase!( - self.odd_points[bucket_index], - bases[positions[in_off + 1]], - self.t[tmp_off], - acc - ); + self.t[tmp_off] = get_base!(positions, in_off); + let lhs = get_base!(positions, in_off + 1); + first_phase!(self.odd_points[bucket_index], lhs, self.t[tmp_off], acc); tmp_off += 1; } // 3 base point left // move addition of first two to intermediate and last to odd-point cache (true, false) => { - self.t[tmp_off] = bases[positions[in_off]]; - first_phase!( - self.t[out_off], - bases[positions[in_off + 1]], - self.t[tmp_off], - acc - ); - self.t[out_off + 1] = bases[positions[in_off + 2]]; + self.t[tmp_off] = get_base!(positions, in_off); + let lhs = get_base!(positions, in_off + 1); + first_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); + self.t[out_off + 1] = get_base!(positions, in_off + 2); tmp_off += 1; out_off += 2; } @@ -154,12 +151,8 @@ impl Round { // 2 base point left // move addition result to odd-point cache (false, true) => { - second_phase!( - self.odd_points[bucket_index], - bases[positions[in_off]], - self.t[tmp_off], - acc - ); + let lhs = get_base!(positions, in_off); + second_phase!(self.odd_points[bucket_index], lhs, self.t[tmp_off], acc); tmp_off -= 1; in_off -= 2; } @@ -168,12 +161,8 @@ impl Round { (true, false) => { in_off -= 1; out_off -= 1; - second_phase!( - self.t[out_off], - bases[positions[in_off]], - self.t[tmp_off], - acc - ); + let lhs = get_base!(positions, in_off); + second_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); tmp_off -= 1; in_off -= 2; out_off -= 1; @@ -182,12 +171,8 @@ impl Round { } // process even number of additions for _ in (0..n_additions & (usize::MAX - 1)).rev() { - second_phase!( - self.t[out_off], - bases[positions[in_off]], - self.t[tmp_off], - acc - ); + let lhs = get_base!(positions, in_off); + second_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); tmp_off -= 1; out_off -= 1; in_off -= 2; @@ -298,29 +283,43 @@ mod tests { rng: &mut impl Rng, n_buckets: usize, n_points: usize, - ) -> (Vec, Vec) { - let mut positions: Vec> = vec![vec![]; n_buckets]; + ) -> (Vec, Vec) { + let mut positions: Vec> = vec![vec![]; n_buckets]; (0..n_points).for_each(|i| { - let index = rng.gen_range(0..n_buckets); - positions[index].push(i); + let bucket_index = rng.gen_range(0..n_buckets); + let is_neg: bool = rng.gen(); + let signed_index = if is_neg { + i as u32 | 0x80000000 + } else { + i as u32 + }; + positions[bucket_index].push(signed_index); }); positions.iter_mut().for_each(|positions| { positions.shuffle(rng); }); let bucket_sizes = positions.iter().map(|positions| positions.len()).collect(); - let positions: Vec = positions.iter().flatten().cloned().collect(); + let positions: Vec = positions.iter().flatten().cloned().collect(); (positions, bucket_sizes) } impl Round { - fn sanity_check(&self, bases: &[G1Affine], positions: &[usize], bucket_sizes: &[usize]) { + fn sanity_check(&self, bases: &[G1Affine], positions: &[u32], bucket_sizes: &[usize]) { { let mut off = 0; let sums: Vec<_> = bucket_sizes .iter() .map(|bucket_size| { let sum = (off..off + bucket_size) - .map(|i| bases[positions[i]]) + .map(|i| { + let index = index!(positions[i]); + let is_neg = is_neg!(positions[i]); + if is_neg { + -bases[index] + } else { + bases[index] + } + }) .fold(G1::identity(), |acc, next| acc + next); off += bucket_size; sum From b28c47388943fd48066ec0be548f9f52481caa38 Mon Sep 17 00:00:00 2001 From: kilic Date: Wed, 1 Mar 2023 15:03:19 +0300 Subject: [PATCH 5/7] use zcash msm when number of points less than 256 --- src/arithmetic.rs | 93 +++++++++++++++++++++++++++++++ src/bn256/msm/mod.rs | 84 ++++++++++++++-------------- src/bn256/msm/zcash.rs | 123 ----------------------------------------- 3 files changed, 134 insertions(+), 166 deletions(-) delete mode 100644 src/bn256/msm/zcash.rs diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 2882161b..f148f723 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -4,6 +4,8 @@ //! This module is temporary, and the extension traits defined here are expected to be //! upstreamed into the `ff` and `group` crates after some refactoring. +use ff::PrimeField; +use group::Group; use pasta_curves::arithmetic::CurveAffine; use crate::CurveExt; @@ -86,6 +88,97 @@ pub(crate) fn mul_512(a: [u64; 4], b: [u64; 4]) -> [u64; 8] { [r0, r1, r2, r3, r4, r5, r6, carry_out] } +// taken from zcash/halo2::aritmetic +pub(crate) fn msm_zcash(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + + if skip_bytes >= 32 { + return 0; + } + + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *v = *o; + } + + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + + tmp as usize + } + + let segments = (256 / c) + 1; + + for current_segment in (0..segments).rev() { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + &a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; + + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + let coeff = get_at::(current_segment, c, coeff); + if coeff != 0 { + buckets[coeff - 1].add_assign(base); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} + #[cfg(test)] mod test { use super::CurveEndo; diff --git a/src/bn256/msm/mod.rs b/src/bn256/msm/mod.rs index 7e9a1e20..282fbbbd 100644 --- a/src/bn256/msm/mod.rs +++ b/src/bn256/msm/mod.rs @@ -1,4 +1,5 @@ use super::{Fr, G1Affine}; +use crate::arithmetic::msm_zcash; use crate::bn256::{msm::round::Round, G1}; use crate::group::Group; use ff::PrimeField; @@ -49,8 +50,6 @@ macro_rules! sign_bit { #[cfg(test)] mod pr40; mod round; -#[cfg(test)] -mod zcash; pub struct MSM { signed_digits: Vec, @@ -65,7 +64,7 @@ pub struct MSM { } impl MSM { - pub fn alloacate(n_points: usize) -> Self { + pub fn alloacate(n_points: usize, override_window: Option) -> Self { fn best_window(n: usize) -> usize { if n >= 262144 { 15 @@ -81,7 +80,14 @@ impl MSM { 7 } } - let window = best_window(n_points); + let window = match override_window { + Some(window) => { + let overriden = best_window(n_points); + println!("override window from {} to {}", overriden, window); + window + } + None => best_window(n_points), + }; let n_windows = div_ceil!(Fr::NUM_BITS as usize, window); let n_buckets = (1 << (window - 1)) + 1; let round = Round::new(n_buckets, n_points); @@ -153,15 +159,20 @@ impl MSM { } } - pub fn evaluate(scalars: &[Fr], bases: &[G1Affine], acc: &mut G1) { - let mut msm = Self::alloacate(bases.len()); + pub fn evaluate_with( + scalars: &[Fr], + points: &[G1Affine], + acc: &mut G1, + override_window: Option, + ) { + let mut msm = Self::alloacate(points.len(), override_window); msm.decompose(scalars); for w_i in (0..msm.n_windows).rev() { if w_i != msm.n_windows - 1 { *acc = double_n!(*acc, msm.window); } msm.round.init( - bases, + points, &msm.sorted_positions[range!(w_i, msm.n_points)], &msm.bucket_sizes[range!(w_i, msm.n_buckets)], ); @@ -174,30 +185,38 @@ impl MSM { } } - pub fn best(coeffs: &[Fr], bases: &[G1Affine]) -> G1 { - assert_eq!(coeffs.len(), bases.len()); + pub fn evaluate(scalars: &[Fr], points: &[G1Affine], acc: &mut G1) { + Self::evaluate_with(scalars, points, acc, None) + } + + pub fn best(scalars: &[Fr], points: &[G1Affine]) -> G1 { + assert_eq!(scalars.len(), points.len()); let num_threads = current_num_threads(); - if coeffs.len() > num_threads { - let chunk = coeffs.len() / num_threads; - let num_chunks = coeffs.chunks(chunk).len(); + if scalars.len() > num_threads { + let chunk = scalars.len() / num_threads; + let num_chunks = scalars.chunks(chunk).len(); let mut results = vec![G1::identity(); num_chunks]; scope(|scope| { - let chunk = coeffs.len() / num_threads; + let chunk = scalars.len() / num_threads; - for ((coeffs, bases), acc) in coeffs + for ((scalars, points), acc) in scalars .chunks(chunk) - .zip(bases.chunks(chunk)) + .zip(points.chunks(chunk)) .zip(results.iter_mut()) { scope.spawn(move |_| { - Self::evaluate(coeffs, bases, acc); + if points.len() < 1 << 8 { + msm_zcash(scalars, points, acc); + } else { + Self::evaluate(scalars, points, acc); + } }); } }); results.iter().fold(G1::identity(), |a, b| a + b) } else { let mut acc = G1::identity(); - Self::evaluate(coeffs, bases, &mut acc); + Self::evaluate(scalars, points, &mut acc); acc } } @@ -205,13 +224,14 @@ impl MSM { #[cfg(test)] mod test { + use crate::arithmetic::msm_zcash; use crate::bn256::msm::pr40::{MultiExp, MultiExpContext}; - use crate::bn256::msm::zcash::{best_multiexp_zcash, msm_zcash}; use crate::bn256::{Fr, G1Affine, G1}; use crate::group::Group; use crate::serde::SerdeObject; use ff::Field; use group::Curve; + use rand::Rng; use rand_core::OsRng; use std::fs::File; use std::path::Path; @@ -248,42 +268,20 @@ mod test { } #[test] - fn test_msm() { - let (min_k, max_k) = (10, 22); + let (min_k, max_k) = (4, 20); let (points, scalars) = get_data(1 << max_k); for k in min_k..=max_k { - println!("k = {}", k); - let n_points = 1 << k; + let mut rng = OsRng; + let n_points = rng.gen_range(1 << (k - 1)..1 << k); let scalars = &scalars[..n_points]; let points = &points[..n_points]; let mut r0 = G1::identity(); - let time = std::time::Instant::now(); msm_zcash(scalars, points, &mut r0); - println!("zcash serial {:?}", time.elapsed()); - - let time = std::time::Instant::now(); - let r0 = best_multiexp_zcash(scalars, points); - println!("zcash parallel {:?}", time.elapsed()); - - let time = std::time::Instant::now(); let mut r1 = G1::identity(); super::MSM::evaluate(scalars, points, &mut r1); assert_eq!(r0, r1); - println!("this serial 1 {:?}", time.elapsed()); - - let time = std::time::Instant::now(); - let r1 = super::MSM::best(scalars, points); - assert_eq!(r0, r1); - println!("this parallel {:?}", time.elapsed()); - - let time = std::time::Instant::now(); - let msm = MultiExp::new(&points); - let mut ctx = MultiExpContext::default(); - let _ = msm.evaluate(&mut ctx, scalars, false); - // assert_eq!(r0, r1); // fails - println!("pr40 {:?}", time.elapsed()); } } } diff --git a/src/bn256/msm/zcash.rs b/src/bn256/msm/zcash.rs deleted file mode 100644 index 1e5d2b59..00000000 --- a/src/bn256/msm/zcash.rs +++ /dev/null @@ -1,123 +0,0 @@ -use crate::group::Group; -use ff::PrimeField; -use pasta_curves::arithmetic::CurveAffine; -use rayon::{current_num_threads, scope}; - -pub(crate) fn best_multiexp_zcash(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { - assert_eq!(coeffs.len(), bases.len()); - - let num_threads = current_num_threads(); - if coeffs.len() > num_threads { - let chunk = coeffs.len() / num_threads; - let num_chunks = coeffs.chunks(chunk).len(); - let mut results = vec![C::Curve::identity(); num_chunks]; - scope(|scope| { - let chunk = coeffs.len() / num_threads; - - for ((coeffs, bases), acc) in coeffs - .chunks(chunk) - .zip(bases.chunks(chunk)) - .zip(results.iter_mut()) - { - scope.spawn(move |_| { - msm_zcash(coeffs, bases, acc); - }); - } - }); - results.iter().fold(C::Curve::identity(), |a, b| a + b) - } else { - let mut acc = C::Curve::identity(); - msm_zcash(coeffs, bases, &mut acc); - acc - } -} - -pub(crate) fn msm_zcash(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { - let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); - - let c = if bases.len() < 4 { - 1 - } else if bases.len() < 32 { - 3 - } else { - (f64::from(bases.len() as u32)).ln().ceil() as usize - }; - - fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { - let skip_bits = segment * c; - let skip_bytes = skip_bits / 8; - - if skip_bytes >= 32 { - return 0; - } - - let mut v = [0; 8]; - for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { - *v = *o; - } - - let mut tmp = u64::from_le_bytes(v); - tmp >>= skip_bits - (skip_bytes * 8); - tmp %= 1 << c; - - tmp as usize - } - - let segments = (256 / c) + 1; - - for current_segment in (0..segments).rev() { - for _ in 0..c { - *acc = acc.double(); - } - - #[derive(Clone, Copy)] - enum Bucket { - None, - Affine(C), - Projective(C::Curve), - } - - impl Bucket { - fn add_assign(&mut self, other: &C) { - *self = match *self { - Bucket::None => Bucket::Affine(*other), - Bucket::Affine(a) => Bucket::Projective(a + *other), - Bucket::Projective(mut a) => { - a += *other; - Bucket::Projective(a) - } - } - } - - fn add(self, mut other: C::Curve) -> C::Curve { - match self { - Bucket::None => other, - Bucket::Affine(a) => { - other += a; - other - } - Bucket::Projective(a) => other + &a, - } - } - } - - let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; - - for (coeff, base) in coeffs.iter().zip(bases.iter()) { - let coeff = get_at::(current_segment, c, coeff); - if coeff != 0 { - buckets[coeff - 1].add_assign(base); - } - } - - // Summation by parts - // e.g. 3a + 2b + 1c = a + - // (a) + b + - // ((a) + b) + c - let mut running_sum = C::Curve::identity(); - for exp in buckets.into_iter().rev() { - running_sum = exp.add(running_sum); - *acc += &running_sum; - } - } -} From 165c1da1bcc091c0b783ad63ae7881789ae6470b Mon Sep 17 00:00:00 2001 From: kilic Date: Wed, 1 Mar 2023 15:41:06 +0300 Subject: [PATCH 6/7] remove pse/halo2/#40 --- src/arithmetic.rs | 8 - src/bn256/curve.rs | 8 +- src/bn256/msm/mod.rs | 3 - src/bn256/msm/pr40.rs | 924 ----------------------------------------- src/derive/curve.rs | 142 ------- src/secp256k1/curve.rs | 8 +- 6 files changed, 7 insertions(+), 1086 deletions(-) delete mode 100644 src/bn256/msm/pr40.rs diff --git a/src/arithmetic.rs b/src/arithmetic.rs index f148f723..6c1c0527 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -24,14 +24,6 @@ pub trait CurveEndo: CurveExt { pub trait CurveAffineExt: CurveAffine { fn decompose_scalar(k: &Self::ScalarExt) -> (u128, bool, u128, bool); fn endo(&self) -> Self; - fn batch_add( - points: &mut [Self], - output_indices: &[u32], - num_points: usize, - offset: usize, - bases: &[Self], - base_positions: &[u32], - ); } /// Compute a + b + carry, returning the result and the new carry over. diff --git a/src/bn256/curve.rs b/src/bn256/curve.rs index 61485d77..e288f4fb 100644 --- a/src/bn256/curve.rs +++ b/src/bn256/curve.rs @@ -20,9 +20,9 @@ use std::convert::TryInto; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use crate::{ - batch_add, impl_add_binop_specify_output, impl_binops_additive, - impl_binops_additive_specify_output, impl_binops_multiplicative, - impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, new_curve_impl, + impl_add_binop_specify_output, impl_binops_additive, impl_binops_additive_specify_output, + impl_binops_multiplicative, impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, + new_curve_impl, }; new_curve_impl!( @@ -50,7 +50,6 @@ new_curve_impl!( ); impl CurveAffineExt for G1Affine { - batch_add!(); endo!(ENDO_PARAMS); fn endo(&self) -> Self { @@ -66,7 +65,6 @@ impl CurveEndo for G1 { } impl CurveAffineExt for G2Affine { - batch_add!(); endo!(ENDO_PARAMS); fn endo(&self) -> Self { diff --git a/src/bn256/msm/mod.rs b/src/bn256/msm/mod.rs index 282fbbbd..2ce91643 100644 --- a/src/bn256/msm/mod.rs +++ b/src/bn256/msm/mod.rs @@ -47,8 +47,6 @@ macro_rules! sign_bit { }; } -#[cfg(test)] -mod pr40; mod round; pub struct MSM { @@ -225,7 +223,6 @@ impl MSM { #[cfg(test)] mod test { use crate::arithmetic::msm_zcash; - use crate::bn256::msm::pr40::{MultiExp, MultiExpContext}; use crate::bn256::{Fr, G1Affine, G1}; use crate::group::Group; use crate::serde::SerdeObject; diff --git a/src/bn256/msm/pr40.rs b/src/bn256/msm/pr40.rs deleted file mode 100644 index c1778f14..00000000 --- a/src/bn256/msm/pr40.rs +++ /dev/null @@ -1,924 +0,0 @@ -//! This module implements a fast method for multi-scalar multiplications. -//! -//! Generally it works like pippenger with a couple of tricks to make if faster. -//! -//! - First the coefficients are split into two parts (using the endomorphism). This -//! reduces the number of rounds by half, but doubles the number of points per round. -//! This is faster because half the rounds also means only needing to add all bucket -//! results together half the number of times. -//! -//! - The coefficients are then sorted in buckets. Instead of using -//! the binary representation to do this, a signed digit representation is -//! used instead (WNAF). Unfortunately this doesn't directly reduce the number of additions -//! in a bucket, but it does reduce the number of buckets in half, which halves the -//! work required to accumulate the results of the buckets. -//! -//! - We then need to add all the points in each bucket together. To do this -//! the affine addition formulas are used. If the points are linearly independent the -//! incomplete version of the formula can be used which is quite a bit faster than -//! the full one because some checks can be skipped. -//! The affine formula is only fast if a lot of independent points can be added -//! together. This is because to get the actual result of an addition an inversion is -//! needed which is very expensive, but it's cheap when batched inversion can be used. -//! So the idea is to add a lot of pairs of points together using a single batched inversion. -//! We then have the results of all those additions, and can do a new batch of additions on those -//! results. This process is repeated as many times as needed until all additions for each bucket -//! are done. To do this efficiently we first build up an addition tree that sets everything -//! up correctly per round. We then process each addition tree per round. - -use core::slice; -pub use ff::Field; -use group::{ff::PrimeField, Group as _}; -pub use rayon::{current_num_threads, scope, Scope}; - -fn num_bits(value: usize) -> usize { - (0usize.leading_zeros() - value.leading_zeros()) as usize -} - -fn div_up(a: usize, b: usize) -> usize { - (a + (b - 1)) / b -} - -fn get_wnaf_size_bits(num_bits: usize, w: usize) -> usize { - div_up(num_bits, w) -} - -fn get_wnaf_size(w: usize) -> usize { - get_wnaf_size_bits(div_up(C::Scalar::NUM_BITS as usize, 2), w) -} - -fn get_num_rounds(c: usize) -> usize { - get_wnaf_size::(c + 1) -} - -fn get_num_buckets(c: usize) -> usize { - (1 << c) + 1 -} - -fn get_max_tree_size(num_points: usize, c: usize) -> usize { - num_points * 2 + get_num_buckets(c) -} - -fn get_num_tree_levels(num_points: usize) -> usize { - 1 + num_bits(num_points - 1) -} - -/// Returns the signed digit representation of value with the specified window size. -/// The result is written to the wnaf slice with the specified stride. -fn get_wnaf(value: u128, w: usize, num_rounds: usize, wnaf: &mut [u32], stride: usize) { - fn get_bits_at(v: u128, pos: usize, num: usize) -> usize { - ((v >> pos) & ((1 << num) - 1)) as usize - } - - let mut borrow = 0; - let max = 1 << (w - 1); - for idx in 0..num_rounds { - let b = get_bits_at(value, idx * w, w) + borrow; - if b >= max { - // Set the highest bit to 1 to represent a negative value. - // This way the lower bits directly represent the bucket index. - wnaf[idx * stride] = (0x80000000 | ((1 << w) - b)) as u32; - borrow = 1; - } else { - wnaf[idx * stride] = b as u32; - borrow = 0; - } - } - assert_eq!(borrow, 0); -} - -/// Returns the best bucket width for the given number of points. -fn get_best_c(num_points: usize) -> usize { - if num_points >= 262144 { - 15 - } else if num_points >= 65536 { - 12 - } else if num_points >= 16384 { - 11 - } else if num_points >= 8192 { - 10 - } else if num_points >= 1024 { - 9 - } else { - 7 - } -} - -/// MultiExp -#[derive(Clone, Debug, Default)] -pub struct MultiExp { - /// The bases - bases: Vec, -} - -/// MultiExp context object -#[derive(Clone, Debug, Default)] -pub struct MultiExpContext { - /// Memory to store the points in the addition tree - points: Vec, - /// Memory to store wnafs - wnafs: Vec, - /// Memory split up between rounds - rounds: SharedRoundData, -} - -/// SharedRoundData -#[derive(Clone, Debug, Default)] -struct SharedRoundData { - /// Memory to store bucket sizes - bucket_sizes: Vec, - /// Memory to store bucket offsets - bucket_offsets: Vec, - /// Memory to store the point data - point_data: Vec, - /// Memory to store the output indices - output_indices: Vec, - /// Memory to store the base positions (on the first level) - base_positions: Vec, - /// Memory to store the scatter maps - scatter_map: Vec, -} - -/// RoundData -#[derive(Debug, Default)] -struct RoundData<'a> { - /// Number of levels in the addition tree - pub num_levels: usize, - /// The length of each level in the addition tree - pub level_sizes: Vec, - /// The offset to each level in the addition tree - pub level_offset: Vec, - /// The size of each bucket - pub bucket_sizes: &'a mut [usize], - /// The offset of each bucket - pub bucket_offsets: &'a mut [usize], - /// The point to use for each coefficient - pub point_data: &'a mut [u32], - /// The output index in the point array for each pair addition - pub output_indices: &'a mut [u32], - /// The point to use on the first level in the addition tree - pub base_positions: &'a mut [u32], - /// List of points that are scattered to the addition tree - pub scatter_map: &'a mut [ScatterData], - /// The length of scatter_map - pub scatter_map_len: usize, -} - -/// ScatterData -#[derive(Default, Debug, Clone)] -struct ScatterData { - /// The position in the addition tree to store the point - pub position: u32, - /// The point to write - pub point_data: u32, -} - -impl MultiExp { - /// Create a new MultiExp instance with the specified bases - pub fn new(bases: &[C]) -> Self { - let mut endo_bases = vec![C::identity(); bases.len() * 2]; - - // Generate the endomorphism bases - let num_threads = current_num_threads(); - scope(|scope| { - let num_points_per_thread = div_up(bases.len(), num_threads); - for (endo_bases, bases) in endo_bases - .chunks_mut(num_points_per_thread * 2) - .zip(bases.chunks(num_points_per_thread)) - { - scope.spawn(move |_| { - for (idx, base) in bases.iter().enumerate() { - endo_bases[idx * 2] = *base; - endo_bases[idx * 2 + 1] = C::endo(base); - } - }); - } - }); - - Self { bases: endo_bases } - } - - /// Performs a multi-exponentiation operation. - /// Set complete to true if the bases are not guaranteed linearly independent. - pub fn evaluate( - &self, - ctx: &mut MultiExpContext, - coeffs: &[C::Scalar], - complete: bool, - ) -> C::Curve { - self.evaluate_with(ctx, coeffs, complete, get_best_c(coeffs.len())) - } - - /// Performs a multi-exponentiation operation with the given bucket width. - /// Set complete to true if the bases are not guaranteed linearly independent. - pub fn evaluate_with( - &self, - ctx: &mut MultiExpContext, - coeffs: &[C::Scalar], - complete: bool, - c: usize, - ) -> C::Curve { - assert!(coeffs.len() * 2 <= self.bases.len()); - assert!(c >= 4); - - // Allocate more memory if required - ctx.allocate(coeffs.len(), c); - - // Get the data for each round - let mut rounds = ctx.rounds.get_rounds::(coeffs.len(), c); - - // Get the bases for the coefficients - let bases = &self.bases[..coeffs.len() * 2]; - - let num_threads = current_num_threads(); - let start = start_measure( - format!("msm {} ({}) ({} threads)", coeffs.len(), c, num_threads), - false, - ); - // if coeffs.len() >= 16 { - let num_points = coeffs.len() * 2; - let w = c + 1; - let num_rounds = get_num_rounds::(c); - - // Prepare WNAFs of all coefficients for all rounds - calculate_wnafs::(coeffs, &mut ctx.wnafs, c); - // Sort WNAFs into buckets for all rounds - sort::(&mut ctx.wnafs[0..num_rounds * num_points], &mut rounds, c); - // Calculate addition trees for all rounds - create_addition_trees(&mut rounds); - - // Now process each round individually - let mut partials = vec![C::Curve::identity(); num_rounds]; - for (round, acc) in rounds.iter().zip(partials.iter_mut()) { - // Scatter the odd points in the odd length buckets to the addition tree - do_point_scatter::(round, bases, &mut ctx.points); - // Do all bucket additions - do_batch_additions::(round, bases, &mut ctx.points, complete); - // Get the final result of the round - *acc = accumulate_buckets::(round, &mut ctx.points, c); - } - - // Accumulate round results - let res = partials - .iter() - .rev() - .skip(1) - .fold(partials[num_rounds - 1], |acc, partial| { - let mut res = acc; - for _ in 0..w { - res = res.double(); - } - res + partial - }); - stop_measure(start); - - res - // } else { - // // Just do a naive msm - // let mut acc = C::Curve::identity(); - // for (idx, coeff) in coeffs.iter().enumerate() { - // // Skip over endomorphism bases - // acc += bases[idx * 2] * coeff; - // } - // stop_measure(start); - // acc - // } - } -} - -impl MultiExpContext { - /// Allocate memory for the evalution - pub fn allocate(&mut self, num_points: usize, c: usize) { - let num_points = num_points * 2; - let num_buckets = get_num_buckets(c); - let num_rounds = get_num_rounds::(c); - let tree_size = get_max_tree_size(num_points, c); - let num_points_total = num_rounds * num_points; - let num_buckets_total = num_rounds * num_buckets; - let tree_size_total = num_rounds * tree_size; - - // Allocate memory when necessary - if self.points.len() < tree_size { - self.points.resize(tree_size, C::identity()); - } - if self.wnafs.len() < num_points_total { - self.wnafs.resize(num_points_total, 0u32); - } - if self.rounds.bucket_sizes.len() < num_buckets_total { - self.rounds.bucket_sizes.resize(num_buckets_total, 0usize); - } - if self.rounds.bucket_offsets.len() < num_buckets_total { - self.rounds.bucket_offsets.resize(num_buckets_total, 0usize); - } - if self.rounds.point_data.len() < num_points_total { - self.rounds.point_data.resize(num_points_total, 0u32); - } - if self.rounds.output_indices.len() < tree_size_total / 2 { - self.rounds.output_indices.resize(tree_size_total / 2, 0u32); - } - if self.rounds.base_positions.len() < num_points_total { - self.rounds.base_positions.resize(num_points_total, 0u32); - } - if self.rounds.scatter_map.len() < num_buckets_total { - self.rounds - .scatter_map - .resize(num_buckets_total, ScatterData::default()); - } - } -} - -impl SharedRoundData { - fn get_rounds(&mut self, num_points: usize, c: usize) -> Vec { - let num_points = num_points * 2; - let num_buckets = get_num_buckets(c); - let num_rounds = get_num_rounds::(c); - let tree_size = num_points * 2 + num_buckets; - - let mut bucket_sizes_rest = self.bucket_sizes.as_mut_slice(); - let mut bucket_offsets_rest = self.bucket_offsets.as_mut_slice(); - let mut point_data_rest = self.point_data.as_mut_slice(); - let mut output_indices_rest = self.output_indices.as_mut_slice(); - let mut base_positions_rest = self.base_positions.as_mut_slice(); - let mut scatter_map_rest = self.scatter_map.as_mut_slice(); - - // Use the allocated memory above to init the memory used for each round. - // This way the we don't need to reallocate memory for each msm with - // a different configuration (different number of points or different bucket width) - let mut rounds: Vec = Vec::with_capacity(num_rounds); - for _ in 0..num_rounds { - let (bucket_sizes, rest) = bucket_sizes_rest.split_at_mut(num_buckets); - bucket_sizes_rest = rest; - let (bucket_offsets, rest) = bucket_offsets_rest.split_at_mut(num_buckets); - bucket_offsets_rest = rest; - let (point_data, rest) = point_data_rest.split_at_mut(num_points); - point_data_rest = rest; - let (output_indices, rest) = output_indices_rest.split_at_mut(tree_size / 2); - output_indices_rest = rest; - let (base_positions, rest) = base_positions_rest.split_at_mut(num_points); - base_positions_rest = rest; - let (scatter_map, rest) = scatter_map_rest.split_at_mut(num_buckets); - scatter_map_rest = rest; - - rounds.push(RoundData { - num_levels: 0, - level_sizes: vec![], - level_offset: vec![], - bucket_sizes, - bucket_offsets, - point_data, - output_indices, - base_positions, - scatter_map, - scatter_map_len: 0, - }); - } - rounds - } -} - -#[derive(Clone, Copy)] -struct ThreadBox(*mut T, usize); -#[allow(unsafe_code)] -unsafe impl Send for ThreadBox {} -#[allow(unsafe_code)] -unsafe impl Sync for ThreadBox {} - -/// Wraps a mutable slice so it can be passed into a thread without -/// hard to fix borrow checks caused by difficult data access patterns. -impl ThreadBox { - fn wrap(data: &mut [T]) -> Self { - Self(data.as_mut_ptr(), data.len()) - } - - fn unwrap(&mut self) -> &mut [T] { - #[allow(unsafe_code)] - unsafe { - slice::from_raw_parts_mut(self.0, self.1) - } - } -} - -fn calculate_wnafs(coeffs: &[C::Scalar], wnafs: &mut [u32], c: usize) { - let num_threads = current_num_threads(); - let num_points = coeffs.len() * 2; - let num_rounds = get_num_rounds::(c); - let w = c + 1; - - let start = start_measure("calculate wnafs".to_string(), false); - let mut wnafs_box = ThreadBox::wrap(wnafs); - let chunk_size = div_up(coeffs.len(), num_threads); - scope(|scope| { - for (thread_idx, coeffs) in coeffs.chunks(chunk_size).enumerate() { - scope.spawn(move |_| { - let wnafs = &mut wnafs_box.unwrap()[thread_idx * chunk_size * 2..]; - for (idx, coeff) in coeffs.iter().enumerate() { - let (p0, _, p1, _) = C::decompose_scalar(coeff); - get_wnaf(p0, w, num_rounds, &mut wnafs[idx * 2..], num_points); - get_wnaf(p1, w, num_rounds, &mut wnafs[idx * 2 + 1..], num_points); - } - }); - } - }); - stop_measure(start); -} - -fn radix_sort(wnafs: &mut [u32], round: &mut RoundData) { - let bucket_sizes = &mut round.bucket_sizes; - let bucket_offsets = &mut round.bucket_offsets; - - // Calculate bucket sizes, first resetting all sizes to 0 - bucket_sizes.fill_with(|| 0); - for wnaf in wnafs.iter() { - bucket_sizes[(wnaf & 0x7FFFFFFF) as usize] += 1; - } - - // Calculate bucket offsets - let mut offset = 0; - let mut max_bucket_size = 0; - bucket_offsets[0] = offset; - offset += bucket_sizes[0]; - for (bucket_offset, bucket_size) in bucket_offsets - .iter_mut() - .skip(1) - .zip(bucket_sizes.iter().skip(1)) - { - *bucket_offset = offset; - offset += bucket_size; - max_bucket_size = max_bucket_size.max(*bucket_size); - } - // Number of levels we need in our addition tree - round.num_levels = get_num_tree_levels(max_bucket_size); - - // Fill in point data grouped in buckets - let point_data = &mut round.point_data; - for (idx, wnaf) in wnafs.iter().enumerate() { - let bucket_idx = (wnaf & 0x7FFFFFFF) as usize; - point_data[bucket_offsets[bucket_idx]] = (wnaf & 0x80000000) | (idx as u32); - bucket_offsets[bucket_idx] += 1; - } -} - -/// Sorts the points so they are grouped per bucket -fn sort(wnafs: &mut [u32], rounds: &mut [RoundData], c: usize) { - let num_rounds = get_num_rounds::(c); - let num_points = wnafs.len() / num_rounds; - - // Sort per bucket for each round separately - let start = start_measure("radix sort".to_string(), false); - scope(|scope| { - for (round, wnafs) in rounds.chunks_mut(1).zip(wnafs.chunks_mut(num_points)) { - scope.spawn(move |_| { - radix_sort(wnafs, &mut round[0]); - }); - } - }); - stop_measure(start); -} - -/// Creates the addition tree. -/// When PREPROCESS is false we just calculate the size of each level. -/// All points in a bucket need to be added to each other. Because the affine formulas -/// are used we need to add points together in pairs. So we have to make sure that -/// on each level we have an even number of points for each level. Odd points are -/// added to lower levels where the length of the addition results is odd (which then -/// makes the length even). -fn process_addition_tree(round: &mut RoundData) { - let num_levels = round.num_levels; - let bucket_sizes = &round.bucket_sizes; - let point_data = &round.point_data; - - let mut level_sizes = vec![0usize; num_levels]; - let mut level_offset = vec![0usize; num_levels]; - let output_indices = &mut round.output_indices; - let scatter_map = &mut round.scatter_map; - let base_positions = &mut round.base_positions; - let mut point_idx = bucket_sizes[0]; - - if !PREPROCESS { - // Set the offsets to the different levels in the tree - level_offset[0] = 0; - for idx in 1..level_offset.len() { - level_offset[idx] = level_offset[idx - 1] + round.level_sizes[idx - 1]; - } - } - - // The level where all bucket results will be stored - let bucket_level = num_levels - 1; - - // Run over all buckets - for bucket_size in bucket_sizes.iter().skip(1) { - let mut size = *bucket_size; - if size == 0 { - level_sizes[bucket_level] += 1; - } else if size == 1 { - if !PREPROCESS { - scatter_map[round.scatter_map_len] = ScatterData { - position: (level_offset[bucket_level] + level_sizes[bucket_level]) as u32, - point_data: point_data[point_idx], - }; - round.scatter_map_len += 1; - point_idx += 1; - } - level_sizes[bucket_level] += 1; - } else { - #[derive(Clone, Copy, PartialEq)] - enum State { - Even, - OddPoint(usize), - OddResult(usize), - } - let mut state = State::Even; - let num_levels_bucket = get_num_tree_levels(size); - - let mut start_level_size = level_sizes[0]; - for level in 0..num_levels_bucket - 1 { - let is_level_odd = size & 1; - let first_level = level == 0; - let last_level = level == num_levels_bucket - 2; - - // If this level has odd size we have to handle it - if is_level_odd == 1 { - // If we already have a point saved from a previous odd level, use it - // to make the current level even - if state != State::Even { - if !PREPROCESS { - let pos = (level_offset[level] + level_sizes[level]) as u32; - match state { - State::OddPoint(point_idx) => { - scatter_map[round.scatter_map_len] = ScatterData { - position: pos, - point_data: point_data[point_idx], - }; - round.scatter_map_len += 1; - } - State::OddResult(output_idx) => { - output_indices[output_idx] = pos; - } - _ => unreachable!(), - }; - } - level_sizes[level] += 1; - size += 1; - state = State::Even; - } else { - // Not odd yet, so the state is now odd - // Store the point we have to add later - if !PREPROCESS { - if first_level { - state = State::OddPoint(point_idx + size - 1); - } else { - state = State::OddResult( - (level_offset[level] + level_sizes[level] + size) >> 1, - ); - } - } else { - // Just mark it as odd, we won't use the actual value anywhere - state = State::OddPoint(0); - } - size -= 1; - } - } - - // Write initial points on the first level - if first_level { - if !PREPROCESS { - // Just write all points (except the odd size one) - let pos = level_offset[level] + level_sizes[level]; - base_positions[pos..pos + size] - .copy_from_slice(&point_data[point_idx..point_idx + size]); - point_idx += size + is_level_odd; - } - level_sizes[level] += size; - } - - // Write output indices - // If the next level would be odd, we have to make it even - // by writing the last result of this level to the next level that is odd - // (unless we are writing the final result to the bucket level) - let next_level_size = size >> 1; - let next_level_odd = next_level_size & 1 == 1; - let redirect = - if next_level_odd && state == State::Even && level < num_levels_bucket - 2 { - 1usize - } else { - 0usize - }; - // An addition works on two points and has one result, so this takes only half the size - let sub_level_offset = (level_offset[level] + start_level_size) >> 1; - // Cache the start position of the next level - start_level_size = level_sizes[level + 1]; - if !PREPROCESS { - // Write the destination positions of the addition results in the tree - let dst_pos = level_offset[level + 1] + level_sizes[level + 1]; - for (idx, output_index) in output_indices - [sub_level_offset..sub_level_offset + next_level_size] - .iter_mut() - .enumerate() - { - *output_index = (dst_pos + idx) as u32; - } - } - if last_level { - // The result of the last addition for this bucket is written - // to the last level (so all bucket results are nicely after each other). - // Overwrite the output locations of the last result here. - if !PREPROCESS { - output_indices[sub_level_offset] = - (level_offset[bucket_level] + level_sizes[bucket_level]) as u32; - } - level_sizes[bucket_level] += 1; - } else { - // Update the sizes - level_sizes[level + 1] += next_level_size - redirect; - size -= redirect; - // We have to redirect the last result to a lower level - if redirect == 1 { - state = State::OddResult(sub_level_offset + next_level_size - 1); - } - } - - // We added pairs of points together so the next level has half the size - size >>= 1; - } - } - } - - // Store the tree level data - round.level_sizes = level_sizes; - round.level_offset = level_offset; -} - -/// The affine formula is only efficient for independent point additions -/// (using the result of the addition requires an inversion which needs to be avoided as much as possible). -/// And so we try to add as many points together on each level of the tree, writing the result of the addition -/// to a lower level. Each level thus contains independent point additions, with only requiring a single inversion -/// per level in the tree. -fn create_addition_trees(rounds: &mut [RoundData]) { - let start = start_measure("create addition trees".to_string(), false); - scope(|scope| { - for round in rounds.chunks_mut(1) { - scope.spawn(move |_| { - // Collect tree levels sizes - process_addition_tree::(&mut round[0]); - // Construct the tree - process_addition_tree::(&mut round[0]); - }); - } - }); - stop_measure(start); -} - -/// Here we write the odd points in odd length buckets (the other points are loaded on the fly). -/// This will do random reads AND random writes, which is normally terrible for performance. -/// Luckily this doesn't really matter because we only have to write at most num_buckets points. -fn do_point_scatter(round: &RoundData, bases: &[C], points: &mut [C]) { - let num_threads = current_num_threads(); - let scatter_map = &round.scatter_map[..round.scatter_map_len]; - let mut points_box = ThreadBox::wrap(points); - let start = start_measure("point scatter".to_string(), false); - if !scatter_map.is_empty() { - scope(|scope| { - let num_copies_per_thread = div_up(scatter_map.len(), num_threads); - for scatter_map in scatter_map.chunks(num_copies_per_thread) { - scope.spawn(move |_| { - let points = points_box.unwrap(); - for scatter_data in scatter_map.iter() { - let target_idx = scatter_data.position as usize; - let negate = scatter_data.point_data & 0x80000000 != 0; - let base_idx = (scatter_data.point_data & 0x7FFFFFFF) as usize; - if negate { - points[target_idx] = bases[base_idx].neg(); - } else { - points[target_idx] = bases[base_idx]; - } - } - }); - } - }); - } - stop_measure(start); -} - -/// Finally do all additions using the addition tree we've setup. -fn do_batch_additions( - round: &RoundData, - bases: &[C], - points: &mut [C], - complete: bool, -) { - let num_threads = current_num_threads(); - - let num_levels = round.num_levels; - let level_counter = &round.level_sizes; - let level_offset = &round.level_offset; - let output_indices = &round.output_indices; - let base_positions = &round.base_positions; - let mut points_box = ThreadBox::wrap(points); - - let start = start_measure("batch additions".to_string(), false); - for i in 0..num_levels - 1 { - let start = level_offset[i]; - let num_points = level_counter[i]; - scope(|scope| { - // We have to make sure we have an even amount here so we don't split within a pair - let num_points_per_thread = div_up(num_points / 2, num_threads) * 2; - for thread_idx in 0..num_threads { - scope.spawn(move |_| { - let points = points_box.unwrap(); - - let thread_start = thread_idx * num_points_per_thread; - let mut thread_num_points = num_points_per_thread; - - if thread_start < num_points { - if thread_start + thread_num_points > num_points { - thread_num_points = num_points - thread_start; - } - - let points = &mut points[(start + thread_start)..]; - let output_indices = &output_indices[(start + thread_start) / 2..]; - let offset = start + thread_start; - if i == 0 { - let base_positions = &base_positions[(start + thread_start)..]; - if complete { - C::batch_add::( - points, - output_indices, - thread_num_points, - offset, - bases, - base_positions, - ); - } else { - C::batch_add::( - points, - output_indices, - thread_num_points, - offset, - bases, - base_positions, - ); - } - } else { - #[allow(collapsible-else-if)] - if complete { - C::batch_add::( - points, - output_indices, - thread_num_points, - offset, - &[], - &[], - ); - } else { - C::batch_add::( - points, - output_indices, - thread_num_points, - offset, - &[], - &[], - ); - } - } - } - }); - } - }); - } - stop_measure(start); -} - -/// Accumulate all bucket results to get the result of the round -fn accumulate_buckets( - round: &RoundData, - points: &mut [C], - c: usize, -) -> C::Curve { - let num_threads = current_num_threads(); - let num_buckets = get_num_buckets(c); - - let num_levels = round.num_levels; - let bucket_sizes = &round.bucket_sizes; - let level_offset = &round.level_offset; - - let start_time = start_measure("accumulate buckets".to_string(), false); - let start = level_offset[num_levels - 1]; - let buckets = &mut points[start..(start + num_buckets)]; - let mut results: Vec = vec![C::Curve::identity(); num_threads]; - scope(|scope| { - let chunk_size = num_buckets / num_threads; - for (thread_idx, ((bucket_sizes, buckets), result)) in bucket_sizes[1..] - .chunks(chunk_size) - .zip(buckets[..].chunks_mut(chunk_size)) - .zip(results.chunks_mut(1)) - .enumerate() - { - scope.spawn(move |_| { - // Accumulate all bucket results - let num_buckets_thread = bucket_sizes.len(); - let mut acc = C::Curve::identity(); - let mut running_sum = C::Curve::identity(); - for b in (0..num_buckets_thread).rev() { - if bucket_sizes[b] > 0 { - running_sum = running_sum + buckets[b]; - } - acc = acc + &running_sum; - } - - // Each thread started at a different bucket location - // so correct for that here - let bucket_start = thread_idx * chunk_size; - let num_bits = num_bits(bucket_start); - let mut accumulator = C::Curve::identity(); - for idx in (0..num_bits).rev() { - accumulator = accumulator.double(); - if (bucket_start >> idx) & 1 != 0 { - accumulator += running_sum; - } - } - acc += accumulator; - - // Store the result - result[0] = acc; - }); - } - }); - stop_measure(start_time); - - // Add the results of all threads together - results - .iter() - .fold(C::Curve::identity(), |acc, result| acc + result) -} - -use crate::CurveAffineExt; -use std::{ - env::var, - sync::atomic::{AtomicUsize, Ordering}, - time::Instant, -}; - -#[allow(missing_debug_implementations)] -pub struct MeasurementInfo { - /// Show measurement - pub show: bool, - /// The start time - pub time: Instant, - /// What is being measured - pub message: String, - /// The indent - pub indent: usize, -} - -/// Global indent counter -pub static NUM_INDENT: AtomicUsize = AtomicUsize::new(0); - -/// Gets the time difference between the current time and the passed in time -pub fn get_duration(start: Instant) -> usize { - let final_time = Instant::now() - start; - let secs = final_time.as_secs() as usize; - let millis = final_time.subsec_millis() as usize; - let micros = (final_time.subsec_micros() % 1000) as usize; - secs * 1000000 + millis * 1000 + micros -} - -/// Prints a measurement on screen -pub fn log_measurement(indent: Option, msg: String, duration: usize) { - let indent = indent.unwrap_or(0); - println!( - "{}{} ........ {}s", - "*".repeat(indent), - msg, - (duration as f32) / 1000000.0 - ); -} - -/// Starts a measurement -pub fn start_measure(msg: String, always: bool) -> MeasurementInfo { - let measure = env_value("MEASURE", 0); - let indent = NUM_INDENT.fetch_add(1, Ordering::Relaxed); - MeasurementInfo { - show: always || measure == 1, - time: Instant::now(), - message: msg, - indent, - } -} - -/// Stops a measurement, returns the duration -pub fn stop_measure(info: MeasurementInfo) -> usize { - NUM_INDENT.fetch_sub(1, Ordering::Relaxed); - let duration = get_duration(info.time); - if info.show { - log_measurement(Some(info.indent), info.message, duration); - } - duration -} - -/// Gets the ENV variable if defined, otherwise returns the default value -pub fn env_value(key: &str, default: usize) -> usize { - match var(key) { - Ok(val) => val.parse().unwrap(), - Err(_) => default, - } -} diff --git a/src/derive/curve.rs b/src/derive/curve.rs index d42384ba..866fe8e7 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -1,145 +1,3 @@ -#[macro_export] -macro_rules! batch_add { - () => { - fn batch_add( - points: &mut [Self], - output_indices: &[u32], - num_points: usize, - offset: usize, - bases: &[Self], - base_positions: &[u32], - ) { - // assert!(Self::constant_a().is_zero()); - - let get_point = |point_data: u32| -> Self { - let negate = point_data & 0x80000000 != 0; - let base_idx = (point_data & 0x7FFFFFFF) as usize; - if negate { - bases[base_idx].neg() - } else { - bases[base_idx] - } - }; - - // Affine addition formula (P != Q): - // - lambda = (y_2 - y_1) / (x_2 - x_1) - // - x_3 = lambda^2 - (x_2 + x_1) - // - y_3 = lambda * (x_1 - x_3) - y_1 - - // Batch invert accumulator - let mut acc = Self::Base::one(); - - for i in (0..num_points).step_by(2) { - // Where that result of the point addition will be stored - let out_idx = output_indices[i >> 1] as usize - offset; - - #[cfg(all(feature = "prefetch", target_arch = "x86_64"))] - if i < num_points - 2 { - if LOAD_POINTS { - $crate::prefetch::(bases, base_positions[i + 2] as usize); - $crate::prefetch::(bases, base_positions[i + 3] as usize); - } - $crate::prefetch::( - points, - output_indices[(i >> 1) + 1] as usize - offset, - ); - } - if LOAD_POINTS { - points[i] = get_point(base_positions[i]); - points[i + 1] = get_point(base_positions[i + 1]); - } - - if COMPLETE { - // Nothing to do here if one of the points is zero - if (points[i].is_identity() | points[i + 1].is_identity()).into() { - continue; - } - - if points[i].x == points[i + 1].x { - if points[i].y == points[i + 1].y { - // Point doubling (P == Q) - // - s = (3 * x^2) / (2 * y) - // - x_2 = s^2 - (2 * x) - // - y_2 = s * (x - x_2) - y - - // (2 * x) - points[out_idx].x = points[i].x + points[i].x; - // x^2 - let xx = points[i].x.square(); - // (2 * y) - points[i + 1].x = points[i].y + points[i].y; - // (3 * x^2) * acc - points[i + 1].y = (xx + xx + xx) * acc; - // acc * (2 * y) - acc *= points[i + 1].x; - continue; - } else { - // Zero - points[i] = Self::identity(); - points[i + 1] = Self::identity(); - continue; - } - } - } - - // (x_2 + x_1) - points[out_idx].x = points[i].x + points[i + 1].x; - // (x_2 - x_1) - points[i + 1].x -= points[i].x; - // (y2 - y1) * acc - points[i + 1].y = (points[i + 1].y - points[i].y) * acc; - // acc * (x_2 - x_1) - acc *= points[i + 1].x; - } - - // Batch invert - if COMPLETE { - if (!acc.is_zero()).into() { - acc = acc.invert().unwrap(); - } - } else { - acc = acc.invert().unwrap(); - } - - for i in (0..num_points).step_by(2).rev() { - // Where that result of the point addition will be stored - let out_idx = output_indices[i >> 1] as usize - offset; - - #[cfg(all(feature = "prefetch", target_arch = "x86_64"))] - if i > 0 { - $crate::prefetch::( - points, - output_indices[(i >> 1) - 1] as usize - offset, - ); - } - - if COMPLETE { - // points[i] is zero so the sum is points[i + 1] - if points[i].is_identity().into() { - points[out_idx] = points[i + 1]; - continue; - } - // points[i + 1] is zero so the sum is points[i] - if points[i + 1].is_identity().into() { - points[out_idx] = points[i]; - continue; - } - } - - // lambda - points[i + 1].y *= acc; - // acc * (x_2 - x_1) - acc *= points[i + 1].x; - // x_3 = lambda^2 - (x_2 + x_1) - points[out_idx].x = points[i + 1].y.square() - points[out_idx].x; - // y_3 = lambda * (x_1 - x_3) - y_1 - points[out_idx].y = - points[i + 1].y * (points[i].x - points[out_idx].x) - points[i].y; - } - } - }; -} - #[macro_export] macro_rules! endo { ($params:expr) => { diff --git a/src/secp256k1/curve.rs b/src/secp256k1/curve.rs index 3f644249..93530d80 100644 --- a/src/secp256k1/curve.rs +++ b/src/secp256k1/curve.rs @@ -1,6 +1,6 @@ use crate::secp256k1::Fp; use crate::secp256k1::Fq; -use crate::{Coordinates, CurveAffine, CurveAffineExt, CurveExt}; +use crate::{Coordinates, CurveAffine, CurveExt}; use core::cmp; use core::fmt::Debug; use core::iter::Sum; @@ -44,9 +44,9 @@ const SECP_GENERATOR_Y: Fp = Fp::from_raw([ const SECP_B: Fp = Fp::from_raw([7, 0, 0, 0]); use crate::{ - batch_add, impl_add_binop_specify_output, impl_binops_additive, - impl_binops_additive_specify_output, impl_binops_multiplicative, - impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, new_curve_impl, + impl_add_binop_specify_output, impl_binops_additive, impl_binops_additive_specify_output, + impl_binops_multiplicative, impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, + new_curve_impl, }; new_curve_impl!( From 2afd5e23adccc5c181a459bcbef8a4ec01669f26 Mon Sep 17 00:00:00 2001 From: kilic Date: Wed, 1 Mar 2023 15:57:40 +0300 Subject: [PATCH 7/7] fix typo --- src/bn256/msm/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bn256/msm/mod.rs b/src/bn256/msm/mod.rs index 2ce91643..2c0d4de8 100644 --- a/src/bn256/msm/mod.rs +++ b/src/bn256/msm/mod.rs @@ -62,7 +62,7 @@ pub struct MSM { } impl MSM { - pub fn alloacate(n_points: usize, override_window: Option) -> Self { + pub fn allocate(n_points: usize, override_window: Option) -> Self { fn best_window(n: usize) -> usize { if n >= 262144 { 15 @@ -163,7 +163,7 @@ impl MSM { acc: &mut G1, override_window: Option, ) { - let mut msm = Self::alloacate(points.len(), override_window); + let mut msm = Self::allocate(points.len(), override_window); msm.decompose(scalars); for w_i in (0..msm.n_windows).rev() { if w_i != msm.n_windows - 1 {