Skip to content

Commit

Permalink
fix: batch inversion zero edge cases, introduced in #278
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jun 10, 2024
1 parent 35d9938 commit 9b7bc95
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 99 deletions.
2 changes: 1 addition & 1 deletion constantine.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
# ("tests/math_fields/t_finite_fields_conditional_arithmetic.nim", false),
# ("tests/math_fields/t_finite_fields_mulsquare.nim", false),
# ("tests/math_fields/t_finite_fields_sqrt.nim", false),
# ("tests/math_fields/t_finite_fields_powinv.nim", false),
("tests/math_fields/t_finite_fields_powinv.nim", false),
# ("tests/math_fields/t_finite_fields_vs_gmp.nim", true),
# ("tests/math_fields/t_fp_cubic_root.nim", false),

Expand Down
2 changes: 1 addition & 1 deletion constantine/eth_verkle_ipa/barycentric_form.nim
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func computeBarycentricCoefficients*(res_inv: var openArray[Fr[Banderwagon]], pr

totalProd *= tmp

res_inv.batchInvert(res)
res_inv.batchInv_vartime(res)

for i in 0 ..< VerkleDomain:
res_inv[i] *= totalProd
Expand Down
2 changes: 1 addition & 1 deletion constantine/eth_verkle_ipa/ipa_verifier.nim
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func checkIPAProof* (ic: IPASettings, transcript: var CryptoHash, got: var EC_P,
challenges[i].fromBig(challenges_big[i])

var challengesInv {.noInit.}: array[8,Fr[Banderwagon]]
challengesInv.batchInvert(challenges)
challengesInv.batchInv_vartime(challenges)

for i in 0 ..< challenges.len:
var x = challenges[i]
Expand Down
4 changes: 2 additions & 2 deletions constantine/eth_verkle_ipa/multiproof.nim
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func createMultiProof* [MultiProof] (res: var MultiProof, transcript: var Crypto


var denInv_prime {.noInit.}: array[VerkleDomain, Fr[Banderwagon]]
denInv_prime.batchInvert(denInv)
denInv_prime.batchInv_vartime(denInv)

#Compute h(X) = g1(X)
var hx {.noInit.}: array[VerkleDomain, Fr[Banderwagon]]
Expand Down Expand Up @@ -267,7 +267,7 @@ func verifyMultiproof*[MultiProof](multiProof: var MultiProof, transcript : var
helperScalarDeno[i].diff(t_fr, z)

var helperScalarDeno_prime {.noInit.}: array[VerkleDomain, Fr[Banderwagon]]
helperScalarDeno_prime.batchInvert(helperScalarDeno)
helperScalarDeno_prime.batchInv_vartime(helperScalarDeno)

# Compute g_2(t) = SUMMATION (y_i * r^i) / (t - z_i) = SUMMATION (y_i * r) * helperScalarDeno
var g2t {.noInit.}: Fr[Banderwagon]
Expand Down
7 changes: 2 additions & 5 deletions constantine/ethereum_verkle_primitives.nim
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import
./math/config/[type_ff, curves],
./math/arithmetic,
./math/elliptic/[
ec_twistededwards_projective,
ec_twistededwards_batch_ops
],
./math/elliptic/ec_twistededwards_projective,
./math/io/[io_bigints, io_fields],
./curves_primitives

Expand Down Expand Up @@ -81,7 +78,7 @@ func batchMapToScalarField*(
for i in 0 ..< N:
ys[i] = points[i].y

ys_inv.batchInvert(ys, N)
ys_inv.batchInv_vartime(ys, N)

for i in 0 ..< N:
var mappedElement: Fp[Banderwagon]
Expand Down
94 changes: 93 additions & 1 deletion constantine/math/arithmetic/finite_fields.nim
Original file line number Diff line number Diff line change
Expand Up @@ -613,4 +613,96 @@ func inv_vartime*(a: var FF) {.tags: [VarTime].} =
## Incidentally this avoids extra check
## to convert Jacobian and Projective coordinates
## to affine for elliptic curve
a.inv_vartime(a)
a.inv_vartime(a)

# ############################################################
#
# Batch operations
#
# ############################################################

func batchInv*[F](
dst: ptr UncheckedArray[F],
elements: ptr UncheckedArray[F],
N: int,
useVartime: static bool = false
) {.noInline.} =
## Batch inversion
## If an element is 0, the inverse stored will be 0.
var zeros = allocStackArray(SecretBool, N)
zeroMem(zeros, N)

var acc: F
acc.setOne()

for i in 0 ..< N:
# Skip zeros
zeros[i] = elements[i].isZero()
var z = elements[i]
z.csetOne(zeros[i])

dst[i] = acc
if i != N-1:
acc.prod(acc, z, skipFinalSub = true)
else:
acc.prod(acc, z, skipFinalSub = false)

acc.inv()

for i in countdown(N-1, 0):
# Extract 1/elemᵢ
dst[i] *= acc
dst[i].csetZero(zeros[i])

# next iteration
var eli = elements[i]
eli.csetOne(zeros[i])
acc.prod(acc, eli, skipFinalSub = true)

func batchInv_vartime*[F](
dst: ptr UncheckedArray[F],
elements: ptr UncheckedArray[F],
N: int,
useVartime: static bool = false
) {.noInline.} =
## Batch inversion
## If an element is 0, the inverse stored will be 0.
var zeros = allocStackArray(bool, N)
zeroMem(zeros, N)

var acc: F
acc.setOne()

for i in 0 ..< N:
if elements[i].isZero().bool():
zeros[i] = true
dst[i].setZero()
continue

dst[i] = acc
if i != N-1:
acc.prod(acc, elements[i], skipFinalSub = true)
else:
acc.prod(acc, elements[i], skipFinalSub = false)

acc.inv_vartime()

for i in countdown(N-1, 0):
if zeros[i] == true:
continue
dst[i] *= acc
acc.prod(acc, elements[i], skipFinalSub = true)

func batchInv*[F](dst: var openArray[F], source: openArray[F]) {.inline.} =
debug: doAssert dst.len == source.len
batchInv(dst.asUnchecked(), source.asUnchecked(), dst.len)

func batchInv*[N: static int, F](dst: var array[N, F], src: array[N, F]) =
batchInv(dst.asUnchecked(), src.asUnchecked(), N)

func batchInv_vartime*[F](dst: var openArray[F], source: openArray[F]) {.inline.} =
debug: doAssert dst.len == source.len
batchInv_vartime(dst.asUnchecked(), source.asUnchecked(), dst.len)

func batchInv_vartime*[N: static int, F](dst: var array[N, F], src: array[N, F]) =
batchInv_vartime(dst.asUnchecked(), src.asUnchecked(), N)
35 changes: 0 additions & 35 deletions constantine/math/elliptic/ec_twistededwards_batch_ops.nim
Original file line number Diff line number Diff line change
Expand Up @@ -86,38 +86,3 @@ func batchAffine*[M, N: static int, F](
affs: var array[M, array[N, ECP_TwEdwards_Aff[F]]],
projs: array[M, array[N, ECP_TwEdwards_Prj[F]]]) {.inline.} =
batchAffine(affs[0].asUnchecked(), projs[0].asUnchecked(), M*N)

func batchInvert*[F](
dst: ptr UncheckedArray[F],
elements: ptr UncheckedArray[F],
N: int
) {.noInline.} =
## Montgomery's batch inversion
var zeros = allocStackArray(bool, N)
zeroMem(zeros, N)

var accumulator: F
accumulator.setOne() # sets the accumulator to 1

for i in 0 ..< N:
if elements[i].isZero().bool():
zeros[i] = true
continue

This comment has been minimized.

Copy link
@mratsim

mratsim Jun 10, 2024

Author Owner

Here dst[i] should be set to zero or the result is inconsistent with single inversion in presence of zeros.


dst[i] = accumulator
accumulator *= elements[i]

accumulator.inv() # inversion of the accumulator

for i in countdown(N-1, 0):
if zeros[i] == true:
continue
dst[i] *= accumulator
accumulator *= elements[i]

func batchInvert*[F](dst: var openArray[F], source: openArray[F]) {.inline.} =
debug: doAssert dst.len == source.len
batchInvert(dst.asUnchecked(), source.asUnchecked(), dst.len)

func batchInvert*[N: static int, F](dst: var array[N, F], src: array[N, F]) =
batchInvert(dst.asUnchecked(), src.asUnchecked(), N)
50 changes: 25 additions & 25 deletions constantine/serialization/codecs_banderwagon.nim
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ func make_scalar_mod_order*(reduced_scalar: var Fr[Banderwagon], src: array[32,

func serialize*(dst: var array[32, byte], P: EC_Prj): CttCodecEccStatus =
## Serialize a Banderwagon point(x, y) in the format
##
##
## serialize = bigEndian( sign(y) * x )
## If y is not lexicographically largest
## set x -> -x
## then serialize
##
##
## Returns cttCodecEcc_Success if successful
## Spec: https://hackmd.io/@6iQDuIePQjyYBqDChYw_jg/BJBNcv9fq#Serialisation

Expand All @@ -81,7 +81,7 @@ func serialize*(dst: var array[32, byte], P: EC_Prj): CttCodecEccStatus =
for i in 0 ..< dst.len:
dst[i] = byte 0
return cttCodecEcc_Success

# Convert the projective points into affine format before encoding
var aff {.noInit.}: EC_Aff
aff.affine(P)
Expand All @@ -96,10 +96,10 @@ func serialize*(dst: var array[32, byte], P: EC_Prj): CttCodecEccStatus =

func deserialize_unchecked*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =
## Deserialize a Banderwagon point (x, y) in format
##
##
## if y is not lexicographically largest
## set y -> -y
##
##
## Returns cttCodecEcc_Success if successful
## https://hackmd.io/@6iQDuIePQjyYBqDChYw_jg/BJBNcv9fq#Serialisation
# If infinity, src must be all zeros
Expand All @@ -111,7 +111,7 @@ func deserialize_unchecked*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccS
if check:
dst.setInf()
return cttCodecEcc_PointAtInfinity

var t{.noInit.}: matchingBigInt(Banderwagon)
t.unmarshal(src, bigEndian)

Expand All @@ -133,10 +133,10 @@ func deserialize_unchecked*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccS
func deserialize_unchecked_vartime*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =
## This is not in constant-time
## Deserialize a Banderwagon point (x, y) in format
##
##
## if y is not lexicographically largest
## set y -> -y
##
##
## Returns cttCodecEcc_Success if successful
## https://hackmd.io/@6iQDuIePQjyYBqDChYw_jg/BJBNcv9fq#Serialisation
# If infinity, src must be all zeros
Expand All @@ -148,7 +148,7 @@ func deserialize_unchecked_vartime*(dst: var EC_Prj, src: array[32, byte]): CttC
if check:
dst.setInf()
return cttCodecEcc_PointAtInfinity

var t{.noInit.}: matchingBigInt(Banderwagon)
t.unmarshal(src, bigEndian)

Expand All @@ -169,9 +169,9 @@ func deserialize_unchecked_vartime*(dst: var EC_Prj, src: array[32, byte]): CttC

func deserialize*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =
## Deserialize a Banderwagon point (x, y) in format
##
##
## Also checks if the point lies in the banderwagon scheme subgroup
##
##
## Returns cttCodecEcc_Success if successful
## Returns cttCodecEcc_PointNotInSubgroup if doesn't lie in subgroup
result = deserialize_unchecked(dst, src)
Expand All @@ -185,9 +185,9 @@ func deserialize*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =

func deserialize_vartime*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccStatus =
## Deserialize a Banderwagon point (x, y) in format
##
##
## Also checks if the point lies in the banderwagon scheme subgroup
##
##
## Returns cttCodecEcc_Success if successful
## Returns cttCodecEcc_PointNotInSubgroup if doesn't lie in subgroup
result = deserialize_unchecked_vartime(dst, src)
Expand All @@ -204,7 +204,7 @@ func deserialize_vartime*(dst: var EC_Prj, src: array[32, byte]): CttCodecEccSta
## Banderwagon Scalar Serialization
##
## ############################################################
##
##
func serialize_scalar*(dst: var array[32, byte], scalar: matchingOrderBigInt(Banderwagon), order: static Endianness = bigEndian): CttCodecScalarStatus =
## Adding an optional Endianness param default at BigEndian
## Serialize a scalar
Expand All @@ -217,7 +217,7 @@ func serialize_scalar*(dst: var array[32, byte], scalar: matchingOrderBigInt(Ban
## Banderwagon Scalar Deserialization
##
## ############################################################
##
##
func deserialize_scalar*(dst: var matchingOrderBigInt(Banderwagon), src: array[32, byte], order: static Endianness = bigEndian): CttCodecScalarStatus =
## Adding an optional Endianness param default at BigEndian
## Deserialize a scalar
Expand All @@ -243,7 +243,7 @@ func deserialize_scalar_mod_order* (dst: var Fr[Banderwagon], src: array[32, byt
debug: doAssert stat, "transcript_gen.deserialize_scalar_mod_order: Unexpected failure"

return cttCodecScalar_Success

## ############################################################
##
## Banderwagon Batch Serialization
Expand All @@ -262,8 +262,8 @@ func serializeBatch*(
for i in 0 ..< N:
zs[i] = points[i].z

zs_inv.batchInvert(zs, N)
zs_inv.batchInv_vartime(zs, N)

for i in 0 ..< N:
var X: Fp[Banderwagon]
var Y: Fp[Banderwagon]
Expand All @@ -288,15 +288,15 @@ func serializeBatchUncompressed*(
## In uncompressed format
## serialize = [ bigEndian( x ) , bigEndian( y ) ]
## Returns cttCodecEcc_Success if successful

# collect all the z coordinates
var zs = allocStackArray(Fp[Banderwagon], N)
var zs_inv = allocStackArray(Fp[Banderwagon], N)
for i in 0 ..< N:
zs[i] = points[i].z

zs_inv.batchInvert(zs, N)
zs_inv.batchInv_vartime(zs, N)

for i in 0 ..< N:
var X: Fp[Banderwagon]
var Y: Fp[Banderwagon]
Expand Down Expand Up @@ -334,9 +334,9 @@ func serializeBatch*[N: static int](

func serializeUncompressed*(dst: var array[64, byte], P: EC_Prj): CttCodecEccStatus =
## Serialize a Banderwagon point(x, y) in the format
##
##
## serialize = [ bigEndian( x ) , bigEndian( y ) ]
##
##
## Returns cttCodecEcc_Success if successful
var aff {.noInit.}: EC_Aff
aff.affine(P)
Expand Down Expand Up @@ -380,9 +380,9 @@ func deserializeUncompressed_unchecked*(dst: var EC_Prj, src: array[64, byte]):

func deserializeUncompressed*(dst: var EC_Prj, src: array[64, byte]): CttCodecEccStatus =
## Deserialize a Banderwagon point (x, y) in format
##
##
## Also checks if the point lies in the banderwagon scheme subgroup
##
##
## Returns cttCodecEcc_Success if successful
result = dst.deserializeUncompressed_unchecked(src)
if not(bool dst.isInSubgroup()):
Expand Down
Loading

0 comments on commit 9b7bc95

Please sign in to comment.