Skip to content

Commit

Permalink
feat(bench): zkalc prepare for adding pairing benches - generic type …
Browse files Browse the repository at this point in the history
…resulution issue
  • Loading branch information
mratsim committed Jul 12, 2024
1 parent 94b5d04 commit 0c7d22e
Show file tree
Hide file tree
Showing 30 changed files with 216 additions and 67 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_elliptic_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ proc multiAddParallelBench*(EC: typedesc, numInputs: int, iters: int) =
type BenchMsmContext*[EC] = object
tp: Threadpool
numInputs: int
coefs: seq[getBigInt(EC.F.Name, kScalarField)]
coefs: seq[getBigInt(EC.getName(), kScalarField)]
points: seq[affine(EC)]

proc createBenchMsmContext*(EC: typedesc, inputSizes: openArray[int]): BenchMsmContext[EC] =
Expand Down
119 changes: 107 additions & 12 deletions benchmarks/zkalc.nim
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import
constantine/lowlevel_fields,
constantine/lowlevel_elliptic_curves,
constantine/lowlevel_elliptic_curves_parallel,
# constantine/lowlevel_extension_fields,
constantine/lowlevel_pairing_curves,
constantine/threadpool,
# Helpers
helpers/prng_unsafe,
# Standard library
std/[stats, monotimes, times, strformat, strutils, cmdline],
std/[stats, monotimes, times, strformat, strutils, cmdline, macros],
# Third-party
jsony, cliche

Expand Down Expand Up @@ -87,6 +89,7 @@ template bench(body: untyped): AggStats =
let (candidateIters, elapsedNs) = warmup(warmupMs)

# Deduce batch size for bench iterations so that each batch is atleast 10ms to amortize clock overhead
# See https://gms.tf/on-the-costs-of-syscalls.html on clock and syscall latencies and vDSO.
let batchSize = max(1, int(candidateIters.float64 * batchMs.float64 / warmupMs.float64))
# Compute the number of iterations for ~5s of benchmarks
let iters = int(
Expand Down Expand Up @@ -216,11 +219,11 @@ proc benchFrIP(rng: var RngState, curve: static Algebra): ZkalcBenchDetails =
# EC benches
# -------------------------------------------------------------------------------------

proc benchEcAdd(rng: var RngState, EC: type, useVartime: bool): ZkalcBenchDetails =
proc benchEcAdd(rng: var RngState, EC: typedesc, useVartime: bool): ZkalcBenchDetails =
const G =
when EC.G == G1: "𝔾1"
else: "𝔾2"
const curve = EC.F.Name
const curve = EC.getName()

var r {.noInit.}: EC
let P = rng.random_unsafe(EC)
Expand All @@ -243,11 +246,11 @@ proc benchEcAdd(rng: var RngState, EC: type, useVartime: bool): ZkalcBenchDetail
report(G & " Addition " & align("| constant-time", 29), curve, stats)
stats.toZkalc()

proc benchEcMul(rng: var RngState, EC: type, useVartime: bool): ZkalcBenchDetails =
proc benchEcMul(rng: var RngState, EC: typedesc, useVartime: bool): ZkalcBenchDetails =
const G =
when EC.G == G1: "𝔾1"
else: "𝔾2"
const curve = EC.F.Name
const curve = EC.getName()

var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
Expand All @@ -273,7 +276,7 @@ proc benchEcMul(rng: var RngState, EC: type, useVartime: bool): ZkalcBenchDetail
# EC Msm benches
# -------------------------------------------------------------------------------------

type BenchMsmContext*[EC] = object
type BenchMsmContext[EC] = object
numInputs: int
coefs: seq[getBigInt(EC.F.Name, kScalarField)]
points: seq[affine(EC)]
Expand Down Expand Up @@ -343,16 +346,15 @@ proc benchEcMsm[EC](ctx: BenchMsmContext[EC]): ZkalcBenchDetails =

tp.shutdown()

# EC Misc benches
# EC serialization benches
# -------------------------------------------------------------------------------------

proc benchEcIsInSubgroup(rng: var RngState, EC: type): ZkalcBenchDetails =
const G =
when EC.G == G1: "𝔾1"
else: "𝔾2"
const curve = EC.F.Name
const curve = EC.getName()

var r {.noInit.}: EC
var P = rng.random_unsafe(EC)
P.clearCofactor()
preventOptimAway(P)
Expand All @@ -367,7 +369,7 @@ proc benchEcHashToCurve(rng: var RngState, EC: type): ZkalcBenchDetails =
const G =
when EC.G == G1: "𝔾1"
else: "𝔾2"
const curve = EC.F.Name
const curve = EC.getName()

const dst = "Constantine_Zkalc_Bench_HashToCurve"
# Gnark uses a message of size 54, probably to not spill over padding with SHA256
Expand All @@ -388,6 +390,62 @@ proc benchEcHashToCurve(rng: var RngState, EC: type): ZkalcBenchDetails =
report(G & " Hash-to-Curve", curve, stats)
stats.toZkalc()

# Pairing benches
# -------------------------------------------------------------------------------------

func clearCofactor[F; G: static Subgroup](
ec: var EC_ShortW_Aff[F, G]) =
# For now we don't have any affine operation defined
var t {.noInit.}: EC_ShortW_Prj[F, G]
t.fromAffine(ec)
t.clearCofactor()
ec.affine(t)

func random_point*(rng: var RngState, EC: typedesc): EC {.noInit.} =
result = rng.random_unsafe(EC)
result.clearCofactor()

# proc benchPairing*(rng: var RngState, curve: static Algebra): ZkalcBenchDetails =
# let
# P = rng.random_point(EC_ShortW_Aff[Fp[curve], G1])
# Q = rng.random_point(EC_ShortW_Aff[Fp2[curve], G2])

# var f: Fp12[curve]
# let stats = bench():
# f.pairing(P, Q)

# report("Pairing", curve, stats)
# stats.toZkalc()

# proc benchMultiPairing*(rng: var RngState, curve: static Algebra, maxNumInputs: int): ZkalcBenchDetails =
# var
# Ps = newSeq[EC_ShortW_Aff[Fp[curve], G1]](maxNumInputs)
# Qs = newSeq[EC_ShortW_Aff[Fp2[curve], G2]](maxNumInputs)

# stdout.write &"Generating {maxNumInputs} (𝔾1, 𝔾2) pairs ... "
# stdout.flushFile()

# let start = getMonotime()

# for i in 0 ..< maxNumInputs:
# Ps[i] = rng.random_point(typeof(Ps[0]))
# Qs[i] = rng.random_point(typeof(Qs[0]))

# let stop = getMonotime()
# stdout.write &"in {float64(inNanoSeconds(stop-start)) / 1e6:6.3f} ms\n"
# separator()

# var size = 2
# while size <= maxNumInputs:
# var f{.noInit.}: Fp12[curve]
# let stats = bench():
# f.pairing(Ps.toOpenArray(0, size-1), Qs.toOpenArray(0, size-1))

# report("Multipairing " & align($size, 5), curve, stats)
# result.append(stats, size)

# size *= 2

# Run benches
# -------------------------------------------------------------------------------------

Expand All @@ -397,13 +455,19 @@ proc runBenches(curve: static Algebra, useVartime: bool) =

var zkalc: ZkalcBenchResult

type EcG1 = EC_ShortW_Jac[Fp[curve], G1]
# Fields
# --------------------------------------------------------------------
separator()
zkalc.add_ff = rng.benchFrAdd(curve)
zkalc.mul_ff = rng.benchFrMul(curve)
zkalc.invert = rng.benchFrInv(curve, useVartime)
zkalc.ip_ff = rng.benchFrIP(curve)
separator()

# Elliptic curve
# --------------------------------------------------------------------
type EcG1 = EC_ShortW_Jac[Fp[curve], G1]

zkalc.add_g1 = rng.benchEcAdd(EcG1, useVartime)
zkalc.mul_g1 = rng.benchEcMul(EcG1, useVartime)
separator()
Expand All @@ -416,6 +480,36 @@ proc runBenches(curve: static Algebra, useVartime: bool) =
zkalc.hash_G1 = rng.benchEcHashToCurve(EcG1)
separator()

# # Pairing-friendly curve only
# # --------------------------------------------------------------------

# when curve.isPairingFriendly():

# # Elliptic curve 𝔾2
# # --------------------------------------------------------------------

# type EcG2 = EC_ShortW_Jac[Fp2[curve], G2] # For now we only supports G2 on Fp2 (not Fp like BW6 or Fp4 like BLS24)

# zkalc.add_g2 = rng.benchEcAdd(EcG2, useVartime)
# zkalc.mul_g2 = rng.benchEcMul(EcG2, useVartime)
# separator()
# let ctxG2 = rng.createBenchMsmContext(EcG2, maxNumInputs = 2097152)
# separator()
# zkalc.msm_g2 = benchEcMsm(ctxG2)
# separator()
# zkalc.is_in_sub_G2 = rng.benchEcIsInSubgroup(EcG2)
# when curve in {BN254_Snarks, BLS12_381}:
# zkalc.hash_G2 = rng.benchEcHashToCurve(EcG2)
# separator()

# # Pairings
# # --------------------------------------------------------------------

# zkalc.pairing = rng.benchPairing(curve)
# separator()
# zkalc.multipairing = rng.benchMultiPairing(curve, maxNumInputs = 1024)
# separator()

proc main() =
let cmd = commandLineParams()
cmd.getOpt (curve: BN254_Snarks, vartime: true)
Expand All @@ -429,4 +523,5 @@ proc main() =
else:
echo "This curve '" & $curve & "' is not configured for benchmarking at the moment."

main()
when isMainModule:
main()
6 changes: 4 additions & 2 deletions constantine/lowlevel_elliptic_curves.nim
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export
abstractions.BigInt,
algebras.Algebra,
algebras.getBigInt,
algebras.FieldKind
algebras.FieldKind,
algebras.isPairingFriendly

# Generic sandwich
export abstractions
Expand All @@ -52,7 +53,8 @@ export
ec_shortweierstrass.EC_ShortW_Aff,
ec_shortweierstrass.EC_ShortW_Jac,
ec_shortweierstrass.EC_ShortW_Prj,
ec_shortweierstrass.EC_ShortW
ec_shortweierstrass.EC_ShortW,
ec_shortweierstrass.getName

export ec_shortweierstrass.`==`
export ec_shortweierstrass.isNeutral
Expand Down
10 changes: 8 additions & 2 deletions constantine/lowlevel_extension_fields.nim
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ export
algebras.Algebra,
algebras.getBigInt

export
algebras.Fp,
algebras.Fr,
algebras.FF

# Extension fields
# ------------------------------------------------------------

export
extension_fields.Fp2
extension_fields.Name,
extension_fields.Fp2,
# TODO: deal with Fp2->Fp6 vs Fp3->Fp6 and Fp2->Fp6->Fp12 vs Fp2->Fp4->Fp12
# extension_fields.Fp4,
# extension_fields.Fp6,
# extension_fields.Fp12
extension_fields.Fp12

# Generic sandwich - https://github.com/nim-lang/Nim/issues/11225
export extension_fields.c0, extension_fields.`c0=`
Expand Down
4 changes: 2 additions & 2 deletions constantine/math/elliptic/ec_multi_scalar_mul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,13 @@ template withEndo[coefsBits: static int, EC, ECaff](
coefs: ptr UncheckedArray[BigInt[coefsBits]],
points: ptr UncheckedArray[ECaff],
N: int, c: static int) =
when hasEndomorphismAcceleration(EC.F.Name) and
when hasEndomorphismAcceleration(EC.getName()) and
EndomorphismThreshold <= coefsBits and
coefsBits <= EC.getScalarField().bits() and
# computeEndomorphism assumes they can be applied to affine repr
# but this is not the case for Bandersnatch/wagon
# instead Twisted Edwards MSM should be overloaded for Projective/ProjectiveExtended
EC.F.Name notin {Bandersnatch, Banderwagon}:
EC.getName() notin {Bandersnatch, Banderwagon}:
let (endoCoefs, endoPoints, endoN) = applyEndomorphism(coefs, points, N)
# Given that bits and N changed, we are able to use a bigger `c`
# but it has no significant impact on performance
Expand Down
6 changes: 3 additions & 3 deletions constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,13 @@ template withEndo[coefsBits: static int, EC, ECaff](
coefs: ptr UncheckedArray[BigInt[coefsBits]],
points: ptr UncheckedArray[ECaff],
N: int, c: static int) =
when hasEndomorphismAcceleration(EC.F.Name) and
when hasEndomorphismAcceleration(EC.getName()) and
EndomorphismThreshold <= coefsBits and
coefsBits <= EC.getScalarField().bits() and
# computeEndomorphism assumes they can be applied to affine repr
# but this is not the case for Bandersnatch/wagon
# instead Twisted Edwards MSM should be overloaded for Projective/ProjectiveExtended
EC.F.Name notin {Bandersnatch, Banderwagon}:
EC.getName() notin {Bandersnatch, Banderwagon}:
let (endoCoefs, endoPoints, endoN) = applyEndomorphism_parallel(tp, coefs, points, N)
# Given that bits and N changed, we are able to use a bigger `c`
# but it has no significant impact on performance
Expand All @@ -518,7 +518,7 @@ template withEndo[coefsBits: static int, EC, ECaff](
coefs: ptr UncheckedArray[BigInt[coefsBits]],
points: ptr UncheckedArray[ECaff],
N: int, c: static int, useParallelBuckets: static bool) =
when coefsBits <= EC.getScalarField().bits() and hasEndomorphismAcceleration(EC.F.Name):
when coefsBits <= EC.getScalarField().bits() and hasEndomorphismAcceleration(EC.getName()):
let (endoCoefs, endoPoints, endoN) = applyEndomorphism_parallel(tp, coefs, points, N)
# Given that bits and N changed, we are able to use a bigger `c`
# but it has no significant impact on performance
Expand Down
4 changes: 2 additions & 2 deletions constantine/math/elliptic/ec_multi_scalar_mul_scheduler.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

import
constantine/platforms/abstractions,
constantine/math/arithmetic,
constantine/math/[arithmetic, extension_fields],
./ec_shortweierstrass_affine,
./ec_shortweierstrass_jacobian,
./ec_shortweierstrass_projective,
./ec_shortweierstrass_batch_ops,
./ec_twistededwards_projective,
./ec_twistededwards_affine

export abstractions, arithmetic,
export abstractions, arithmetic, extension_fields,
ec_shortweierstrass_affine,
ec_shortweierstrass_jacobian,
ec_shortweierstrass_projective,
Expand Down
2 changes: 1 addition & 1 deletion constantine/math/elliptic/ec_scalar_mul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func scalarMul*[EC](P: var EC, scalar: BigInt) {.inline, meter.} =
## - Cofactor to be cleared
## - 0 <= scalar < curve order
## Those will be assumed to maintain constant-time property
when EC.F.Name.hasEndomorphismAcceleration() and
when EC.getName().hasEndomorphismAcceleration() and
BigInt.bits >= EndomorphismThreshold:
when EC.F is Fp:
P.scalarMulGLV_m2w2(scalar)
Expand Down
2 changes: 1 addition & 1 deletion constantine/math/elliptic/ec_scalar_mul_vartime.nim
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func scalarMul_vartime*[scalBits; EC](P: var EC, scalar: BigInt[scalBits]) {.met

let usedBits = scalar.limbs.getBits_LE_vartime()

when EC.F.Name.hasEndomorphismAcceleration():
when EC.getName().hasEndomorphismAcceleration():
when scalBits >= EndomorphismThreshold: # Skip static: doAssert when multiplying by intentionally small scalars.
if usedBits >= EndomorphismThreshold:
when EC.F is Fp:
Expand Down
3 changes: 3 additions & 0 deletions constantine/math/elliptic/ec_shortweierstrass_affine.nim
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ type

SexticNonResidue* = NonResidue

template getName*(EC: type EC_ShortW_Aff): untyped =
EC.F.Name

template getScalarField*(EC: type EC_ShortW_Aff): untyped =
Fr[EC.F.Name]

Expand Down
3 changes: 3 additions & 0 deletions constantine/math/elliptic/ec_shortweierstrass_jacobian.nim
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ type EC_ShortW_Jac*[F; G: static Subgroup] = object
## Note that jacobian coordinates are not unique
x*, y*, z*: F

template getName*(EC: type EC_ShortW_Jac): untyped =
EC.F.Name

template getScalarField*(EC: type EC_ShortW_Jac): untyped =
Fr[EC.F.Name]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ type EC_ShortW_JacExt*[F; G: static Subgroup] = object
## Note that extended jacobian coordinates are not unique
x*, y*, zz*, zzz*: F

template getName*(EC: type EC_ShortW_JacExt): untyped =
EC.F.Name

func fromAffine*[F; G](jacext: var EC_ShortW_JacExt[F, G], aff: EC_ShortW_Aff[F, G]) {.inline.}

func isNeutral*(P: EC_ShortW_JacExt): SecretBool {.inline, meter.} =
Expand Down
3 changes: 3 additions & 0 deletions constantine/math/elliptic/ec_shortweierstrass_projective.nim
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ type EC_ShortW_Prj*[F; G: static Subgroup] = object
## Note that projective coordinates are not unique
x*, y*, z*: F

template getName*(EC: type EC_ShortW_Prj): untyped =
EC.F.Name

template getScalarField*(EC: type EC_ShortW_Prj): untyped =
Fr[EC.F.Name]

Expand Down
3 changes: 3 additions & 0 deletions constantine/math/elliptic/ec_twistededwards_affine.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type EC_TwEdw_Aff*[F] = object
## over a field F
x*, y*: F

template getName*(EC: type EC_TwEdw_Aff): untyped =
EC.F.Name

template getScalarField*(EC: type EC_TwEdw_Aff): untyped =
Fr[EC.F.Name]

Expand Down
Loading

0 comments on commit 0c7d22e

Please sign in to comment.