Skip to content

Commit

Permalink
fix(ec): non canonical endomorphism acceleration (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim authored Jun 30, 2024
1 parent f0d5d2f commit dcc9310
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 70 deletions.
18 changes: 9 additions & 9 deletions constantine.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -415,21 +415,21 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
("tests/math_elliptic_curves/t_ec_shortw_jac_g1_add_double.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_sanity.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_distri.nim", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_vs_ref.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_vs_ref.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mixed_add.nim", false),

("tests/math_elliptic_curves/t_ec_shortw_jacext_g1_add_double.nim", false),
("tests/math_elliptic_curves/t_ec_shortw_jacext_g1_mixed_add.nim", false),

# ("tests/math_elliptic_curves/t_ec_twedw_prj_add_double", false),
("tests/math_elliptic_curves/t_ec_twedw_prj_mul_sanity", false),
# ("tests/math_elliptic_curves/t_ec_twedw_prj_mul_sanity", false),
("tests/math_elliptic_curves/t_ec_twedw_prj_mul_distri", false),

# ("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_endomorphism_bls12_381", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_endomorphism_bls12_381", false),
# ("tests/math_elliptic_curves/t_ec_shortw_prj_g1_mul_endomorphism_bls12_381", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_endomorphism_bn254_snarks", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g1_mul_endomorphism_bn254_snarks", false),
# ("tests/math_elliptic_curves/t_ec_shortw_prj_g1_mul_endomorphism_bn254_snarks", false),
# ("tests/math_elliptic_curves/t_ec_twedwards_mul_endomorphism_bandersnatch", false),
("tests/math_elliptic_curves/t_ec_twedwards_mul_endomorphism_bandersnatch", false),


# Elliptic curve arithmetic G2
Expand Down Expand Up @@ -461,13 +461,13 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_add_double_bn254_snarks.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_sanity_bn254_snarks.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_distri_bn254_snarks.nim", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_vs_ref_bn254_snarks.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_vs_ref_bn254_snarks.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mixed_add_bn254_snarks.nim", false),

# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_add_double_bls12_381.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_sanity_bls12_381.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_distri_bls12_381.nim", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_vs_ref_bls12_381.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_vs_ref_bls12_381.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mixed_add_bls12_381.nim", false),

# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_add_double_bls12_377.nim", false),
Expand All @@ -482,9 +482,9 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_vs_ref_bw6_761.nim", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mixed_add_bw6_761.nim", false),

# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_endomorphism_bls12_381", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_endomorphism_bls12_381", false),
# ("tests/math_elliptic_curves/t_ec_shortw_prj_g2_mul_endomorphism_bls12_381", false),
# ("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_endomorphism_bn254_snarks", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g2_mul_endomorphism_bn254_snarks", false),
# ("tests/math_elliptic_curves/t_ec_shortw_prj_g2_mul_endomorphism_bn254_snarks", false),

# Elliptic curve arithmetic vs Sagemath
Expand Down
26 changes: 14 additions & 12 deletions constantine/math/elliptic/ec_endomorphism_accel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,26 @@ template decomposeEndoImpl[scalBits: static int](
copyMiniScalarsResult: untyped) =
static: doAssert scalBits >= L, "Cannot decompose a scalar smaller than a mini-scalar or the decomposition coefficient"
# Equal when no window or no negative handling, greater otherwise
static: doAssert L >= ceilDiv_vartime(scalBits, M) + 1
const w = Fr[F.Name].bits().wordsRequired()
const frBits = Fr[F.Name].bits()
static: doAssert frBits >= scalBits
static: doAssert L >= ceilDiv_vartime(frBits, M) + 1
const w = frBits.wordsRequired()

# Upstream bug:
# {.noInit.} variables must be {.inject.} as well
# or they'll be mangled as foo`gensym12345 instead of fooX60gensym12345 in C codegen

when M == 2:
var alphas{.noInit, inject.}: (
BigInt[scalBits + babai(F)[0][0].bits],
BigInt[scalBits + babai(F)[1][0].bits]
BigInt[frBits + babai(F)[0][0].bits],
BigInt[frBits + babai(F)[1][0].bits]
)
elif M == 4:
var alphas{.noInit, inject.}: (
BigInt[scalBits + babai(F)[0][0].bits],
BigInt[scalBits + babai(F)[1][0].bits],
BigInt[scalBits + babai(F)[2][0].bits],
BigInt[scalBits + babai(F)[3][0].bits]
BigInt[frBits + babai(F)[0][0].bits],
BigInt[frBits + babai(F)[1][0].bits],
BigInt[frBits + babai(F)[2][0].bits],
BigInt[frBits + babai(F)[3][0].bits]
)
else:
{.error: "The decomposition degree " & $M & " is not configured".}
Expand All @@ -79,9 +81,9 @@ template decomposeEndoImpl[scalBits: static int](
# We have k0 = s - 𝛼0 b00 - 𝛼1 b10 ... - 𝛼m bm0
# and kj = 0 - 𝛼j b0j - 𝛼1 b1j ... - 𝛼m bmj
var
k {.inject.}: array[M, BigInt[scalBits]] # zero-init required
alphaB {.noInit, inject.}: BigInt[scalBits]
k[0] = scalar
k {.inject.}: array[M, BigInt[frBits]] # zero-init required
alphaB {.noInit, inject.}: BigInt[frBits]
k[0].copyTruncatedFrom(scalar)
staticFor miniScalarIdx, 0, M:
staticFor basisIdx, 0, M:
when not bool lattice(F)[basisIdx][miniScalarIdx][0].isZero():
Expand Down Expand Up @@ -342,7 +344,7 @@ func scalarMulEndo*[scalBits; EC](
endos.computeEndomorphisms(P)

# 2. Decompose scalar into mini-scalars
const L = scalBits.ceilDiv_vartime(M) + 1
const L = EC.getScalarField().bits().ceilDiv_vartime(M) + 1
var miniScalars {.noInit.}: array[M, BigInt[L]]
var negatePoints {.noInit.}: array[M, SecretBool]
miniScalars.decomposeEndo(negatePoints, scalar, P.F)
Expand Down
10 changes: 8 additions & 2 deletions constantine/math/elliptic/ec_multi_scalar_mul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ proc applyEndomorphism[bits: static int, ECaff](
elif ECaff.F is Fp2: 4
else: {.error: "Unconfigured".}

const L = bits.ceilDiv_vartime(M) + 1
const L = ECaff.getScalarField().bits().ceilDiv_vartime(M) + 1
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
let endoBasis = allocHeapArray(array[M, ECaff], N)

Expand Down Expand Up @@ -429,7 +429,13 @@ template withEndo[coefsBits: static int, EC, ECaff](
coefs: ptr UncheckedArray[BigInt[coefsBits]],
points: ptr UncheckedArray[ECaff],
N: int, c: static int) =
when coefsBits <= EC.getScalarField().bits() and hasEndomorphismAcceleration(EC.F.Name):
when hasEndomorphismAcceleration(EC.F.Name) and
EndomorphismThreshold <= coefsBits and
coefsBits <= EC.getScalarField().bits() and
# computeEndomorphism assumes they can be applied to affine repr
# but this is not the case for Bandersnatch/wagon
# instead Twisted Edwards MSM should be overloaded for Projective/ProjectiveExtended
EC.F.Name notin {Bandersnatch, Banderwagon}:
let (endoCoefs, endoPoints, endoN) = applyEndomorphism(coefs, points, N)
# Given that bits and N changed, we are able to use a bigger `c`
# but it has no significant impact on performance
Expand Down
10 changes: 8 additions & 2 deletions constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ proc applyEndomorphism_parallel[bits: static int, ECaff](
elif ECaff.F is Fp2: 4
else: {.error: "Unconfigured".}

const L = bits.ceilDiv_vartime(M) + 1
const L = ECaff.getScalarField().bits().ceilDiv_vartime(M) + 1
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
let endoBasis = allocHeapArray(array[M, ECaff], N)

Expand Down Expand Up @@ -495,7 +495,13 @@ template withEndo[coefsBits: static int, EC, ECaff](
coefs: ptr UncheckedArray[BigInt[coefsBits]],
points: ptr UncheckedArray[ECaff],
N: int, c: static int) =
when coefsBits <= EC.getScalarField().bits() and hasEndomorphismAcceleration(EC.F.Name):
when hasEndomorphismAcceleration(EC.F.Name) and
EndomorphismThreshold <= coefsBits and
coefsBits <= EC.getScalarField().bits() and
# computeEndomorphism assumes they can be applied to affine repr
# but this is not the case for Bandersnatch/wagon
# instead Twisted Edwards MSM should be overloaded for Projective/ProjectiveExtended
EC.F.Name notin {Bandersnatch, Banderwagon}:
let (endoCoefs, endoPoints, endoN) = applyEndomorphism_parallel(tp, coefs, points, N)
# Given that bits and N changed, we are able to use a bigger `c`
# but it has no significant impact on performance
Expand Down
5 changes: 2 additions & 3 deletions constantine/math/elliptic/ec_scalar_mul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ func scalarMul*[EC](P: var EC, scalar: BigInt) {.inline, meter.} =
## - Cofactor to be cleared
## - 0 <= scalar < curve order
## Those will be assumed to maintain constant-time property
when BigInt.bits <= EC.getScalarField().bits() and
EC.F.Name.hasEndomorphismAcceleration():
# TODO, min amount of bits for endomorphisms?
when EC.F.Name.hasEndomorphismAcceleration() and
BigInt.bits >= EndomorphismThreshold:
when EC.F is Fp:
P.scalarMulGLV_m2w2(scalar)
elif EC.F is Fp2:
Expand Down
22 changes: 11 additions & 11 deletions constantine/math/elliptic/ec_scalar_mul_vartime.nim
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func scalarMulEndo_minHammingWeight_windowed_vartime*[scalBits: static int; EC](
endos.computeEndomorphisms(P)

# 2. Decompose scalar into mini-scalars
const L = scalBits.ceilDiv_vartime(M) + 1
const L = EC.getScalarField().bits().ceilDiv_vartime(M) + 1
var miniScalars {.noInit.}: array[M, BigInt[L]]
var negatePoints {.noInit.}: array[M, SecretBool]
miniScalars.decomposeEndo(negatePoints, scalar, EC.F)
Expand Down Expand Up @@ -340,16 +340,16 @@ func scalarMul_vartime*[scalBits; EC](P: var EC, scalar: BigInt[scalBits]) {.met

let usedBits = scalar.limbs.getBits_LE_vartime()

when scalBits == EC.getScalarField().bits() and
EC.F.Name.hasEndomorphismAcceleration():
if usedBits >= L:
when EC.F is Fp:
P.scalarMulEndo_minHammingWeight_windowed_vartime(scalar, window = 4)
elif EC.F is Fp2:
P.scalarMulEndo_minHammingWeight_windowed_vartime(scalar, window = 3)
else: # Curves defined on Fp^m with m > 2
{.error: "Unreachable".}
return
when EC.F.Name.hasEndomorphismAcceleration():
when scalBits >= EndomorphismThreshold: # Skip static: doAssert when multiplying by intentionally small scalars.
if usedBits >= EndomorphismThreshold:
when EC.F is Fp:
P.scalarMulEndo_minHammingWeight_windowed_vartime(scalar, window = 4)
elif EC.F is Fp2:
P.scalarMulEndo_minHammingWeight_windowed_vartime(scalar, window = 3)
else: # Curves defined on Fp^m with m > 2
{.error: "Unreachable".}
return

if 64 < usedBits:
# With a window of 5, we precompute 2^3 = 8 points
Expand Down
11 changes: 5 additions & 6 deletions constantine/named/zoo_endomorphisms.nim
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ func computeEndomorphisms*[EC; M: static int](endos: var array[M-1, EC], P: EC)

func hasEndomorphismAcceleration*(Name: static Algebra): bool {.compileTime.} =
Name in {
# TODO: MSM assumes that endomorphism can be computed with affine coordinates
# Bandersnatch,
# Banderwagon,
Bandersnatch,
Banderwagon,
BN254_Nogami,
BN254_Snarks,
BLS12_377,
Expand All @@ -117,9 +116,9 @@ func hasEndomorphismAcceleration*(Name: static Algebra): bool {.compileTime.} =
Vesta
}

const EndomorphismThreshold* = 196
const EndomorphismThreshold* = 192
## We use substraction by maximum infinity norm coefficient
## to split scalars for endomorphisms
## For small scalars the substraction will overflow
##
## TODO: implement an alternative way to split scalars.
## TODO: explore an alternative way to split scalars, for example via division
## https://github.com/mratsim/constantine/issues/347
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import
# Internals
../../constantine/math/config/[type_ff, curves],
../../constantine/math/ec_shortweierstrass,
constantine/named/algebras,
constantine/math/ec_shortweierstrass,
# Test utilities
./t_ec_template

Expand All @@ -18,7 +18,7 @@ const
ItersMul = Iters div 4

run_EC_mul_endomorphism_impl(
ec = ECP_ShortW_Jac[Fp[BLS12_381], G1],
ec = EC_ShortW_Jac[Fp[BLS12_381], G1],
ItersMul = ItersMul,
moduleName = "test_ec_shortw_jac_g1_mul_endomorphism_" & $BLS12_381
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import
# Internals
../../constantine/math/config/[type_ff, curves],
../../constantine/math/ec_shortweierstrass,
constantine/named/algebras,
constantine/math/ec_shortweierstrass,
# Test utilities
./t_ec_template

Expand All @@ -18,7 +18,7 @@ const
ItersMul = Iters div 4

run_EC_mul_endomorphism_impl(
ec = ECP_ShortW_Jac[Fp[BN254_Snarks], G1],
ec = EC_ShortW_Jac[Fp[BN254_Snarks], G1],
ItersMul = ItersMul,
moduleName = "test_ec_shortw_jac_g1_mul_endomorphism_" & $BN254_Snarks
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import
# Internals
../../constantine/math/config/[type_ff, curves],
../../constantine/math/ec_shortweierstrass,
constantine/named/algebras,
constantine/math/ec_shortweierstrass,
constantine/math/extension_fields,
# Test utilities
./t_ec_template

Expand All @@ -18,7 +19,7 @@ const
ItersMul = Iters div 4

run_EC_mul_endomorphism_impl(
ec = ECP_ShortW_Jac[Fp2[BLS12_381], G2],
ec = EC_ShortW_Jac[Fp2[BLS12_381], G2],
ItersMul = ItersMul,
moduleName = "test_ec_shortw_jac_g2_mul_endomorphism_" & $BLS12_381
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import
# Internals
../../constantine/math/config/[type_ff, curves],
../../constantine/math/ec_shortweierstrass,
constantine/named/algebras,
constantine/math/ec_shortweierstrass,
constantine/math/extension_fields,
# Test utilities
./t_ec_template

Expand All @@ -18,7 +19,7 @@ const
ItersMul = Iters div 4

run_EC_mul_endomorphism_impl(
ec = ECP_ShortW_Jac[Fp2[BN254_Snarks], G2],
ec = EC_ShortW_Jac[Fp2[BN254_Snarks], G2],
ItersMul = ItersMul,
moduleName = "test_ec_shortw_jac_g2_mul_endomorphism_" & $BN254_Snarks
)
19 changes: 6 additions & 13 deletions tests/math_elliptic_curves/t_ec_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -809,19 +809,12 @@ proc run_EC_mul_endomorphism_impl*(
test(ec, bits = ec.getScalarField().bits() - 4, randZ = false, gen = Long01Sequence)
test(ec, bits = ec.getScalarField().bits() - 4, randZ = true, gen = Long01Sequence)

test(ec, bits = ec.getScalarField().bits() div 2, randZ = false, gen = Uniform)
test(ec, bits = ec.getScalarField().bits() div 2, randZ = true, gen = Uniform)
test(ec, bits = ec.getScalarField().bits() div 2, randZ = false, gen = HighHammingWeight)
test(ec, bits = ec.getScalarField().bits() div 2, randZ = true, gen = HighHammingWeight)
test(ec, bits = ec.getScalarField().bits() div 2, randZ = false, gen = Long01Sequence)
test(ec, bits = ec.getScalarField().bits() div 2, randZ = true, gen = Long01Sequence)

test(ec, bits = ec.getScalarField().bits() div 4, randZ = false, gen = Uniform)
test(ec, bits = ec.getScalarField().bits() div 4, randZ = true, gen = Uniform)
test(ec, bits = ec.getScalarField().bits() div 4, randZ = false, gen = HighHammingWeight)
test(ec, bits = ec.getScalarField().bits() div 4, randZ = true, gen = HighHammingWeight)
test(ec, bits = ec.getScalarField().bits() div 4, randZ = false, gen = Long01Sequence)
test(ec, bits = ec.getScalarField().bits() div 4, randZ = true, gen = Long01Sequence)
test(ec, bits = EndomorphismThreshold, randZ = false, gen = Uniform)
test(ec, bits = EndomorphismThreshold, randZ = true, gen = Uniform)
test(ec, bits = EndomorphismThreshold, randZ = false, gen = HighHammingWeight)
test(ec, bits = EndomorphismThreshold, randZ = true, gen = HighHammingWeight)
test(ec, bits = EndomorphismThreshold, randZ = false, gen = Long01Sequence)
test(ec, bits = EndomorphismThreshold, randZ = true, gen = Long01Sequence)

proc run_EC_mixed_add_impl*(
ec: typedesc,
Expand Down

0 comments on commit dcc9310

Please sign in to comment.