From 874f8d8f99db3988256ca315db9eebd2e286cbf2 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Wed, 17 Jul 2024 00:25:59 +0200 Subject: [PATCH] =?UTF-8?q?feat(gt-multiexp):=20add=20optimized=20multi-ex?= =?UTF-8?q?ponentiation=20in=20=F0=9D=94=BE=E2=82=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...late.nim.cfg => bench_ec_g1_batch.nim.cfg} | 0 benchmarks/bench_ec_msm_bandersnatch.nim.cfg | 1 + benchmarks/bench_ec_msm_bls12_381_g1.nim.cfg | 1 + benchmarks/bench_ec_msm_bls12_381_g2.nim.cfg | 1 + .../bench_ec_msm_bn254_snarks_g1.nim.cfg | 1 + benchmarks/bench_ec_msm_pasta.nim.cfg | 1 + benchmarks/bench_gt_parallel_template.nim | 24 +- .../math/elliptic/ec_multi_scalar_mul.nim | 89 ++--- constantine/math/pairings/gt_multiexp.nim | 338 ++++++++++++++++-- tests/math_pairings/t_pairing_template.nim | 4 +- 10 files changed, 376 insertions(+), 84 deletions(-) rename benchmarks/{bench_elliptic_parallel_template.nim.cfg => bench_ec_g1_batch.nim.cfg} (100%) create mode 100644 benchmarks/bench_ec_msm_bandersnatch.nim.cfg create mode 100644 benchmarks/bench_ec_msm_bls12_381_g1.nim.cfg create mode 100644 benchmarks/bench_ec_msm_bls12_381_g2.nim.cfg create mode 100644 benchmarks/bench_ec_msm_bn254_snarks_g1.nim.cfg create mode 100644 benchmarks/bench_ec_msm_pasta.nim.cfg diff --git a/benchmarks/bench_elliptic_parallel_template.nim.cfg b/benchmarks/bench_ec_g1_batch.nim.cfg similarity index 100% rename from benchmarks/bench_elliptic_parallel_template.nim.cfg rename to benchmarks/bench_ec_g1_batch.nim.cfg diff --git a/benchmarks/bench_ec_msm_bandersnatch.nim.cfg b/benchmarks/bench_ec_msm_bandersnatch.nim.cfg new file mode 100644 index 000000000..9d57ecf93 --- /dev/null +++ b/benchmarks/bench_ec_msm_bandersnatch.nim.cfg @@ -0,0 +1 @@ +--threads:on \ No newline at end of file diff --git a/benchmarks/bench_ec_msm_bls12_381_g1.nim.cfg b/benchmarks/bench_ec_msm_bls12_381_g1.nim.cfg new file mode 100644 index 000000000..9d57ecf93 --- /dev/null +++ b/benchmarks/bench_ec_msm_bls12_381_g1.nim.cfg @@ -0,0 +1 @@ +--threads:on \ No newline at end of file diff --git a/benchmarks/bench_ec_msm_bls12_381_g2.nim.cfg b/benchmarks/bench_ec_msm_bls12_381_g2.nim.cfg new file mode 100644 index 000000000..9d57ecf93 --- /dev/null +++ b/benchmarks/bench_ec_msm_bls12_381_g2.nim.cfg @@ -0,0 +1 @@ +--threads:on \ No newline at end of file diff --git a/benchmarks/bench_ec_msm_bn254_snarks_g1.nim.cfg b/benchmarks/bench_ec_msm_bn254_snarks_g1.nim.cfg new file mode 100644 index 000000000..9d57ecf93 --- /dev/null +++ b/benchmarks/bench_ec_msm_bn254_snarks_g1.nim.cfg @@ -0,0 +1 @@ +--threads:on \ No newline at end of file diff --git a/benchmarks/bench_ec_msm_pasta.nim.cfg b/benchmarks/bench_ec_msm_pasta.nim.cfg new file mode 100644 index 000000000..9d57ecf93 --- /dev/null +++ b/benchmarks/bench_ec_msm_pasta.nim.cfg @@ -0,0 +1 @@ +--threads:on \ No newline at end of file diff --git a/benchmarks/bench_gt_parallel_template.nim b/benchmarks/bench_gt_parallel_template.nim index 197d1b17b..c57da4909 100644 --- a/benchmarks/bench_gt_parallel_template.nim +++ b/benchmarks/bench_gt_parallel_template.nim @@ -125,11 +125,11 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in var r{.noInit.}: GT - var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline: MonoTime + var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline, startMultiExpOpt, stopMultiExpOpt: MonoTime if numInputs <= 100000: # startNaive = getMonotime() - bench("๐”พโ‚œ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters): + bench("๐”พโ‚œ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters): var tmp: GT r.setOne() for i in 0 ..< elems.len: @@ -139,7 +139,7 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in if numInputs <= 100000: startNaive = getMonotime() - bench("๐”พโ‚œ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters): + bench("๐”พโ‚œ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters): var tmp: GT r.setOne() for i in 0 ..< elems.len: @@ -149,14 +149,26 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in if numInputs <= 100000: startMultiExpBaseline = getMonotime() - bench("๐”พโ‚œ multi-exponentiations baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters): + bench("๐”พโ‚œ multi-exponentiations baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters): r.multiExp_reference_vartime(elems, exponents) stopMultiExpBaseline = getMonotime() + block: + startMultiExpOpt = getMonotime() + bench("๐”พโ‚œ multi-exponentiations optimized " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters): + r.multiExp_vartime(elems, exponents) + stopMultiExpOpt = getMonotime() let perfNaive = inNanoseconds((stopNaive-startNaive) div iters) - let perfMSMbaseline = inNanoseconds((stopMultiExpBaseline-startMultiExpBaseline) div iters) + let perfMultiExpBaseline = inNanoseconds((stopMultiExpBaseline-startMultiExpBaseline) div iters) + let perfMultiExpOpt = inNanoseconds((stopMultiExpOpt-startMultiExpOpt) div iters) if numInputs <= 100000: - let speedupBaseline = float(perfNaive) / float(perfMSMbaseline) + let speedupBaseline = float(perfNaive) / float(perfMultiExpBaseline) echo &"Speedup ratio baseline over naive linear combination: {speedupBaseline:>6.3f}x" + + let speedupOpt = float(perfNaive) / float(perfMultiExpOpt) + echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x" + + let speedupOptBaseline = float(perfMultiExpBaseline) / float(perfMultiExpOpt) + echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x" diff --git a/constantine/math/elliptic/ec_multi_scalar_mul.nim b/constantine/math/elliptic/ec_multi_scalar_mul.nim index 51f9e9573..25c1fc754 100644 --- a/constantine/math/elliptic/ec_multi_scalar_mul.nim +++ b/constantine/math/elliptic/ec_multi_scalar_mul.nim @@ -253,7 +253,7 @@ func miniMSM[bits: static int, EC, ECaff]( for _ in 0 ..< c: r.double() -func multiScalarMul_vartime*[bits: static int, EC, ECaff]( +func msmImpl_vartime[bits: static int, EC, ECaff]( r: var EC, coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECaff], N: int, c: static int) {.tags:[VarTime, HeapAlloc], meter.} = @@ -295,6 +295,9 @@ func multiScalarMul_vartime*[bits: static int, EC, ECaff]( # ------- buckets.freeHeap() +# Multi scalar multiplication with batched affine additions +# ----------------------------------------------------------------------------------------------------------------------- + func schedAccumulate*[NumBuckets, QueueLen, F, G; bits: static int]( sched: ptr Scheduler[NumBuckets, QueueLen, F, G], bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int, @@ -344,7 +347,7 @@ func miniMSM_affine[NumBuckets, QueueLen, EC, ECaff; bits: static int]( for _ in 0 ..< c: r.double() -func multiScalarMulAffine_vartime[bits: static int, EC, ECaff]( +func msmAffineImpl_vartime[bits: static int, EC, ECaff]( r: var EC, coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECaff], N: int, c: static int) {.tags:[VarTime, Alloca, HeapAlloc], meter.} = @@ -389,6 +392,9 @@ func multiScalarMulAffine_vartime[bits: static int, EC, ECaff]( sched.freeHeap() buckets.freeHeap() +# Endomorphism acceleration +# ----------------------------------------------------------------------------------------------------------------------- + proc applyEndomorphism[bits: static int, ECaff]( coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECaff], @@ -403,7 +409,7 @@ proc applyEndomorphism[bits: static int, ECaff]( const G = when ECaff isnot EC_ShortW_Aff: G1 else: ECaff.G - const L = ECaff.getScalarField().bits().ceilDiv_vartime(M) + 1 + const L = ECaff.getScalarField().bits().computeEndoRecodedLength(M) let splitCoefs = allocHeapArray(array[M, BigInt[L]], N) let endoBasis = allocHeapArray(array[M, ECaff], N) @@ -447,7 +453,10 @@ template withEndo[coefsBits: static int, EC, ECaff]( else: msmProc(r, coefs, points, N, c) -func multiScalarMul_dispatch_vartime[bits: static int, F, G]( +# Algorithm selection +# ----------------------------------------------------------------------------------------------------------------------- + +func msm_dispatch_vartime[bits: static int, F, G]( r: var (EC_ShortW_Jac[F, G] or EC_ShortW_Prj[F, G]), coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[EC_ShortW_Aff[F, G]], N: int) = @@ -460,27 +469,27 @@ func multiScalarMul_dispatch_vartime[bits: static int, F, G]( # but it has no significant impact on performance case c - of 2: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 2) - of 3: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 3) - of 4: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 4) - of 5: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 5) - of 6: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 6) - of 7: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 7) - of 8: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 8) - - of 9: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 9) - of 10: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 10) - of 11: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 11) - of 12: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 12) - of 13: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 13) - of 14: multiScalarMulAffine_vartime(r, coefs, points, N, c = 14) - of 15: multiScalarMulAffine_vartime(r, coefs, points, N, c = 15) - - of 16..17: multiScalarMulAffine_vartime(r, coefs, points, N, c = 16) + of 2: withEndo(msmImpl_vartime, r, coefs, points, N, c = 2) + of 3: withEndo(msmImpl_vartime, r, coefs, points, N, c = 3) + of 4: withEndo(msmImpl_vartime, r, coefs, points, N, c = 4) + of 5: withEndo(msmImpl_vartime, r, coefs, points, N, c = 5) + of 6: withEndo(msmImpl_vartime, r, coefs, points, N, c = 6) + of 7: withEndo(msmImpl_vartime, r, coefs, points, N, c = 7) + of 8: withEndo(msmImpl_vartime, r, coefs, points, N, c = 8) + + of 9: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 9) + of 10: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 10) + of 11: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 11) + of 12: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 12) + of 13: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 13) + of 14: msmAffineImpl_vartime(r, coefs, points, N, c = 14) + of 15: msmAffineImpl_vartime(r, coefs, points, N, c = 15) + + of 16..17: msmAffineImpl_vartime(r, coefs, points, N, c = 16) else: unreachable() -func multiScalarMul_dispatch_vartime[bits: static int, F]( +func msm_dispatch_vartime[bits: static int, F]( r: var EC_TwEdw_Prj[F], coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[EC_TwEdw_Aff[F]], N: int) = ## Multiscalar multiplication: @@ -494,22 +503,22 @@ func multiScalarMul_dispatch_vartime[bits: static int, F]( # but it has no significant impact on performance case c - of 2: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 2) - of 3: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 3) - of 4: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 4) - of 5: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 5) - of 6: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 6) - of 7: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 7) - of 8: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 8) - of 9: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 9) - of 10: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 10) - of 11: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 11) - of 12: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 12) - of 13: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 13) - of 14: multiScalarMul_vartime(r, coefs, points, N, c = 14) - of 15: multiScalarMul_vartime(r, coefs, points, N, c = 15) - - of 16..17: multiScalarMul_vartime(r, coefs, points, N, c = 16) + of 2: withEndo(msmImpl_vartime, r, coefs, points, N, c = 2) + of 3: withEndo(msmImpl_vartime, r, coefs, points, N, c = 3) + of 4: withEndo(msmImpl_vartime, r, coefs, points, N, c = 4) + of 5: withEndo(msmImpl_vartime, r, coefs, points, N, c = 5) + of 6: withEndo(msmImpl_vartime, r, coefs, points, N, c = 6) + of 7: withEndo(msmImpl_vartime, r, coefs, points, N, c = 7) + of 8: withEndo(msmImpl_vartime, r, coefs, points, N, c = 8) + of 9: withEndo(msmImpl_vartime, r, coefs, points, N, c = 9) + of 10: withEndo(msmImpl_vartime, r, coefs, points, N, c = 10) + of 11: withEndo(msmImpl_vartime, r, coefs, points, N, c = 11) + of 12: withEndo(msmImpl_vartime, r, coefs, points, N, c = 12) + of 13: withEndo(msmImpl_vartime, r, coefs, points, N, c = 13) + of 14: msmImpl_vartime(r, coefs, points, N, c = 14) + of 15: msmImpl_vartime(r, coefs, points, N, c = 15) + + of 16..17: msmImpl_vartime(r, coefs, points, N, c = 16) else: unreachable() @@ -521,7 +530,7 @@ func multiScalarMul_vartime*[bits: static int, EC, ECaff]( ## Multiscalar multiplication: ## r <- [aโ‚€]Pโ‚€ + [aโ‚]Pโ‚ + ... + [aโ‚™โ‚‹โ‚]Pโ‚™โ‚‹โ‚ - multiScalarMul_dispatch_vartime(r, coefs, points, len) + msm_dispatch_vartime(r, coefs, points, len) func multiScalarMul_vartime*[bits: static int, EC, ECaff]( r: var EC, @@ -531,7 +540,7 @@ func multiScalarMul_vartime*[bits: static int, EC, ECaff]( ## r <- [aโ‚€]Pโ‚€ + [aโ‚]Pโ‚ + ... + [aโ‚™โ‚‹โ‚]Pโ‚™โ‚‹โ‚ debug: doAssert coefs.len == points.len let N = points.len - multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N) + msm_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N) func multiScalarMul_vartime*[F, EC, ECaff]( r: var EC, diff --git a/constantine/math/pairings/gt_multiexp.nim b/constantine/math/pairings/gt_multiexp.nim index c98f905ef..7a25b60fb 100644 --- a/constantine/math/pairings/gt_multiexp.nim +++ b/constantine/math/pairings/gt_multiexp.nim @@ -86,8 +86,7 @@ func bestBucketBitSize*(inputSize: int, scalarBitwidth: static int, useSignedBuc if 13 <= result: result -= 1 -func `~*=`*[Gt: ExtensionField](a: var Gt, b: Gt) = - +func `~*=`*[Gt: ExtensionField](a: var Gt, b: Gt) {.inline.} = # TODO: Analyze the inputs to see if there is avalue in more complex shortcuts (-1, or partial 0 coordinates) if a.isOne().bool(): a = b @@ -96,11 +95,14 @@ func `~*=`*[Gt: ExtensionField](a: var Gt, b: Gt) = else: a *= b -func `~/=`*[Gt: ExtensionField](a: var Gt, b: Gt) = +func `~/=`*[Gt: ExtensionField](a: var Gt, b: Gt) {.inline.} = ## Cyclotomic division var t {.noInit.}: Gt t.cyclotomic_inv(b) - a ~*= b + a ~*= t + +func setNeutral*[Gt: ExtensionField](a: var Gt) {.inline.} = + a.setOne() # Reference multi-exponentiation # ------------------------------------------------------------- @@ -108,7 +110,7 @@ func `~/=`*[Gt: ExtensionField](a: var Gt, b: Gt) = func multiExpImpl_reference_vartime[bits: static int, Gt]( r: var Gt, elems: ptr UncheckedArray[Gt], - exponents: ptr UncheckedArray[BigInt[bits]], + expos: ptr UncheckedArray[BigInt[bits]], N: int, c: static int) {.tags:[VarTime, HeapAlloc].} = ## Inner implementation of MEXP, for static dispatch over c, the bucket bit length ## This is a straightforward simple translation of BDLO12, section 4 @@ -127,11 +129,11 @@ func multiExpImpl_reference_vartime[bits: static int, Gt]( # Place our elements in a bucket corresponding to # how many times their bit pattern in the current window of size c for i in 0 ..< numBuckets: - buckets[i].setOne() + buckets[i].setNeutral() # 1. Bucket accumulation. Cost: n - (2แถœ-1) => n elems in 2แถœ-1 elems, first elem per bucket is just copied for j in 0 ..< N: - let b = cast[int](exponents[j].getWindowAt(w*c, c)) + let b = cast[int](expos[j].getWindowAt(w*c, c)) if b == 0: # bucket 0 is unused, no need to add aโฑผโฐ continue else: @@ -166,68 +168,330 @@ func multiExpImpl_reference_vartime[bits: static int, Gt]( func multiExp_reference_dispatch_vartime[bits: static int, Gt]( r: var Gt, elems: ptr UncheckedArray[Gt], - exponents: ptr UncheckedArray[BigInt[bits]], + expos: ptr UncheckedArray[BigInt[bits]], N: int) {.tags:[VarTime, HeapAlloc].} = ## Multiexponentiation: ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ let c = bestBucketBitSize(N, bits, useSignedBuckets = false, useManualTuning = false) case c - of 2: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 2) - of 3: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 3) - of 4: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 4) - of 5: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 5) - of 6: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 6) - of 7: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 7) - of 8: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 8) - of 9: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 9) - of 10: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 10) - of 11: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 11) - of 12: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 12) - of 13: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 13) - of 14: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 14) - of 15: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 15) - - of 16..20: multiExpImpl_reference_vartime(r, elems, exponents, N, c = 16) + of 2: multiExpImpl_reference_vartime(r, elems, expos, N, c = 2) + of 3: multiExpImpl_reference_vartime(r, elems, expos, N, c = 3) + of 4: multiExpImpl_reference_vartime(r, elems, expos, N, c = 4) + of 5: multiExpImpl_reference_vartime(r, elems, expos, N, c = 5) + of 6: multiExpImpl_reference_vartime(r, elems, expos, N, c = 6) + of 7: multiExpImpl_reference_vartime(r, elems, expos, N, c = 7) + of 8: multiExpImpl_reference_vartime(r, elems, expos, N, c = 8) + of 9: multiExpImpl_reference_vartime(r, elems, expos, N, c = 9) + of 10: multiExpImpl_reference_vartime(r, elems, expos, N, c = 10) + of 11: multiExpImpl_reference_vartime(r, elems, expos, N, c = 11) + of 12: multiExpImpl_reference_vartime(r, elems, expos, N, c = 12) + of 13: multiExpImpl_reference_vartime(r, elems, expos, N, c = 13) + of 14: multiExpImpl_reference_vartime(r, elems, expos, N, c = 14) + of 15: multiExpImpl_reference_vartime(r, elems, expos, N, c = 15) + + of 16..20: multiExpImpl_reference_vartime(r, elems, expos, N, c = 16) else: unreachable() func multiExp_reference_vartime*[bits: static int, Gt]( r: var Gt, elems: ptr UncheckedArray[Gt], - exponents: ptr UncheckedArray[BigInt[bits]], + expos: ptr UncheckedArray[BigInt[bits]], N: int) {.tags:[VarTime, HeapAlloc].} = ## Multiexponentiation: ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ - multiExp_reference_dispatch_vartime(r, elems, exponents, N) + multiExp_reference_dispatch_vartime(r, elems, expos, N) -func multiExp_reference_vartime*[Gt](r: var Gt, elems: openArray[Gt], exponents: openArray[BigInt]) {.tags:[VarTime, HeapAlloc].} = +func multiExp_reference_vartime*[Gt](r: var Gt, elems: openArray[Gt], expos: openArray[BigInt]) {.tags:[VarTime, HeapAlloc].} = ## Multiexponentiation: ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ - debug: doAssert exponents.len == elems.len + debug: doAssert expos.len == elems.len let N = elems.len - multiExp_reference_dispatch_vartime(r, elems.asUnchecked(), exponents.asUnchecked(), N) + multiExp_reference_dispatch_vartime(r, elems.asUnchecked(), expos.asUnchecked(), N) func multiExp_reference_vartime*[F, Gt]( r: var Gt, elems: ptr UncheckedArray[Gt], - exponents: ptr UncheckedArray[F], + expos: ptr UncheckedArray[F], len: int) {.tags:[VarTime, Alloca, HeapAlloc], meter.} = ## Multiexponentiation: ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ let n = cast[int](len) - let exponents_big = allocHeapArrayAligned(F.getBigInt(), n, alignment = 64) - exponents_big.batchFromField(exponents, n) - r.multiExp_reference_vartime(elems, exponents_big, n) + let expos_big = allocHeapArrayAligned(F.getBigInt(), n, alignment = 64) + expos_big.batchFromField(expos, n) + r.multiExp_reference_vartime(elems, expos_big, n) - freeHeapAligned(exponents_big) + freeHeapAligned(expos_big) func multiExp_reference_vartime*[Gt]( r: var Gt, elems: openArray[Gt], - exponents: openArray[Fr]) {.tags:[VarTime, Alloca, HeapAlloc], inline.} = + expos: openArray[Fr]) {.tags:[VarTime, Alloca, HeapAlloc], inline.} = + ## Multiexponentiation: + ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ + debug: doAssert expos.len == elems.len + let N = elems.len + multiExp_reference_vartime(r, elems.asUnchecked(), expos.asUnchecked(), N) + +# ########################################################### # +# # +# Multi-exponentiations in ๐”พโ‚œ # +# Optimized version # +# # +# ########################################################### # + + +func accumulate[GT](buckets: ptr UncheckedArray[GT], val: SecretWord, negate: SecretBool, elem: GT) {.inline, meter.} = + let val = BaseType(val) + if val == 0: # Skip gโฐ + return + elif negate.bool: + buckets[val-1] ~/= elem + else: + buckets[val-1] ~*= elem + +func bucketReduce[GT](r: var GT, buckets: ptr UncheckedArray[GT], numBuckets: static int) {.meter.} = + # We interleave reduction with one-ing the bucket to use instruction-level parallelism + + var accumBuckets{.noInit.}: typeof(r) + accumBuckets = buckets[numBuckets-1] + r = buckets[numBuckets-1] + buckets[numBuckets-1].setNeutral() + + for k in countdown(numBuckets-2, 0): + accumBuckets ~*= buckets[k] + r ~*= accumBuckets + buckets[k].setNeutral() + +type MiniMultiExpKind* = enum + kTopWindow + kFullWindow + kBottomWindow + +func bucketAccumReduce*[bits: static int, GT]( + r: var GT, + buckets: ptr UncheckedArray[GT], + bitIndex: int, miniMultiExpKind: static MiniMultiExpKind, c: static int, + elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]], N: int) = + + const excess = bits mod c + const top = bits - excess + + # 1. Bucket Accumulation + var curVal, nextVal: SecretWord + var curNeg, nextNeg: SecretBool + + template getSignedWindow(j : int): tuple[val: SecretWord, neg: SecretBool] = + when miniMultiExpKind == kBottomWindow: expos[j].getSignedBottomWindow(c) + elif miniMultiExpKind == kTopWindow: expos[j].getSignedTopWindow(top, excess) + else: expos[j].getSignedFullWindowAt(bitIndex, c) + + (curVal, curNeg) = getSignedWindow(0) + for j in 0 ..< N-1: + (nextVal, nextNeg) = getSignedWindow(j+1) + if nextVal.BaseType != 0: + # In cryptography, points are indistinguishable from random + # hence, without prefetching, accessing the next bucket is a guaranteed cache miss + prefetchLarge(buckets[nextVal.BaseType-1].addr, Write, HighTemporalLocality, maxCacheLines = 2) + buckets.accumulate(curVal, curNeg, elems[j]) + curVal = nextVal + curNeg = nextNeg + buckets.accumulate(curVal, curNeg, elems[N-1]) + + # 2. Bucket Reduction + r.bucketReduce(buckets, numBuckets = 1 shl (c-1)) + +func miniMultiExp[bits: static int, GT]( + r: var GT, + buckets: ptr UncheckedArray[GT], + bitIndex: int, miniMultiExpKind: static MiniMultiExpKind, c: static int, + elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]], N: int) {.meter.} = + ## Apply a mini-Multi-Exponentiation on [bitIndex, bitIndex+window) + ## slice of all (coef, point) pairs + + var windowProd{.noInit.}: typeof(r) + windowProd.bucketAccumReduce( + buckets, bitIndex, miniMultiExpKind, c, + elems, expos, N) + + # 3. Mini-MultiExp on the slice [bitIndex, bitIndex+window) + r ~*= windowProd + when miniMultiExpKind != kBottomWindow: + for _ in 0 ..< c: + r.cyclotomic_square() + +func multiExpImpl_vartime[bits: static int, GT]( + r: var GT, + elems: ptr UncheckedArray[GT], expos: ptr UncheckedArray[BigInt[bits]], + N: int, c: static int) {.tags:[VarTime, HeapAlloc], meter.} = + ## Multiexponentiation: + ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ + + # Setup + # ----- + const numBuckets = 1 shl (c-1) + + let buckets = allocHeapArray(GT, numBuckets) + for i in 0 ..< numBuckets: + buckets[i].setNeutral() + + # Algorithm + # --------- + const excess = bits mod c + const top = bits - excess + var w = top + r.setNeutral() + + when top != 0: # Prologue + when excess != 0: + r.miniMultiExp(buckets, w, kTopWindow, c, elems, expos, N) + w -= c + else: + # If c divides bits exactly, the signed windowed recoding still needs to see an extra 0 + # Since we did r.setNeutral() earlier, this is a no-op + discard + + while w != 0: # Steady state + r.miniMultiExp(buckets, w, kFullWindow, c, elems, expos, N) + w -= c + + block: # Epilogue + r.miniMultiExp(buckets, w, kBottomWindow, c, elems, expos, N) + + # Cleanup + # ------- + buckets.freeHeap() + +# Endomorphism acceleration +# ----------------------------------------------------------------------------------------------------------------------- + +proc applyEndomorphism[bits: static int, GT]( + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[BigInt[bits]], + N: int): auto = + ## Decompose (elems, expos) into mini-scalars + ## Returns a new triplet (endoElems, endoexpos, N) + ## endoElems and endoexpos MUST be freed afterwards + + const M = when Gt is Fp6: 2 + elif Gt is Fp12: 4 + else: {.error: "Unconfigured".} + + const L = Fr[Gt.Name].bits().computeEndoRecodedLength(M) + let splitExpos = allocHeapArray(array[M, BigInt[L]], N) + let endoBasis = allocHeapArray(array[M, GT], N) + + for i in 0 ..< N: + var negatePoints {.noinit.}: array[M, SecretBool] + splitExpos[i].decomposeEndo(negatePoints, expos[i], Fr[Gt.Name].bits(), Gt.Name, G2) # ๐”พโ‚œ has same decomposition as ๐”พโ‚‚ + if negatePoints[0].bool: + endoBasis[i][0].cyclotomic_inv(elems[i]) + else: + endoBasis[i][0] = elems[i] + + cast[ptr array[M-1, GT]](endoBasis[i][1].addr)[].computeEndomorphisms(elems[i]) + for m in 1 ..< M: + if negatePoints[m].bool: + endoBasis[i][m].cyclotomic_inv() + + let endoElems = cast[ptr UncheckedArray[GT]](endoBasis) + let endoExpos = cast[ptr UncheckedArray[BigInt[L]]](splitExpos) + + return (endoElems, endoExpos, M*N) + +template withEndo[exponentsBits: static int, GT]( + multiExpProc: untyped, + r: var GT, + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[BigInt[exponentsBits]], + N: int, c: static int) = + when Gt.Name.hasEndomorphismAcceleration() and + EndomorphismThreshold <= exponentsBits and + exponentsBits <= Fr[Gt.Name].bits(): + let (endoElems, endoExpos, endoN) = applyEndomorphism(elems, expos, N) + # Given that bits and N changed, we are able to use a bigger `c` + # TODO: bench + multiExpProc(r, endoElems, endoExpos, endoN, c) + freeHeap(endoElems) + freeHeap(endoExpos) + else: + multiExpProc(r, elems, expos, N, c) + +# Algorithm selection +# ----------------------------------------------------------------------------------------------------------------------- + +func multiexp_dispatch_vartime[bits: static int, GT]( + r: var GT, + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[BigInt[bits]], N: int) = + ## Multiexponentiation: + ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ + let c = bestBucketBitSize(N, bits, useSignedBuckets = true, useManualTuning = true) + + # Given that bits and N change after applying an endomorphism, + # we are able to use a bigger `c` + # TODO: benchmark + + case c + of 2: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 2) + of 3: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 3) + of 4: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 4) + of 5: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 5) + of 6: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 6) + of 7: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 7) + of 8: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 8) + of 9: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 9) + of 10: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 10) + of 11: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 11) + of 12: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 12) + of 13: withEndo(multiExpImpl_vartime, r, elems, expos, N, c = 13) + of 14: multiExpImpl_vartime(r, elems, expos, N, c = 14) + of 15: multiExpImpl_vartime(r, elems, expos, N, c = 15) + + of 16..17: multiExpImpl_vartime(r, elems, expos, N, c = 16) + else: + unreachable() + +func multiExp_vartime*[bits: static int, GT]( + r: var GT, + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[BigInt[bits]], + len: int) {.tags:[VarTime, Alloca, HeapAlloc], meter.} = + ## Multiexponentiation: + ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ + multiExp_dispatch_vartime(r, elems, expos, len) + +func multiExp_vartime*[bits: static int, GT]( + r: var GT, + elems: openArray[GT], + expos: openArray[BigInt[bits]]) {.tags:[VarTime, Alloca, HeapAlloc], meter.} = + ## Multiexponentiation: + ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ + debug: doAssert elems.len == expos.len + let N = elems.len + multiExp_dispatch_vartime(r, elems.asUnchecked(), expos.asUnchecked(), N) + +func multiExp_vartime*[F, GT]( + r: var GT, + elems: ptr UncheckedArray[GT], + expos: ptr UncheckedArray[F], + len: int) {.tags:[VarTime, Alloca, HeapAlloc], meter.} = + ## Multiexponentiation: + ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ + let n = cast[int](len) + let expos_big = allocHeapArrayAligned(F.getBigInt(), n, alignment = 64) + expos_big.batchFromField(expos, n) + r.multiExp_vartime(elems, expos_big, n) + + freeHeapAligned(expos_big) + +func multiExp_vartime*[GT]( + r: var GT, + elems: openArray[GT], + expos: openArray[Fr]) {.tags:[VarTime, Alloca, HeapAlloc], inline.} = ## Multiexponentiation: ## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™ - debug: doAssert exponents.len == elems.len + debug: doAssert elems.len == expos.len let N = elems.len - multiExp_reference_vartime(r, elems.asUnchecked(), exponents.asUnchecked(), N) + multiExp_vartime(r, elems.asUnchecked(), expos.asUnchecked(), N) diff --git a/tests/math_pairings/t_pairing_template.nim b/tests/math_pairings/t_pairing_template.nim index 0df29200a..10c7debb9 100644 --- a/tests/math_pairings/t_pairing_template.nim +++ b/tests/math_pairings/t_pairing_template.nim @@ -249,10 +249,12 @@ proc runGTmultiexpTests*[N: static int](GT: typedesc, num_points: array[N, int], t.gtExp_vartime(elems[i], exponents[i]) naive *= t - var mexp_ref: GT + var mexp_ref, mexp_opt: GT mexp_ref.multiExp_reference_vartime(elems, exponents) + mexp_opt.multiExp_vartime(elems, exponents) doAssert bool(naive == mexp_ref) + doAssert bool(naive == mexp_opt) stdout.write '.'