Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tentative fix for #345 - constant-time scalar mul with endomorphism acceleration wrong result #346

Merged
merged 3 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions constantine.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
# Edge cases highlighted by past bugs
# ----------------------------------------------------------
("tests/math_elliptic_curves/t_ec_shortw_prj_edge_cases.nim", false),
("tests/math_elliptic_curves/t_ec_shortw_prj_edge_case_345.nim", false),

# Elliptic curve arithmetic - batch computation
# ----------------------------------------------------------
Expand Down
103 changes: 61 additions & 42 deletions constantine/math/elliptic/ec_endomorphism_accel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,13 @@ type
MultiScalar[M, LengthInBits: static int] = array[M, BigInt[LengthInBits]]
## Decomposition of a secret scalar in multiple scalars

func decomposeEndo*[M, scalBits, L: static int](
miniScalars: var MultiScalar[M, L],
negatePoints: var array[M, SecretBool],
template decomposeEndoImpl[scalBits: static int](
scalar: BigInt[scalBits],
F: typedesc[Fp or Fp2]
) =
## Decompose a secret scalar into M mini-scalars
## using a curve endomorphism(s) characteristics.
##
## A scalar decomposition might lead to negative miniscalar(s).
## For proper handling it requires either:
## 1. Negating it and then negating the corresponding curve point P
## 2. Adding an extra bit to the recoding, which will do the right thing™
##
## For implementation solution 1 is faster:
## - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381)
## - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles
## We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles
## and negate it as well.

F: typedesc[Fp or Fp2],
copyMiniScalarsResult: untyped) =
static: doAssert scalBits >= L, "Cannot decompose a scalar smaller than a mini-scalar or the decomposition coefficient"

# Equal when no window or no negative handling, greater otherwise
static: doAssert L >= scalBits.ceilDiv_vartime(M) + 1
static: doAssert L >= ceilDiv_vartime(scalBits, M) + 1
const w = F.C.getCurveOrderBitwidth().wordsRequired()

when M == 2:
Expand All @@ -84,15 +67,11 @@ func decomposeEndo*[M, scalBits, L: static int](
alphas[i].setZero()
else:
alphas[i].prod_high_words(babai(F)[i][0], scalar, w)
when babai(F)[i][1]:
# prod_high_words works like logical right shift
# When negative, we should add 1 to properly round toward -infinity
alphas[i] += One
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change. Papers that introduce the Babai's rounding unfortunately use shifts but do not go over the negative special case.


# We have k0 = s - 𝛼0 b00 - 𝛼1 b10 ... - 𝛼m bm0
# and kj = 0 - 𝛼j b0j - 𝛼1 b1j ... - 𝛼m bmj
var
k: array[M, BigInt[scalBits]] # zero-init required
k {.inject.}: array[M, BigInt[scalBits]] # zero-init required
alphaB {.noInit.}: BigInt[scalBits]
k[0] = scalar
staticFor miniScalarIdx, 0, M:
Expand All @@ -108,11 +87,62 @@ func decomposeEndo*[M, scalBits, L: static int](
else:
k[miniScalarIdx] -= alphaB

copyMiniScalarsResult

func decomposeEndo*[M, scalBits, L: static int](
miniScalars: var MultiScalar[M, L],
negatePoints: var array[M, SecretBool],
scalar: BigInt[scalBits],
F: typedesc[Fp or Fp2]) =
## Decompose a secret scalar into M mini-scalars
## using a curve endomorphism(s) characteristics.
##
## A scalar decomposition might lead to negative miniscalar(s).
## For proper handling it requires either:
## 1. Negating it and then negating the corresponding curve point P
## 2. Adding an extra bit to the recoding, which will do the right thing™
##
## For implementation solution 1 is faster:
## - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381)
## - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles
## We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles
## and negate it as well.
##
## This implements solution 1.
decomposeEndoImpl(scalar, F):
# Negative miniscalars are turned positive
# Caller should negate the corresponding Elliptic Curve points
let isNeg = k[miniScalarIdx].isMsbSet()
negatePoints[miniScalarIdx] = isNeg
k[miniScalarIdx].cneg(isNeg)
miniScalars[miniScalarIdx].copyTruncatedFrom(k[miniScalarIdx])

func decomposeEndo*[M, scalBits, L: static int](
miniScalars: var MultiScalar[M, L],
scalar: BigInt[scalBits],
F: typedesc[Fp or Fp2]) =
## Decompose a secret scalar into M mini-scalars
## using a curve endomorphism(s) characteristics.
##
## A scalar decomposition might lead to negative miniscalar(s).
## For proper handling it requires either:
## 1. Negating it and then negating the corresponding curve point P
## 2. Adding an extra bit to the recoding, which will do the right thing™
##
## For implementation solution 1 is faster:
## - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381)
## - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles
## We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles
## and negate it as well.
##
## However, when dealing with scalars that do not use the full bitwidth
## the extra bit avoids potential underflows.
## Also for partitioned GLV-SAC (with 8-way decomposition) it is necessary.
##
## This implements solution 2.
decomposeEndoImpl(scalar, F):
miniScalars[miniScalarIdx].copyTruncatedFrom(k[miniScalarIdx])

# Secret scalar + dynamic point
# ----------------------------------------------------------------
#
Expand Down Expand Up @@ -184,8 +214,7 @@ proc `[]=`(recoding: var Recoded,

func nDimMultiScalarRecoding[M, L: static int](
dst: var GLV_SAC[M, L],
src: MultiScalar[M, L]
) =
src: MultiScalar[M, L]) =
## This recodes N scalar for GLV multi-scalar multiplication
## with side-channel resistance.
##
Expand Down Expand Up @@ -316,7 +345,7 @@ func scalarMulEndo*[scalBits; EC](
{.error: "Unconfigured".}

# 2. Decompose scalar into mini-scalars
const L = scalBits.ceilDiv_vartime(M) + 1 # Alternatively, negative can be handled with an extra "+1"
const L = scalBits.ceilDiv_vartime(M) + 1
var miniScalars {.noInit.}: array[M, BigInt[L]]
var negatePoints {.noInit.}: array[M, SecretBool]
miniScalars.decomposeEndo(negatePoints, scalar, P.F)
Expand All @@ -325,13 +354,7 @@ func scalarMulEndo*[scalBits; EC](
# A scalar decomposition might lead to negative miniscalar.
# For proper handling it requires either:
# 1. Negating it and then negating the corresponding curve point P
# 2. Adding an extra bit to the recoding, which will do the right thing™
#
# For implementation solution 1 is faster:
# - Double + Add is about 5000~8000 cycles on 6 64-bits limbs (BLS12-381)
# - Conditional negate is about 10 cycles per Fp, on G2 projective we have 3 (coords) * 2 (Fp2) * 10 (cycles) ~= 60 cycles
# We need to test the mini scalar, which is 65 bits so 2 Fp so about 2 cycles
# and negate it as well.
# 2. Adding an extra bit to L for the recoding, which will do the right thing™
block:
P.cneg(negatePoints[0])
staticFor i, 1, M:
Expand Down Expand Up @@ -401,8 +424,7 @@ func scalarMulEndo*[scalBits; EC](
func buildLookupTable_m2w2[EC, Ecaff](
P0: EC,
P1: EC,
lut: var array[8, Ecaff],
) =
lut: var array[8, Ecaff]) =
## Build a lookup table for GLV with 2-dimensional decomposition
## and window of size 2

Expand Down Expand Up @@ -473,10 +495,7 @@ func computeRecodedLength(bitWidth, window: int): int =
let lw = bitWidth.ceilDiv_vartime(window) + 1
result = (lw mod window) + lw

func scalarMulGLV_m2w2*[scalBits; EC](
P0: var EC,
scalar: BigInt[scalBits]
) =
func scalarMulGLV_m2w2*[scalBits; EC](P0: var EC, scalar: BigInt[scalBits]) =
## Elliptic Curve Scalar Multiplication
##
## P <- [k] P
Expand Down
2 changes: 1 addition & 1 deletion tests/math_elliptic_curves/t_ec_sage_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ proc run_scalar_mul_test_vs_sage*(
const coord = when EC is ECP_ShortW_Prj: " Projective coordinates "
elif EC is ECP_ShortW_Jac: " Jacobian coordinates "

const testSuiteDesc = "Scalar Multiplication " & $EC.F.C & " " & G1_or_G2 & " vs SageMath - " & $bits & "-bit scalar"
const testSuiteDesc = "Scalar Multiplication " & $EC.F.C & " " & G1_or_G2 & " " & coord & " vs SageMath - " & $bits & "-bit scalar"

suite testSuiteDesc & " [" & $WordBitWidth & "-bit words]":
for i in 0 ..< vec.vectors.len:
Expand Down
181 changes: 181 additions & 0 deletions tests/math_elliptic_curves/t_ec_shortw_prj_edge_case_345.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# https://github.com/mratsim/constantine/issues/345

import ../../constantine/math/arithmetic
import ../../constantine/math/io/io_fields
import ../../constantine/math/io/io_bigints
import ../../constantine/math/config/curves
import ../../constantine/math/extension_fields/towers
import ../../constantine/math/elliptic/ec_shortweierstrass_affine
import ../../constantine/math/elliptic/ec_shortweierstrass_projective
import ../../constantine/math/elliptic/ec_scalar_mul
import ../../constantine/math/elliptic/ec_scalar_mul_vartime

#-------------------------------------------------------------------------------

type B = BigInt[254]
type F = Fp[BN254Snarks]
type F2 = QuadraticExt[F]
type G = ECP_ShortW_Prj[F2, G2]

#-------------------------------------------------------------------------------

# size of the scalar field
let r : B = fromHex( B ,"0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001" )

let expo : B = fromHex( B, "0x7b17fcc286b01af79176aa7da3a8615020eacda89a90e4ff5d0a085483f0448" )

let expoA_fr = fromHex( Fr[BN254Snarks],"0x1234567890123456789001234567890" )
var expoB_fr = fromHex( Fr[BN254Snarks],"0x7b17fcc286b01af79176aa7da3a8615020eacda89a90e4ff5d0a085483f0448" )
expoB_fr -= expoA_fr

let expoA = expoA_fr.toBig()
let expoB = expoB_fr.toBig()

# debugEcho "expo:" , expo.toHex()

let zeroF : F = fromHex( F, "0x00" )
let oneF : F = fromHex( F, "0x01" )

#-------------------------------------------------------------------------------

# standard generator of G2

let gen2_xi : F = fromHex( F, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701" )
let gen2_xu : F = fromHex( F, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b" )
let gen2_yi : F = fromHex( F, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8" )
let gen2_yu : F = fromHex( F, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c" )

let gen2_x : F2 = F2( coords: [gen2_xi, gen2_xu] )
let gen2_y : F2 = F2( coords: [gen2_yi, gen2_yu] )
let gen2_z : F2 = F2( coords: [oneF , zeroF ] )

let gen2 : G = G( x: gen2_x, y: gen2_y, z: gen2_z )

#-------------------------------------------------------------------------------

template echo(intercept: untyped) =
# This intercepts system.echo
# Delete this template to debug intermediate steps
discard

proc printF( x: F ) =
echo(" = " & x.toDecimal)

proc printF2( z: F2) =
echo(" 1 ~> " & z.coords[0].toDecimal )
echo(" u ~> " & z.coords[1].toDecimal )


proc printG( pt: G ) =
var aff : ECP_ShortW_Aff[F2, G2];
aff.affine(pt)
echo(" affine x coord: "); printF2( aff.x )
echo(" affine y coord: "); printF2( aff.y )

#-------------------------------------------------------------------------------

template test(scalarProc: untyped) =
proc `test _ scalarProc`() =
var p : G
var q : G

echo("")
echo("sanity check: g2^r should be infinity")
p = gen2
p.scalarProc(r)
printG(p)

echo("")
echo("LHS = g2^expo")
p = gen2
p.scalarProc(expo)
printG(p)
let lhs : G = p

echo("")
echo("RHS = g2^expoA * g2^expoB, where expo = expoA + expoB")
p = gen2
q = gen2
p.scalarProc(expoA)
q.scalarProc(expoB)
p += q
printG(p)
let rhs : G = p

echo("")
echo("reference from SageMath")
echo(" sage x coord:")
echo(" 1 -> 17216390949661727229956939928583223226083668728437958793715435751523027888005 ")
echo(" u -> 3082945034329785101034278215941854680789766318859358488904629243958221738137 ")
echo(" sage y coord:")
echo(" 1 -> 20108673238932196920264801702661201943173809015346082727725783869161803474440 ")
echo(" u -> 10405477402946058176045590740070709500904395284580129777629727895349459816649 ")

echo("")
echo("LHS - RHS = ")
p = lhs
p -= rhs
printG(p)

doAssert p.isInf().bool()

`test _ scalarProc`()

system.echo "issue #345 - scalarMul"
test(scalarMul)
system.echo "issue #345 - scalarMul_vartime"
test(scalarMul_vartime)

system.echo "SUCCESS - issue #345"

#-------------------------------------------------------------------------------

#[

SageMath code

# BN128 elliptic curve
p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
h = 1
Fp = GF(p)
Fr = GF(r)
A = Fp(0)
B = Fp(3)
E = EllipticCurve(Fp,[A,B])
gx = Fp(1)
gy = Fp(2)
gen = E(gx,gy) # subgroup generator
print("scalar field check: ", gen.additive_order() == r )
print("cofactor check: ", E.cardinality() == r*h )

# extension field
R.<x> = Fp[]
Fp2.<u> = Fp.extension(x^2+1)

# twisted curve
B_twist = Fp2(19485874751759354771024239261021720505790618469301721065564631296452457478373 + 266929791119991161246907387137283842545076965332900288569378510910307636690*u )
E2 = EllipticCurve(Fp2,[0,B_twist])
size_E2 = E2.cardinality();
cofactor_E2 = size_E2 / r;

gen2_xi = Fp( 0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701 )
gen2_xu = Fp( 0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b )
gen2_yi = Fp( 0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8 )
gen2_yu = Fp( 0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c )

gen2_x = gen2_xi + u * gen2_xu
gen2_y = gen2_yi + u * gen2_yu

gen2 = E2(gen2_x, gen2_y)

print("g2^r: ", gen2*r )

expo = 0x7b17fcc286b01af79176aa7da3a8615020eacda89a90e4ff5d0a085483f0448

print("g2^expo: ")
print(gen2*expo)

]#

#-------------------------------------------------------------------------------
Loading
Loading