diff --git a/constantine.nimble b/constantine.nimble index 18410246..6c5404b3 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -489,6 +489,7 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[ # Edge cases highlighted by past bugs # ---------------------------------------------------------- ("tests/math_elliptic_curves/t_ec_shortw_prj_edge_cases.nim", false), + ("tests/math_elliptic_curves/t_ec_shortw_prj_edge_case_345.nim", false), # Elliptic curve arithmetic - batch computation # ---------------------------------------------------------- diff --git a/constantine/math/elliptic/ec_endomorphism_accel.nim b/constantine/math/elliptic/ec_endomorphism_accel.nim index 64becf15..fff009c6 100644 --- a/constantine/math/elliptic/ec_endomorphism_accel.nim +++ b/constantine/math/elliptic/ec_endomorphism_accel.nim @@ -38,39 +38,26 @@ type MultiScalar[M, LengthInBits: static int] = array[M, BigInt[LengthInBits]] ## Decomposition of a secret scalar in multiple scalars -func decomposeEndo*[M, scalBits, L: static int]( - miniScalars: var MultiScalar[M, L], - negatePoints: var array[M, SecretBool], +template decomposeEndoImpl[scalBits: static int]( scalar: BigInt[scalBits], - F: typedesc[Fp or Fp2] - ) = - ## Decompose a secret scalar into M mini-scalars - ## using a curve endomorphism(s) characteristics. - ## - ## A scalar decomposition might lead to negative miniscalar(s). - ## For proper handling it requires either: - ## 1. Negating it and then negating the corresponding curve point P - ## 2. Adding an extra bit to the recoding, which will do the right thing™ - ## - ## For implementation solution 1 is faster: - ## - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381) - ## - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles - ## We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles - ## and negate it as well. - + F: typedesc[Fp or Fp2], + copyMiniScalarsResult: untyped) = static: doAssert scalBits >= L, "Cannot decompose a scalar smaller than a mini-scalar or the decomposition coefficient" - # Equal when no window or no negative handling, greater otherwise - static: doAssert L >= scalBits.ceilDiv_vartime(M) + 1 + static: doAssert L >= ceilDiv_vartime(scalBits, M) + 1 const w = F.C.getCurveOrderBitwidth().wordsRequired() + # Upstream bug: + # {.noInit.} variables must be {.inject.} as well + # or they'll be mangled as foo`gensym12345 instead of fooX60gensym12345 in C codegen + when M == 2: - var alphas{.noInit.}: ( + var alphas{.noInit, inject.}: ( BigInt[scalBits + babai(F)[0][0].bits], BigInt[scalBits + babai(F)[1][0].bits] ) elif M == 4: - var alphas{.noInit.}: ( + var alphas{.noInit, inject.}: ( BigInt[scalBits + babai(F)[0][0].bits], BigInt[scalBits + babai(F)[1][0].bits], BigInt[scalBits + babai(F)[2][0].bits], @@ -84,16 +71,12 @@ func decomposeEndo*[M, scalBits, L: static int]( alphas[i].setZero() else: alphas[i].prod_high_words(babai(F)[i][0], scalar, w) - when babai(F)[i][1]: - # prod_high_words works like logical right shift - # When negative, we should add 1 to properly round toward -infinity - alphas[i] += One # We have k0 = s - 𝛼0 b00 - 𝛼1 b10 ... - 𝛼m bm0 # and kj = 0 - 𝛼j b0j - 𝛼1 b1j ... - 𝛼m bmj var - k: array[M, BigInt[scalBits]] # zero-init required - alphaB {.noInit.}: BigInt[scalBits] + k {.inject.}: array[M, BigInt[scalBits]] # zero-init required + alphaB {.noInit, inject.}: BigInt[scalBits] k[0] = scalar staticFor miniScalarIdx, 0, M: staticFor basisIdx, 0, M: @@ -108,11 +91,62 @@ func decomposeEndo*[M, scalBits, L: static int]( else: k[miniScalarIdx] -= alphaB + copyMiniScalarsResult + +func decomposeEndo*[M, scalBits, L: static int]( + miniScalars: var MultiScalar[M, L], + negatePoints: var array[M, SecretBool], + scalar: BigInt[scalBits], + F: typedesc[Fp or Fp2]) = + ## Decompose a secret scalar into M mini-scalars + ## using a curve endomorphism(s) characteristics. + ## + ## A scalar decomposition might lead to negative miniscalar(s). + ## For proper handling it requires either: + ## 1. Negating it and then negating the corresponding curve point P + ## 2. Adding an extra bit to the recoding, which will do the right thing™ + ## + ## For implementation solution 1 is faster: + ## - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381) + ## - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles + ## We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles + ## and negate it as well. + ## + ## This implements solution 1. + decomposeEndoImpl(scalar, F): + # Negative miniscalars are turned positive + # Caller should negate the corresponding Elliptic Curve points let isNeg = k[miniScalarIdx].isMsbSet() negatePoints[miniScalarIdx] = isNeg k[miniScalarIdx].cneg(isNeg) miniScalars[miniScalarIdx].copyTruncatedFrom(k[miniScalarIdx]) +func decomposeEndo*[M, scalBits, L: static int]( + miniScalars: var MultiScalar[M, L], + scalar: BigInt[scalBits], + F: typedesc[Fp or Fp2]) = + ## Decompose a secret scalar into M mini-scalars + ## using a curve endomorphism(s) characteristics. + ## + ## A scalar decomposition might lead to negative miniscalar(s). + ## For proper handling it requires either: + ## 1. Negating it and then negating the corresponding curve point P + ## 2. Adding an extra bit to the recoding, which will do the right thing™ + ## + ## For implementation solution 1 is faster: + ## - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381) + ## - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles + ## We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles + ## and negate it as well. + ## + ## However, when dealing with scalars that do not use the full bitwidth + ## the extra bit avoids potential underflows. + ## Also for partitioned GLV-SAC (with 8-way decomposition) it is necessary. + ## + ## This implements solution 2. + decomposeEndoImpl(scalar, F): + miniScalars[miniScalarIdx].copyTruncatedFrom(k[miniScalarIdx]) + # Secret scalar + dynamic point # ---------------------------------------------------------------- # @@ -184,8 +218,7 @@ proc `[]=`(recoding: var Recoded, func nDimMultiScalarRecoding[M, L: static int]( dst: var GLV_SAC[M, L], - src: MultiScalar[M, L] - ) = + src: MultiScalar[M, L]) = ## This recodes N scalar for GLV multi-scalar multiplication ## with side-channel resistance. ## @@ -316,7 +349,7 @@ func scalarMulEndo*[scalBits; EC]( {.error: "Unconfigured".} # 2. Decompose scalar into mini-scalars - const L = scalBits.ceilDiv_vartime(M) + 1 # Alternatively, negative can be handled with an extra "+1" + const L = scalBits.ceilDiv_vartime(M) + 1 var miniScalars {.noInit.}: array[M, BigInt[L]] var negatePoints {.noInit.}: array[M, SecretBool] miniScalars.decomposeEndo(negatePoints, scalar, P.F) @@ -325,13 +358,7 @@ func scalarMulEndo*[scalBits; EC]( # A scalar decomposition might lead to negative miniscalar. # For proper handling it requires either: # 1. Negating it and then negating the corresponding curve point P - # 2. Adding an extra bit to the recoding, which will do the right thing™ - # - # For implementation solution 1 is faster: - # - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381) - # - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles - # We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles - # and negate it as well. + # 2. Adding an extra bit to L for the recoding, which will do the right thing™ block: P.cneg(negatePoints[0]) staticFor i, 1, M: @@ -401,8 +428,7 @@ func scalarMulEndo*[scalBits; EC]( func buildLookupTable_m2w2[EC, Ecaff]( P0: EC, P1: EC, - lut: var array[8, Ecaff], - ) = + lut: var array[8, Ecaff]) = ## Build a lookup table for GLV with 2-dimensional decomposition ## and window of size 2 @@ -473,10 +499,7 @@ func computeRecodedLength(bitWidth, window: int): int = let lw = bitWidth.ceilDiv_vartime(window) + 1 result = (lw mod window) + lw -func scalarMulGLV_m2w2*[scalBits; EC]( - P0: var EC, - scalar: BigInt[scalBits] - ) = +func scalarMulGLV_m2w2*[scalBits; EC](P0: var EC, scalar: BigInt[scalBits]) = ## Elliptic Curve Scalar Multiplication ## ## P <- [k] P diff --git a/constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim b/constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim index fe1c8700..8fb62bd2 100644 --- a/constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim +++ b/constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim @@ -329,7 +329,7 @@ proc msmAffine_vartime_parallel[bits: static int, EC, ECaff]( # Prologue # -------- - const numBuckets = 1 shl (c-1) + const numBuckets {.used.} = 1 shl (c-1) const numFullWindows = bits div c const numWindows = numFullWindows + 1 # Even if `bits div c` is exact, the signed recoding needs to see an extra 0 after the MSB diff --git a/tests/math_elliptic_curves/t_ec_sage_template.nim b/tests/math_elliptic_curves/t_ec_sage_template.nim index 121ca9b5..7d155cf9 100644 --- a/tests/math_elliptic_curves/t_ec_sage_template.nim +++ b/tests/math_elliptic_curves/t_ec_sage_template.nim @@ -146,7 +146,7 @@ proc run_scalar_mul_test_vs_sage*( const coord = when EC is ECP_ShortW_Prj: " Projective coordinates " elif EC is ECP_ShortW_Jac: " Jacobian coordinates " - const testSuiteDesc = "Scalar Multiplication " & $EC.F.C & " " & G1_or_G2 & " vs SageMath - " & $bits & "-bit scalar" + const testSuiteDesc = "Scalar Multiplication " & $EC.F.C & " " & G1_or_G2 & " " & coord & " vs SageMath - " & $bits & "-bit scalar" suite testSuiteDesc & " [" & $WordBitWidth & "-bit words]": for i in 0 ..< vec.vectors.len: diff --git a/tests/math_elliptic_curves/t_ec_shortw_prj_edge_case_345.nim b/tests/math_elliptic_curves/t_ec_shortw_prj_edge_case_345.nim new file mode 100644 index 00000000..9761e908 --- /dev/null +++ b/tests/math_elliptic_curves/t_ec_shortw_prj_edge_case_345.nim @@ -0,0 +1,181 @@ +# https://github.com/mratsim/constantine/issues/345 + +import ../../constantine/math/arithmetic +import ../../constantine/math/io/io_fields +import ../../constantine/math/io/io_bigints +import ../../constantine/math/config/curves +import ../../constantine/math/extension_fields/towers +import ../../constantine/math/elliptic/ec_shortweierstrass_affine +import ../../constantine/math/elliptic/ec_shortweierstrass_projective +import ../../constantine/math/elliptic/ec_scalar_mul +import ../../constantine/math/elliptic/ec_scalar_mul_vartime + +#------------------------------------------------------------------------------- + +type B = BigInt[254] +type F = Fp[BN254Snarks] +type F2 = QuadraticExt[F] +type G = ECP_ShortW_Prj[F2, G2] + +#------------------------------------------------------------------------------- + +# size of the scalar field +let r : B = fromHex( B ,"0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001" ) + +let expo : B = fromHex( B, "0x7b17fcc286b01af79176aa7da3a8615020eacda89a90e4ff5d0a085483f0448" ) + +let expoA_fr = fromHex( Fr[BN254Snarks],"0x1234567890123456789001234567890" ) +var expoB_fr = fromHex( Fr[BN254Snarks],"0x7b17fcc286b01af79176aa7da3a8615020eacda89a90e4ff5d0a085483f0448" ) +expoB_fr -= expoA_fr + +let expoA = expoA_fr.toBig() +let expoB = expoB_fr.toBig() + +# debugEcho "expo:" , expo.toHex() + +let zeroF : F = fromHex( F, "0x00" ) +let oneF : F = fromHex( F, "0x01" ) + +#------------------------------------------------------------------------------- + +# standard generator of G2 + +let gen2_xi : F = fromHex( F, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701" ) +let gen2_xu : F = fromHex( F, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b" ) +let gen2_yi : F = fromHex( F, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8" ) +let gen2_yu : F = fromHex( F, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c" ) + +let gen2_x : F2 = F2( coords: [gen2_xi, gen2_xu] ) +let gen2_y : F2 = F2( coords: [gen2_yi, gen2_yu] ) +let gen2_z : F2 = F2( coords: [oneF , zeroF ] ) + +let gen2 : G = G( x: gen2_x, y: gen2_y, z: gen2_z ) + +#------------------------------------------------------------------------------- + +template echo(intercept: untyped) = + # This intercepts system.echo + # Delete this template to debug intermediate steps + discard + +proc printF( x: F ) = + echo(" = " & x.toDecimal) + +proc printF2( z: F2) = + echo(" 1 ~> " & z.coords[0].toDecimal ) + echo(" u ~> " & z.coords[1].toDecimal ) + + +proc printG( pt: G ) = + var aff : ECP_ShortW_Aff[F2, G2]; + aff.affine(pt) + echo(" affine x coord: "); printF2( aff.x ) + echo(" affine y coord: "); printF2( aff.y ) + +#------------------------------------------------------------------------------- + +template test(scalarProc: untyped) = + proc `test _ scalarProc`() = + var p : G + var q : G + + echo("") + echo("sanity check: g2^r should be infinity") + p = gen2 + p.scalarProc(r) + printG(p) + + echo("") + echo("LHS = g2^expo") + p = gen2 + p.scalarProc(expo) + printG(p) + let lhs : G = p + + echo("") + echo("RHS = g2^expoA * g2^expoB, where expo = expoA + expoB") + p = gen2 + q = gen2 + p.scalarProc(expoA) + q.scalarProc(expoB) + p += q + printG(p) + let rhs : G = p + + echo("") + echo("reference from SageMath") + echo(" sage x coord:") + echo(" 1 -> 17216390949661727229956939928583223226083668728437958793715435751523027888005 ") + echo(" u -> 3082945034329785101034278215941854680789766318859358488904629243958221738137 ") + echo(" sage y coord:") + echo(" 1 -> 20108673238932196920264801702661201943173809015346082727725783869161803474440 ") + echo(" u -> 10405477402946058176045590740070709500904395284580129777629727895349459816649 ") + + echo("") + echo("LHS - RHS = ") + p = lhs + p -= rhs + printG(p) + + doAssert p.isInf().bool() + + `test _ scalarProc`() + +system.echo "issue #345 - scalarMul" +test(scalarMul) +system.echo "issue #345 - scalarMul_vartime" +test(scalarMul_vartime) + +system.echo "SUCCESS - issue #345" + +#------------------------------------------------------------------------------- + +#[ + +SageMath code + +# BN128 elliptic curve +p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +h = 1 +Fp = GF(p) +Fr = GF(r) +A = Fp(0) +B = Fp(3) +E = EllipticCurve(Fp,[A,B]) +gx = Fp(1) +gy = Fp(2) +gen = E(gx,gy) # subgroup generator +print("scalar field check: ", gen.additive_order() == r ) +print("cofactor check: ", E.cardinality() == r*h ) + +# extension field +R. = Fp[] +Fp2. = Fp.extension(x^2+1) + +# twisted curve +B_twist = Fp2(19485874751759354771024239261021720505790618469301721065564631296452457478373 + 266929791119991161246907387137283842545076965332900288569378510910307636690*u ) +E2 = EllipticCurve(Fp2,[0,B_twist]) +size_E2 = E2.cardinality(); +cofactor_E2 = size_E2 / r; + +gen2_xi = Fp( 0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701 ) +gen2_xu = Fp( 0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b ) +gen2_yi = Fp( 0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8 ) +gen2_yu = Fp( 0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c ) + +gen2_x = gen2_xi + u * gen2_xu +gen2_y = gen2_yi + u * gen2_yu + +gen2 = E2(gen2_x, gen2_y) + +print("g2^r: ", gen2*r ) + +expo = 0x7b17fcc286b01af79176aa7da3a8615020eacda89a90e4ff5d0a085483f0448 + +print("g2^expo: ") +print(gen2*expo) + +]# + +#------------------------------------------------------------------------------- diff --git a/tests/math_elliptic_curves/t_ec_template.nim b/tests/math_elliptic_curves/t_ec_template.nim index 8362d838..3e71f236 100644 --- a/tests/math_elliptic_curves/t_ec_template.nim +++ b/tests/math_elliptic_curves/t_ec_template.nim @@ -685,6 +685,35 @@ proc run_EC_mul_vs_ref_impl*( test(ec, bits = ec.F.C.getCurveOrderBitwidth(), randZ = false, gen = Long01Sequence) test(ec, bits = ec.F.C.getCurveOrderBitwidth(), randZ = true, gen = Long01Sequence) + # Scalars that doesn't uses the full bit length + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 2, randZ = false, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 2, randZ = true, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 2, randZ = false, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 2, randZ = true, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 2, randZ = false, gen = Long01Sequence) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 2, randZ = true, gen = Long01Sequence) + + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 4, randZ = false, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 4, randZ = true, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 4, randZ = false, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 4, randZ = true, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 4, randZ = false, gen = Long01Sequence) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() - 4, randZ = true, gen = Long01Sequence) + + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 2, randZ = false, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 2, randZ = true, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 2, randZ = false, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 2, randZ = true, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 2, randZ = false, gen = Long01Sequence) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 2, randZ = true, gen = Long01Sequence) + + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 4, randZ = false, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 4, randZ = true, gen = Uniform) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 4, randZ = false, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 4, randZ = true, gen = HighHammingWeight) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 4, randZ = false, gen = Long01Sequence) + test(ec, bits = ec.F.C.getCurveOrderBitwidth() div 4, randZ = true, gen = Long01Sequence) + proc run_EC_mixed_add_impl*( ec: typedesc, Iters: static int,