diff --git a/.gitignore b/.gitignore index bceff729..de51da8b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ Cargo.lock **/*.rs.bk .vscode -**/*.html \ No newline at end of file +**/*.html +*.tmp \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 822c6acc..2effac82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ rand_xorshift = "0.3" ark-std = { version = "0.3" } [dependencies] +rayon = "1.5.1" subtle = "2.4" ff = "0.13.0" group = "0.13.0" diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 42b84772..6c1c0527 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -4,22 +4,26 @@ //! This module is temporary, and the extension traits defined here are expected to be //! upstreamed into the `ff` and `group` crates after some refactoring. -pub trait CurveAffineExt: pasta_curves::arithmetic::CurveAffine { - fn batch_add( - points: &mut [Self], - output_indices: &[u32], - num_points: usize, - offset: usize, - bases: &[Self], - base_positions: &[u32], - ); - - /// Unlike the `Coordinates` trait, this just returns the raw affine coordinates without checking `is_on_curve` - fn into_coordinates(self) -> (Self::Base, Self::Base) { - // fallback implementation - let coordinates = self.coordinates().unwrap(); - (*coordinates.x(), *coordinates.y()) - } +use ff::PrimeField; +use group::Group; +use pasta_curves::arithmetic::CurveAffine; + +use crate::CurveExt; + +pub(crate) struct EndoParameters { + pub(crate) gamma1: [u64; 4], + pub(crate) gamma2: [u64; 4], + pub(crate) b1: [u64; 4], + pub(crate) b2: [u64; 4], +} + +pub trait CurveEndo: CurveExt { + fn decompose_scalar(e: &Self::ScalarExt) -> (u128, bool, u128, bool); +} + +pub trait CurveAffineExt: CurveAffine { + fn decompose_scalar(k: &Self::ScalarExt) -> (u128, bool, u128, bool); + fn endo(&self) -> Self; } /// Compute a + b + carry, returning the result and the new carry over. @@ -42,3 +46,367 @@ pub(crate) const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) { let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128); (ret as u64, (ret >> 64) as u64) } + +/// Compute a + (b * c), returning the result and the new carry over. +#[inline(always)] +pub(crate) const fn macx(a: u64, b: u64, c: u64) -> (u64, u64) { + let res = (a as u128) + ((b as u128) * (c as u128)); + (res as u64, (res >> 64) as u64) +} + +/// Compute a * b, returning the result. +#[inline(always)] +pub(crate) fn mul_512(a: [u64; 4], b: [u64; 4]) -> [u64; 8] { + let (r0, carry) = macx(0, a[0], b[0]); + let (r1, carry) = macx(carry, a[0], b[1]); + let (r2, carry) = macx(carry, a[0], b[2]); + let (r3, carry_out) = macx(carry, a[0], b[3]); + + let (r1, carry) = macx(r1, a[1], b[0]); + let (r2, carry) = mac(r2, a[1], b[1], carry); + let (r3, carry) = mac(r3, a[1], b[2], carry); + let (r4, carry_out) = mac(carry_out, a[1], b[3], carry); + + let (r2, carry) = macx(r2, a[2], b[0]); + let (r3, carry) = mac(r3, a[2], b[1], carry); + let (r4, carry) = mac(r4, a[2], b[2], carry); + let (r5, carry_out) = mac(carry_out, a[2], b[3], carry); + + let (r3, carry) = macx(r3, a[3], b[0]); + let (r4, carry) = mac(r4, a[3], b[1], carry); + let (r5, carry) = mac(r5, a[3], b[2], carry); + let (r6, carry_out) = mac(carry_out, a[3], b[3], carry); + + [r0, r1, r2, r3, r4, r5, r6, carry_out] +} + +// taken from zcash/halo2::aritmetic +pub(crate) fn msm_zcash(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + + if skip_bytes >= 32 { + return 0; + } + + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *v = *o; + } + + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + + tmp as usize + } + + let segments = (256 / c) + 1; + + for current_segment in (0..segments).rev() { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + &a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; + + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + let coeff = get_at::(current_segment, c, coeff); + if coeff != 0 { + buckets[coeff - 1].add_assign(base); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} + +#[cfg(test)] +mod test { + use super::CurveEndo; + use crate::bn256::G1; + use ff::Field; + use pasta_curves::Ep; + use pasta_curves::Eq; + use rand_core::OsRng; + + // naive glv multiplication implementation + fn glv_mul(point: C, scalar: &C::ScalarExt) -> C { + const WINDOW: usize = 3; + // decompose scalar and convert to wnaf representation + let (k1, k1_neg, k2, k2_neg) = C::decompose_scalar(scalar); + + let mut k1_wnaf: Vec = Vec::new(); + let mut k2_wnaf: Vec = Vec::new(); + wnaf::form(&mut k1_wnaf, k1.to_le_bytes(), WINDOW); + wnaf::form(&mut k2_wnaf, k2.to_le_bytes(), WINDOW); + + let n = std::cmp::max(k1_wnaf.len(), k2_wnaf.len()); + k1_wnaf.resize(n, 0); + k2_wnaf.resize(n, 0); + + // prepare tables + let two_p = point.double(); + // T1 = {P, 3P, 5P, ...} + let mut table_k1 = vec![point]; + // T2 = {λP, 3λP, 5λP, ...} + let mut table_k2 = vec![point.endo()]; + for i in 1..WINDOW - 1 { + table_k1.push(table_k1[i - 1] + two_p); + table_k2.push(table_k1[i].endo()) + } + if !k2_neg { + table_k2.iter_mut().for_each(|p| *p = -*p); + } + if k1_neg { + table_k1.iter_mut().for_each(|p| *p = -*p); + } + // TODO: batch affine tables for mixed add? + + macro_rules! add { + ($acc:expr, $e:expr, $table:expr) => { + let idx = ($e.abs() >> 1) as usize; + $acc += if $e.is_positive() { + $table[idx] + } else if $e.is_negative() { + -$table[idx] + } else { + C::identity() + }; + }; + } + + // apply simultaneus double add + k1_wnaf + .iter() + .rev() + .zip(k2_wnaf.iter().rev()) + .fold(C::identity(), |acc, (e1, e2)| { + let mut acc = acc.double(); + add!(acc, e1, table_k1); + add!(acc, e2, table_k2); + acc + }) + } + + fn run_glv_mul_test() { + for _ in 0..10000 { + let point = C::random(OsRng); + let scalar = C::ScalarExt::random(OsRng); + let r0 = point * scalar; + let r1 = glv_mul(point, &scalar); + assert_eq!(r0, r1); + } + } + + #[test] + fn test_glv_mul() { + run_glv_mul_test::(); + // run_glv_mul_test::(); + // run_glv_mul_test::(git); + } + + #[test] + fn test_wnaf_form() { + use rand::Rng; + fn from_wnaf(wnaf: &[i64]) -> u128 { + wnaf.iter().rev().fold(0, |acc, next| { + let mut acc = acc * 2; + acc += *next as u128; + acc + }) + } + for w in 2..64 { + for e in 0..=u16::MAX { + let mut wnaf = vec![]; + wnaf::form(&mut wnaf, e.to_le_bytes(), w); + assert_eq!(e as u128, from_wnaf(&wnaf)); + } + } + let mut wnaf = vec![]; + for w in 2..64 { + for e in u128::MAX - 10000..=u128::MAX { + wnaf::form(&mut wnaf, e.to_le_bytes(), w); + assert_eq!(e, from_wnaf(&wnaf)); + } + } + for w in 2..10 { + for _ in 0..10000 { + let e: u128 = OsRng.gen(); + wnaf::form(&mut wnaf, e.to_le_bytes(), w); + assert_eq!(e as u128, from_wnaf(&wnaf)); + } + } + } + + // taken from zkcrypto/group + mod wnaf { + use std::convert::TryInto; + + #[derive(Debug, Clone)] + struct LimbBuffer<'a> { + buf: &'a [u8], + cur_idx: usize, + cur_limb: u64, + next_limb: u64, + } + + impl<'a> LimbBuffer<'a> { + fn new(buf: &'a [u8]) -> Self { + let mut ret = Self { + buf, + cur_idx: 0, + cur_limb: 0, + next_limb: 0, + }; + + // Initialise the limb buffers. + ret.increment_limb(); + ret.increment_limb(); + ret.cur_idx = 0usize; + ret + } + + fn increment_limb(&mut self) { + self.cur_idx += 1; + self.cur_limb = self.next_limb; + match self.buf.len() { + // There are no more bytes in the buffer; zero-extend. + 0 => self.next_limb = 0, + + // There are fewer bytes in the buffer than a u64 limb; zero-extend. + x @ 1..=7 => { + let mut next_limb = [0; 8]; + next_limb[..x].copy_from_slice(self.buf); + self.next_limb = u64::from_le_bytes(next_limb); + self.buf = &[]; + } + + // There are at least eight bytes in the buffer; read the next u64 limb. + _ => { + let (next_limb, rest) = self.buf.split_at(8); + self.next_limb = u64::from_le_bytes(next_limb.try_into().unwrap()); + self.buf = rest; + } + } + } + + fn get(&mut self, idx: usize) -> (u64, u64) { + assert!([self.cur_idx, self.cur_idx + 1].contains(&idx)); + if idx > self.cur_idx { + self.increment_limb(); + } + (self.cur_limb, self.next_limb) + } + } + + /// Replaces the contents of `wnaf` with the w-NAF representation of a little-endian + /// scalar. + pub(crate) fn form>(wnaf: &mut Vec, c: S, window: usize) { + // Required by the NAF definition + debug_assert!(window >= 2); + // Required so that the NAF digits fit in i64 + debug_assert!(window < 64); + + let bit_len = c.as_ref().len() * 8; + + wnaf.truncate(0); + wnaf.reserve(bit_len + 1); + + // Initialise the current and next limb buffers. + let mut limbs = LimbBuffer::new(c.as_ref()); + + let width = 1u64 << window; + let window_mask = width - 1; + + let mut pos = 0; + let mut carry = 0; + while pos <= bit_len { + // Construct a buffer of bits of the scalar, starting at bit `pos` + let u64_idx = pos / 64; + let bit_idx = pos % 64; + let (cur_u64, next_u64) = limbs.get(u64_idx); + let bit_buf = if bit_idx + window < 64 { + // This window's bits are contained in a single u64 + cur_u64 >> bit_idx + } else { + // Combine the current u64's bits with the bits from the next u64 + (cur_u64 >> bit_idx) | (next_u64 << (64 - bit_idx)) + }; + + // Add the carry into the current window + let window_val = carry + (bit_buf & window_mask); + + if window_val & 1 == 0 { + // If the window value is even, preserve the carry and emit 0. + // Why is the carry preserved? + // If carry == 0 and window_val & 1 == 0, then the next carry should be 0 + // If carry == 1 and window_val & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1 + wnaf.push(0); + pos += 1; + } else { + wnaf.push(if window_val < width / 2 { + carry = 0; + window_val as i64 + } else { + carry = 1; + (window_val as i64).wrapping_sub(width as i64) + }); + wnaf.extend(std::iter::repeat(0).take(window - 1)); + pos += window; + } + } + wnaf.truncate(wnaf.len().saturating_sub(window - 1)); + } + } +} diff --git a/src/bn256/curve.rs b/src/bn256/curve.rs index 7c359233..e288f4fb 100644 --- a/src/bn256/curve.rs +++ b/src/bn256/curve.rs @@ -1,6 +1,11 @@ +use crate::arithmetic::mul_512; +use crate::arithmetic::sbb; +use crate::arithmetic::CurveEndo; +use crate::arithmetic::EndoParameters; use crate::bn256::Fq; use crate::bn256::Fq2; use crate::bn256::Fr; +use crate::endo; use crate::{Coordinates, CurveAffine, CurveAffineExt, CurveExt}; use core::cmp; use core::fmt::Debug; @@ -11,12 +16,13 @@ use ff::{Field, PrimeField}; use group::Curve; use group::{cofactor::CofactorGroup, prime::PrimeCurveAffine, Group, GroupEncoding}; use rand::RngCore; +use std::convert::TryInto; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use crate::{ - batch_add, impl_add_binop_specify_output, impl_binops_additive, - impl_binops_additive_specify_output, impl_binops_multiplicative, - impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, new_curve_impl, + impl_add_binop_specify_output, impl_binops_additive, impl_binops_additive_specify_output, + impl_binops_multiplicative, impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, + new_curve_impl, }; new_curve_impl!( @@ -44,18 +50,28 @@ new_curve_impl!( ); impl CurveAffineExt for G1Affine { - batch_add!(); + endo!(ENDO_PARAMS); - fn into_coordinates(self) -> (Self::Base, Self::Base) { - (self.x, self.y) + fn endo(&self) -> Self { + Self { + x: self.x * Self::Base::ZETA, + y: self.y, + } } } +impl CurveEndo for G1 { + endo!(ENDO_PARAMS); +} + impl CurveAffineExt for G2Affine { - batch_add!(); + endo!(ENDO_PARAMS); - fn into_coordinates(self) -> (Self::Base, Self::Base) { - (self.x, self.y) + fn endo(&self) -> Self { + Self { + x: self.x * Self::Base::ZETA, + y: self.y, + } } } @@ -109,6 +125,18 @@ const G2_GENERATOR_Y: Fq2 = Fq2 { ]), }; +const ENDO_PARAMS: EndoParameters = EndoParameters { + gamma1: [ + 0x7a7bd9d4391eb18du64, + 0x4ccef014a773d2cfu64, + 0x0000000000000002u64, + 0u64, + ], + gamma2: [0xd91d232ec7e0b3d7u64, 0x0000000000000002u64, 0u64, 0u64], + b1: [0x8211bbeb7d4f1128u64, 0x6f4d8248eeb859fcu64, 0u64, 0u64], + b2: [0x89d3256894d213e3u64, 0u64, 0u64, 0u64], +}; + impl group::cofactor::CofactorGroup for G1 { type Subgroup = G1; @@ -178,9 +206,14 @@ impl CofactorGroup for G2 { #[cfg(test)] mod tests { + + use crate::arithmetic::CurveEndo; use crate::bn256::{Fr, G1, G2}; use crate::CurveExt; + use ff::Field; + use ff::PrimeField; use ff::WithSmallOrderMulGroup; + use rand_core::OsRng; #[test] fn test_curve() { @@ -189,12 +222,24 @@ mod tests { } #[test] - fn test_endo_consistency() { + fn test_endo() { let g = G1::generator(); assert_eq!(g * Fr::ZETA, g.endo()); - let g = G2::generator(); assert_eq!(g * Fr::ZETA, g.endo()); + for _ in 0..100000 { + let k = Fr::random(OsRng); + let (k1, k1_neg, k2, k2_neg) = G1::decompose_scalar(&k); + if k1_neg & k2_neg { + assert_eq!(k, -Fr::from_u128(k1) + Fr::ZETA * Fr::from_u128(k2)) + } else if k1_neg { + assert_eq!(k, -Fr::from_u128(k1) - Fr::ZETA * Fr::from_u128(k2)) + } else if k2_neg { + assert_eq!(k, Fr::from_u128(k1) + Fr::ZETA * Fr::from_u128(k2)) + } else { + assert_eq!(k, Fr::from_u128(k1) - Fr::ZETA * Fr::from_u128(k2)) + } + } } #[test] diff --git a/src/bn256/mod.rs b/src/bn256/mod.rs index 9cd08946..cd8c5bad 100644 --- a/src/bn256/mod.rs +++ b/src/bn256/mod.rs @@ -5,6 +5,7 @@ mod fq12; mod fq2; mod fq6; mod fr; +pub mod msm; #[cfg(feature = "asm")] mod assembly; diff --git a/src/bn256/msm/mod.rs b/src/bn256/msm/mod.rs new file mode 100644 index 00000000..2c0d4de8 --- /dev/null +++ b/src/bn256/msm/mod.rs @@ -0,0 +1,284 @@ +use super::{Fr, G1Affine}; +use crate::arithmetic::msm_zcash; +use crate::bn256::{msm::round::Round, G1}; +use crate::group::Group; +use ff::PrimeField; +use rayon::{current_num_threads, scope}; + +#[macro_export] +macro_rules! div_ceil { + ($a:expr, $b:expr) => { + (($a - 1) / $b) + 1 + }; +} + +#[macro_export] +macro_rules! double_n { + ($acc:expr, $n:expr) => { + (0..$n).fold($acc, |acc, _| acc.double()) + }; +} + +#[macro_export] +macro_rules! range { + ($index:expr, $n_items:expr) => { + $index * $n_items..($index + 1) * $n_items + }; +} + +#[macro_export] +macro_rules! index { + ($digit:expr) => { + ($digit & 0x7fffffff) as usize + }; +} + +#[macro_export] +macro_rules! is_neg { + ($digit:expr) => { + sign_bit!($digit) != 0 + }; +} + +#[macro_export] +macro_rules! sign_bit { + ($digit:expr) => { + $digit & 0x80000000 + }; +} + +mod round; + +pub struct MSM { + signed_digits: Vec, + sorted_positions: Vec, + bucket_sizes: Vec, + bucket_offsets: Vec, + n_windows: usize, + window: usize, + n_buckets: usize, + n_points: usize, + round: Round, +} + +impl MSM { + pub fn allocate(n_points: usize, override_window: Option) -> Self { + fn best_window(n: usize) -> usize { + if n >= 262144 { + 15 + } else if n >= 65536 { + 12 + } else if n >= 16384 { + 11 + } else if n >= 8192 { + 10 + } else if n >= 1024 { + 9 + } else { + 7 + } + } + let window = match override_window { + Some(window) => { + let overriden = best_window(n_points); + println!("override window from {} to {}", overriden, window); + window + } + None => best_window(n_points), + }; + let n_windows = div_ceil!(Fr::NUM_BITS as usize, window); + let n_buckets = (1 << (window - 1)) + 1; + let round = Round::new(n_buckets, n_points); + MSM { + signed_digits: vec![0u32; n_windows * n_points], + sorted_positions: vec![0u32; n_windows * n_points], + bucket_sizes: vec![0usize; n_windows * n_buckets], + bucket_offsets: vec![0; n_buckets], + n_windows, + window, + n_buckets, + n_points, + round, + } + } + + fn decompose(&mut self, scalars: &[Fr]) { + pub(crate) fn get_bits(segment: usize, window: usize, bytes: &[u8]) -> u32 { + let skip_bits = segment * window; + let skip_bytes = skip_bits / 8; + if skip_bytes >= 32 { + return 0; + } + let mut v = [0; 4]; + for (v, o) in v.iter_mut().zip(bytes[skip_bytes..].iter()) { + *v = *o; + } + let mut tmp = u32::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << window; + tmp + } + let max = 1 << (self.window - 1); + for (point_idx, scalar) in scalars.iter().enumerate() { + let repr = scalar.to_repr(); + let mut borrow = 0u32; + for window_idx in 0..self.n_windows { + let windowed_digit = get_bits(window_idx, self.window, repr.as_ref()) + borrow; + let signed_digit = if windowed_digit >= max { + borrow = 1; + ((1 << self.window) - windowed_digit) | 0x80000000 + } else { + borrow = 0; + windowed_digit + }; + self.bucket_sizes[window_idx * self.n_buckets + index!(signed_digit)] += 1; + self.signed_digits[window_idx * self.n_points + point_idx] = signed_digit; + } + } + self.sort(); + } + + fn sort(&mut self) { + for w_i in 0..self.n_windows { + let sorted_positions = &mut self.sorted_positions[range!(w_i, self.n_points)]; + let bucket_sizes = &self.bucket_sizes[range!(w_i, self.n_buckets)]; + let signed_digits = &self.signed_digits[range!(w_i, self.n_points)]; + let mut offset = 0; + for (i, size) in bucket_sizes.iter().enumerate() { + self.bucket_offsets[i] = offset; + offset += size; + } + for (sorted_idx, signed_digit) in signed_digits.iter().enumerate() { + let bucket_idx = index!(signed_digit); + sorted_positions[self.bucket_offsets[bucket_idx]] = + sign_bit!(signed_digit) | (sorted_idx as u32); + self.bucket_offsets[bucket_idx] += 1; + } + } + } + + pub fn evaluate_with( + scalars: &[Fr], + points: &[G1Affine], + acc: &mut G1, + override_window: Option, + ) { + let mut msm = Self::allocate(points.len(), override_window); + msm.decompose(scalars); + for w_i in (0..msm.n_windows).rev() { + if w_i != msm.n_windows - 1 { + *acc = double_n!(*acc, msm.window); + } + msm.round.init( + points, + &msm.sorted_positions[range!(w_i, msm.n_points)], + &msm.bucket_sizes[range!(w_i, msm.n_buckets)], + ); + let buckets = msm.round.evaluate(); + let mut running_sum = G1::identity(); + for bucket in buckets.into_iter().skip(1).rev() { + running_sum += bucket; + *acc += &running_sum; + } + } + } + + pub fn evaluate(scalars: &[Fr], points: &[G1Affine], acc: &mut G1) { + Self::evaluate_with(scalars, points, acc, None) + } + + pub fn best(scalars: &[Fr], points: &[G1Affine]) -> G1 { + assert_eq!(scalars.len(), points.len()); + let num_threads = current_num_threads(); + if scalars.len() > num_threads { + let chunk = scalars.len() / num_threads; + let num_chunks = scalars.chunks(chunk).len(); + let mut results = vec![G1::identity(); num_chunks]; + scope(|scope| { + let chunk = scalars.len() / num_threads; + + for ((scalars, points), acc) in scalars + .chunks(chunk) + .zip(points.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + if points.len() < 1 << 8 { + msm_zcash(scalars, points, acc); + } else { + Self::evaluate(scalars, points, acc); + } + }); + } + }); + results.iter().fold(G1::identity(), |a, b| a + b) + } else { + let mut acc = G1::identity(); + Self::evaluate(scalars, points, &mut acc); + acc + } + } +} + +#[cfg(test)] +mod test { + use crate::arithmetic::msm_zcash; + use crate::bn256::{Fr, G1Affine, G1}; + use crate::group::Group; + use crate::serde::SerdeObject; + use ff::Field; + use group::Curve; + use rand::Rng; + use rand_core::OsRng; + use std::fs::File; + use std::path::Path; + + fn read_data(n: usize) -> (Vec, Vec) { + let mut file = File::open("data.tmp").unwrap(); + (0..n) + .map(|_| { + let point = G1Affine::read_raw_unchecked(&mut file); + let scalar = Fr::read_raw_unchecked(&mut file); + (point, scalar) + }) + .unzip() + } + + pub(crate) fn get_data(n: usize) -> (Vec, Vec) { + const MAX_N: usize = 1 << 22; + assert!(n <= MAX_N); + if Path::new("data.tmp").is_file() { + read_data(n) + } else { + let mut file = File::create("data.tmp").unwrap(); + (0..MAX_N) + .map(|_| { + let point = G1::random(OsRng).to_affine(); + let scalar = Fr::random(OsRng); + point.write_raw(&mut file).unwrap(); + scalar.write_raw(&mut file).unwrap(); + (point, scalar) + }) + .take(n) + .unzip() + } + } + + #[test] + fn test_msm() { + let (min_k, max_k) = (4, 20); + let (points, scalars) = get_data(1 << max_k); + + for k in min_k..=max_k { + let mut rng = OsRng; + let n_points = rng.gen_range(1 << (k - 1)..1 << k); + let scalars = &scalars[..n_points]; + let points = &points[..n_points]; + let mut r0 = G1::identity(); + msm_zcash(scalars, points, &mut r0); + let mut r1 = G1::identity(); + super::MSM::evaluate(scalars, points, &mut r1); + assert_eq!(r0, r1); + } + } +} diff --git a/src/bn256/msm/round.rs b/src/bn256/msm/round.rs new file mode 100644 index 00000000..6fb03384 --- /dev/null +++ b/src/bn256/msm/round.rs @@ -0,0 +1,360 @@ +use super::super::{Fq, G1Affine}; +use crate::group::prime::PrimeCurveAffine; +use ff::Field; + +macro_rules! log_ceil { + ($a:expr) => { + $a.next_power_of_two().trailing_zeros() + }; +} + +pub(crate) struct Round { + t: Vec, + odd_points: Vec, + bucket_sizes: Vec, + n_buckets: usize, + n_points: usize, + out_off: usize, + in_off: usize, +} + +macro_rules! first_phase { + ($out:expr, $p0:expr, $p1:expr, $acc:expr) => { + $out.x = $p0.x + $p1.x; + $p1.x = $p1.x - $p0.x; + $p1.y = ($p1.y - $p0.y) * $acc; + $acc = $acc * $p1.x; + }; +} +macro_rules! second_phase { + ($out:expr, $p0:expr, $p1:expr, $acc:expr) => { + $p1.y = $p1.y * $acc; + $acc = $acc * $p1.x; + $out.x = $p1.y.square() - $out.x; + $out.y = ($p1.y * ($p0.x - $out.x)) - $p0.y; + }; +} + +impl Round { + pub(crate) fn new(n_buckets: usize, n_points: usize) -> Self { + let odd_points = vec![G1Affine::identity(); n_buckets]; + let bucket_sizes = vec![0; n_buckets]; + // TODO: requires less than allocated + let t = vec![G1Affine::identity(); n_points * 2]; + Self { + t, + odd_points, + bucket_sizes, + n_buckets, + n_points, + out_off: 0, + in_off: 0, + } + } + pub(crate) fn init(&mut self, bases: &[G1Affine], positions: &[u32], bucket_sizes: &[usize]) { + { + assert_eq!(self.n_points, positions.len()); + assert_eq!(self.n_points, bases.len()); + assert_eq!(self.bucket_sizes.len(), self.n_buckets); + } + macro_rules! get_base { + ($positions:expr, $off:expr) => { + if is_neg!($positions[$off]) { + -bases[index!($positions[$off])] + } else { + bases[index!($positions[$off])] + } + }; + } + let n_additions = bucket_sizes + .iter() + .map(|bucket_sizes| bucket_sizes / 2) + .collect::>(); + let mut out_off = n_additions.iter().sum::(); + let mut tmp_off = 0; + let mut position_off = 0; + let mut acc = Fq::ONE; + for (bucket_index, bucket_size) in bucket_sizes.iter().enumerate() { + let positions = &positions[position_off..position_off + bucket_size]; + position_off += bucket_size; + let bucket_size_pre = positions.len(); + if bucket_size_pre == 0 { + self.odd_points[bucket_index] = G1Affine::identity(); + self.bucket_sizes[bucket_index] = 0; + } else { + let mut in_off = 0; + let bucket_size_post = (bucket_size_pre + 1) / 2; + let n_additions = bucket_size_pre / 2; + // process even number of additions + for _ in 0..n_additions & (usize::MAX - 1) { + // second operand must be mutable + self.t[tmp_off] = get_base!(positions, in_off); + let lhs = get_base!(positions, in_off + 1); + first_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); + tmp_off += 1; + out_off += 1; + in_off += 2; + } + // process the latest elements if there are odd number of additions + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + // 1 base point left + // move to odd-point cache + (true, true) => { + assert_eq!(positions.len() - 1, in_off); + self.odd_points[bucket_index] = get_base!(positions, in_off); + } + // 2 base point left + // move addition result to odd-point cache + (false, true) => { + self.t[tmp_off] = get_base!(positions, in_off); + let lhs = get_base!(positions, in_off + 1); + first_phase!(self.odd_points[bucket_index], lhs, self.t[tmp_off], acc); + tmp_off += 1; + } + // 3 base point left + // move addition of first two to intermediate and last to odd-point cache + (true, false) => { + self.t[tmp_off] = get_base!(positions, in_off); + let lhs = get_base!(positions, in_off + 1); + first_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); + self.t[out_off + 1] = get_base!(positions, in_off + 2); + tmp_off += 1; + out_off += 2; + } + _ => { /* 0 base point left */ } + } + self.bucket_sizes[bucket_index] = bucket_size_post; + } + } + self.in_off = tmp_off; + self.out_off = out_off; + tmp_off -= 1; + out_off -= 1; + acc = acc.invert().unwrap(); + for (bucket_index, bucket_size_pre) in bucket_sizes.iter().enumerate().rev() { + let positions = &positions[position_off - bucket_size_pre..position_off]; + position_off -= bucket_size_pre; + if *bucket_size_pre == 0 { + // already updated in first phase + } else { + if positions.len() != 0 { + let bucket_size_post = (bucket_size_pre + 1) / 2; + let n_additions = bucket_size_pre / 2; + let mut in_off = positions.len() - 1; + // process the latest elements if there are odd number of additions + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + // 1 base point left + // move to odd-point cache + (true, true) => { + in_off -= 1; + } + // 2 base point left + // move addition result to odd-point cache + (false, true) => { + let lhs = get_base!(positions, in_off); + second_phase!(self.odd_points[bucket_index], lhs, self.t[tmp_off], acc); + tmp_off -= 1; + in_off -= 2; + } + // 3 base point left + // move addition of first two to intermediate and last to odd-point cache + (true, false) => { + in_off -= 1; + out_off -= 1; + let lhs = get_base!(positions, in_off); + second_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); + tmp_off -= 1; + in_off -= 2; + out_off -= 1; + } + _ => { /* 0 base point left */ } + } + // process even number of additions + for _ in (0..n_additions & (usize::MAX - 1)).rev() { + let lhs = get_base!(positions, in_off); + second_phase!(self.t[out_off], lhs, self.t[tmp_off], acc); + tmp_off -= 1; + out_off -= 1; + in_off -= 2; + } + } + } + } + } + fn batch_add(&mut self) { + let (mut out_off, mut in_off) = (self.out_off, self.in_off); + let mut acc = Fq::ONE; + for bucket_index in 0..self.n_buckets { + let bucket_size_pre = self.bucket_sizes[bucket_index]; + let n_additions = bucket_size_pre / 2; + let bucket_size_post = (bucket_size_pre + 1) / 2; + for _ in 0..n_additions & (usize::MAX - 1) { + first_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + (out_off, in_off) = (out_off + 1, in_off + 2); + } + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + (true, false) => { + first_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + self.t[out_off + 1] = self.odd_points[bucket_index]; + out_off += 2; + in_off += 2; + } + (false, true) => { + first_phase!( + self.odd_points[bucket_index], + self.t[in_off], + self.t[in_off + 1], + acc + ); + in_off += 2; + } + _ => { /* clean sheets */ } + } + } + self.out_off = out_off; + self.in_off = in_off; + out_off -= 1; + in_off -= 2; + acc = acc.invert().unwrap(); + // process second phase + for bucket_index in (0..self.n_buckets).rev() { + let bucket_size_pre = self.bucket_sizes[bucket_index]; + let n_additions = bucket_size_pre / 2; + let bucket_size_post = (bucket_size_pre + 1) / 2; + + match (bucket_size_pre & 1 == 1, bucket_size_post & 1 == 1) { + (true, false) => { + out_off -= 1; + second_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + out_off -= 1; + in_off -= 2; + } + (false, true) => { + second_phase!( + self.odd_points[bucket_index], + self.t[in_off], + self.t[in_off + 1], + acc + ); + in_off -= 2; + } + _ => { /* clean sheets */ } + } + + for _ in 0..n_additions & (usize::MAX - 1) { + second_phase!(self.t[out_off], self.t[in_off], self.t[in_off + 1], acc); + out_off -= 1; + in_off -= 2; + } + self.bucket_sizes[bucket_index] = bucket_size_post; + } + } + + fn max_tree_height(&self) -> usize { + *self.tree_heights().iter().max().unwrap() + } + fn tree_heights(&self) -> Vec { + self.bucket_sizes + .iter() + .map(|bucket_size| log_ceil!(bucket_size) as usize) + .collect() + } + pub fn evaluate(&mut self) -> &[G1Affine] { + for _ in 0..self.max_tree_height() { + self.batch_add(); + } + &self.odd_points + } +} + +#[cfg(test)] +mod tests { + use super::Round; + use crate::bn256::msm::test::get_data; + use crate::bn256::{G1Affine, G1}; + use crate::group::Group; + use crate::CurveAffine; + use group::Curve; + use rand::seq::SliceRandom; + use rand::Rng; + use rand_core::OsRng; + + pub(crate) fn rand_positions( + rng: &mut impl Rng, + n_buckets: usize, + n_points: usize, + ) -> (Vec, Vec) { + let mut positions: Vec> = vec![vec![]; n_buckets]; + (0..n_points).for_each(|i| { + let bucket_index = rng.gen_range(0..n_buckets); + let is_neg: bool = rng.gen(); + let signed_index = if is_neg { + i as u32 | 0x80000000 + } else { + i as u32 + }; + positions[bucket_index].push(signed_index); + }); + positions.iter_mut().for_each(|positions| { + positions.shuffle(rng); + }); + let bucket_sizes = positions.iter().map(|positions| positions.len()).collect(); + let positions: Vec = positions.iter().flatten().cloned().collect(); + (positions, bucket_sizes) + } + + impl Round { + fn sanity_check(&self, bases: &[G1Affine], positions: &[u32], bucket_sizes: &[usize]) { + { + let mut off = 0; + let sums: Vec<_> = bucket_sizes + .iter() + .map(|bucket_size| { + let sum = (off..off + bucket_size) + .map(|i| { + let index = index!(positions[i]); + let is_neg = is_neg!(positions[i]); + if is_neg { + -bases[index] + } else { + bases[index] + } + }) + .fold(G1::identity(), |acc, next| acc + next); + off += bucket_size; + sum + }) + .collect(); + self.odd_points + .iter() + .for_each(|r| assert!(bool::from(r.is_on_curve()))); + sums.iter() + .zip(self.odd_points.iter()) + .for_each(|(r0, r1)| assert_eq!(r0.to_affine(), *r1)); + self.bucket_sizes + .iter() + .for_each(|bucket_size| assert_eq!(bucket_size & 1, *bucket_size)); + assert_eq!(self.max_tree_height(), 0); + self.tree_heights() + .iter() + .for_each(|tree_height| assert_eq!(tree_height, &0)); + } + } + } + + #[test] + fn test_round() { + let mut rng = OsRng; + + let n_points = 1 << 12; + let window_size = 5; + let n_buckets = (1 << window_size) - 1; + + let (bases, _) = get_data(n_points); + let (positions, bucket_sizes) = rand_positions(&mut rng, n_buckets, n_points); + let mut round = Round::new(n_buckets, n_points); + round.init(&bases[..n_points], &positions, &bucket_sizes); + round.evaluate(); + round.sanity_check(&bases, &positions, &bucket_sizes); + } +} diff --git a/src/derive/curve.rs b/src/derive/curve.rs index ba065fe2..866fe8e7 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -1,141 +1,48 @@ #[macro_export] -macro_rules! batch_add { - () => { - fn batch_add( - points: &mut [Self], - output_indices: &[u32], - num_points: usize, - offset: usize, - bases: &[Self], - base_positions: &[u32], - ) { - // assert!(Self::constant_a().is_zero()); - - let get_point = |point_data: u32| -> Self { - let negate = point_data & 0x80000000 != 0; - let base_idx = (point_data & 0x7FFFFFFF) as usize; - if negate { - bases[base_idx].neg() - } else { - bases[base_idx] - } +macro_rules! endo { + ($params:expr) => { + fn decompose_scalar(k: &Self::ScalarExt) -> (u128, bool, u128, bool) { + let to_limbs = |e: &Self::ScalarExt| { + let repr = e.to_repr(); + let repr = repr.as_ref(); + let tmp0 = u64::from_le_bytes(repr[0..8].try_into().unwrap()); + let tmp1 = u64::from_le_bytes(repr[8..16].try_into().unwrap()); + let tmp2 = u64::from_le_bytes(repr[16..24].try_into().unwrap()); + let tmp3 = u64::from_le_bytes(repr[24..32].try_into().unwrap()); + [tmp0, tmp1, tmp2, tmp3] }; - // Affine addition formula (P != Q): - // - lambda = (y_2 - y_1) / (x_2 - x_1) - // - x_3 = lambda^2 - (x_2 + x_1) - // - y_3 = lambda * (x_1 - x_3) - y_1 - - // Batch invert accumulator - let mut acc = Self::Base::one(); - - for i in (0..num_points).step_by(2) { - // Where that result of the point addition will be stored - let out_idx = output_indices[i >> 1] as usize - offset; - - #[cfg(all(feature = "prefetch", target_arch = "x86_64"))] - if i < num_points - 2 { - if LOAD_POINTS { - $crate::prefetch::(bases, base_positions[i + 2] as usize); - $crate::prefetch::(bases, base_positions[i + 3] as usize); - } - $crate::prefetch::( - points, - output_indices[(i >> 1) + 1] as usize - offset, - ); - } - if LOAD_POINTS { - points[i] = get_point(base_positions[i]); - points[i + 1] = get_point(base_positions[i + 1]); - } - - if COMPLETE { - // Nothing to do here if one of the points is zero - if (points[i].is_identity() | points[i + 1].is_identity()).into() { - continue; - } - - if points[i].x == points[i + 1].x { - if points[i].y == points[i + 1].y { - // Point doubling (P == Q) - // - s = (3 * x^2) / (2 * y) - // - x_2 = s^2 - (2 * x) - // - y_2 = s * (x - x_2) - y - - // (2 * x) - points[out_idx].x = points[i].x + points[i].x; - // x^2 - let xx = points[i].x.square(); - // (2 * y) - points[i + 1].x = points[i].y + points[i].y; - // (3 * x^2) * acc - points[i + 1].y = (xx + xx + xx) * acc; - // acc * (2 * y) - acc *= points[i + 1].x; - continue; - } else { - // Zero - points[i] = Self::identity(); - points[i + 1] = Self::identity(); - continue; - } - } - } - - // (x_2 + x_1) - points[out_idx].x = points[i].x + points[i + 1].x; - // (x_2 - x_1) - points[i + 1].x -= points[i].x; - // (y2 - y1) * acc - points[i + 1].y = (points[i + 1].y - points[i].y) * acc; - // acc * (x_2 - x_1) - acc *= points[i + 1].x; - } - - // Batch invert - if COMPLETE { - if (!acc.is_zero()).into() { - acc = acc.invert().unwrap(); - } - } else { - acc = acc.invert().unwrap(); - } - - for i in (0..num_points).step_by(2).rev() { - // Where that result of the point addition will be stored - let out_idx = output_indices[i >> 1] as usize - offset; - - #[cfg(all(feature = "prefetch", target_arch = "x86_64"))] - if i > 0 { - $crate::prefetch::( - points, - output_indices[(i >> 1) - 1] as usize - offset, - ); - } + let get_lower_128 = |e: &Self::ScalarExt| { + let e = to_limbs(e); + u128::from(e[0]) | (u128::from(e[1]) << 64) + }; - if COMPLETE { - // points[i] is zero so the sum is points[i + 1] - if points[i].is_identity().into() { - points[out_idx] = points[i + 1]; - continue; - } - // points[i + 1] is zero so the sum is points[i] - if points[i + 1].is_identity().into() { - points[out_idx] = points[i]; - continue; - } - } + let is_neg = |e: &Self::ScalarExt| { + let e = to_limbs(e); + let (_, borrow) = sbb(0xffffffffffffffff, e[0], 0); + let (_, borrow) = sbb(0xffffffffffffffff, e[1], borrow); + let (_, borrow) = sbb(0xffffffffffffffff, e[2], borrow); + let (_, borrow) = sbb(0x00, e[3], borrow); + borrow & 1 != 0 + }; - // lambda - points[i + 1].y *= acc; - // acc * (x_2 - x_1) - acc *= points[i + 1].x; - // x_3 = lambda^2 - (x_2 + x_1) - points[out_idx].x = points[i + 1].y.square() - points[out_idx].x; - // y_3 = lambda * (x_1 - x_3) - y_1 - points[out_idx].y = - points[i + 1].y * (points[i].x - points[out_idx].x) - points[i].y; - } + let input = to_limbs(&k); + let c1 = mul_512($params.gamma2, input); + let c2 = mul_512($params.gamma1, input); + let c1 = [c1[4], c1[5], c1[6], c1[7]]; + let c2 = [c2[4], c2[5], c2[6], c2[7]]; + let q1 = mul_512(c1, $params.b1); + let q2 = mul_512(c2, $params.b2); + let q1 = Self::ScalarExt::from_raw([q1[0], q1[1], q1[2], q1[3]]); + let q2 = Self::ScalarExt::from_raw([q2[0], q2[1], q2[2], q2[3]]); + let k2 = q2 - q1; + let k1 = k + k2 * Self::ScalarExt::ZETA; + let k1_neg = is_neg(&k1); + let k2_neg = is_neg(&k2); + let k1 = if k1_neg { -k1 } else { k1 }; + let k2 = if k2_neg { -k2 } else { k2 }; + + (get_lower_128(&k1), k1_neg, get_lower_128(&k2), k2_neg) } }; } diff --git a/src/pasta/mod.rs b/src/pasta/mod.rs index f6aee547..6831372c 100644 --- a/src/pasta/mod.rs +++ b/src/pasta/mod.rs @@ -1,27 +1,24 @@ +use crate::arithmetic::mul_512; +use crate::arithmetic::sbb; +use crate::{ + arithmetic::{CurveEndo, EndoParameters}, + endo, +}; +use ff::PrimeField; +use ff::WithSmallOrderMulGroup; pub use pasta_curves::{pallas, vesta, Ep, EpAffine, Eq, EqAffine, Fp, Fq}; +use std::convert::TryInto; -impl crate::CurveAffineExt for EpAffine { - fn batch_add( - _: &mut [Self], - _: &[u32], - _: usize, - _: usize, - _: &[Self], - _: &[u32], - ) { - unimplemented!(); - } -} +const ENDO_PARAMS_EQ: EndoParameters = EndoParameters { + gamma1: [0x32c49e4c00000003, 0x279a745902a2654e, 0x1, 0x0], + gamma2: [0x31f0256800000002, 0x4f34e8b2066389a4, 0x2, 0x0], + b1: [0x8cb1279300000001, 0x49e69d1640a89953, 0x0, 0x0], + b2: [0x0c7c095a00000001, 0x93cd3a2c8198e269, 0x0, 0x0], +}; -impl crate::CurveAffineExt for EqAffine { - fn batch_add( - _: &mut [Self], - _: &[u32], - _: usize, - _: usize, - _: &[Self], - _: &[u32], - ) { - unimplemented!(); - } -} +const ENDO_PARAMS_EP: EndoParameters = EndoParameters { + gamma1: [0x32c49e4bffffffff, 0x279a745902a2654e, 0x1, 0x0], + gamma2: [0x31f0256800000002, 0x4f34e8b2066389a4, 0x2, 0x0], + b1: [0x8cb1279300000000, 0x49e69d1640a89953, 0x0, 0x0], + b2: [0x0c7c095a00000001, 0x93cd3a2c8198e269, 0x0, 0x0], +}; diff --git a/src/secp256k1/curve.rs b/src/secp256k1/curve.rs index 89c197b5..93530d80 100644 --- a/src/secp256k1/curve.rs +++ b/src/secp256k1/curve.rs @@ -1,6 +1,6 @@ use crate::secp256k1::Fp; use crate::secp256k1::Fq; -use crate::{Coordinates, CurveAffine, CurveAffineExt, CurveExt}; +use crate::{Coordinates, CurveAffine, CurveExt}; use core::cmp; use core::fmt::Debug; use core::iter::Sum; @@ -44,9 +44,9 @@ const SECP_GENERATOR_Y: Fp = Fp::from_raw([ const SECP_B: Fp = Fp::from_raw([7, 0, 0, 0]); use crate::{ - batch_add, impl_add_binop_specify_output, impl_binops_additive, - impl_binops_additive_specify_output, impl_binops_multiplicative, - impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, new_curve_impl, + impl_add_binop_specify_output, impl_binops_additive, impl_binops_additive_specify_output, + impl_binops_multiplicative, impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, + new_curve_impl, }; new_curve_impl!( @@ -61,14 +61,6 @@ new_curve_impl!( "secp256k1", ); -impl CurveAffineExt for Secp256k1Affine { - batch_add!(); - - fn into_coordinates(self) -> (Self::Base, Self::Base) { - (self.x, self.y) - } -} - #[test] fn test_curve() { crate::tests::curve::curve_tests::();