Skip to content

Commit

Permalink
feat(gt-multiexp): add optimized multi-exponentiation in π”Ύβ‚œ
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jul 16, 2024
1 parent 8fb61c6 commit 874f8d8
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 84 deletions.
File renamed without changes.
1 change: 1 addition & 0 deletions benchmarks/bench_ec_msm_bandersnatch.nim.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--threads:on
1 change: 1 addition & 0 deletions benchmarks/bench_ec_msm_bls12_381_g1.nim.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--threads:on
1 change: 1 addition & 0 deletions benchmarks/bench_ec_msm_bls12_381_g2.nim.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--threads:on
1 change: 1 addition & 0 deletions benchmarks/bench_ec_msm_bn254_snarks_g1.nim.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--threads:on
1 change: 1 addition & 0 deletions benchmarks/bench_ec_msm_pasta.nim.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--threads:on
24 changes: 18 additions & 6 deletions benchmarks/bench_gt_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in


var r{.noInit.}: GT
var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline: MonoTime
var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline, startMultiExpOpt, stopMultiExpOpt: MonoTime

if numInputs <= 100000:
# startNaive = getMonotime()
bench("π”Ύβ‚œ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
bench("π”Ύβ‚œ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
var tmp: GT
r.setOne()
for i in 0 ..< elems.len:
Expand All @@ -139,7 +139,7 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

if numInputs <= 100000:
startNaive = getMonotime()
bench("π”Ύβ‚œ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
bench("π”Ύβ‚œ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
var tmp: GT
r.setOne()
for i in 0 ..< elems.len:
Expand All @@ -149,14 +149,26 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

if numInputs <= 100000:
startMultiExpBaseline = getMonotime()
bench("π”Ύβ‚œ multi-exponentiations baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
bench("π”Ύβ‚œ multi-exponentiations baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_reference_vartime(elems, exponents)
stopMultiExpBaseline = getMonotime()

block:
startMultiExpOpt = getMonotime()
bench("π”Ύβ‚œ multi-exponentiations optimized " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_vartime(elems, exponents)
stopMultiExpOpt = getMonotime()

let perfNaive = inNanoseconds((stopNaive-startNaive) div iters)
let perfMSMbaseline = inNanoseconds((stopMultiExpBaseline-startMultiExpBaseline) div iters)
let perfMultiExpBaseline = inNanoseconds((stopMultiExpBaseline-startMultiExpBaseline) div iters)
let perfMultiExpOpt = inNanoseconds((stopMultiExpOpt-startMultiExpOpt) div iters)

if numInputs <= 100000:
let speedupBaseline = float(perfNaive) / float(perfMSMbaseline)
let speedupBaseline = float(perfNaive) / float(perfMultiExpBaseline)
echo &"Speedup ratio baseline over naive linear combination: {speedupBaseline:>6.3f}x"

let speedupOpt = float(perfNaive) / float(perfMultiExpOpt)
echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x"

let speedupOptBaseline = float(perfMultiExpBaseline) / float(perfMultiExpOpt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
89 changes: 49 additions & 40 deletions constantine/math/elliptic/ec_multi_scalar_mul.nim
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func miniMSM[bits: static int, EC, ECaff](
for _ in 0 ..< c:
r.double()

func multiScalarMul_vartime*[bits: static int, EC, ECaff](
func msmImpl_vartime[bits: static int, EC, ECaff](
r: var EC,
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECaff],
N: int, c: static int) {.tags:[VarTime, HeapAlloc], meter.} =
Expand Down Expand Up @@ -295,6 +295,9 @@ func multiScalarMul_vartime*[bits: static int, EC, ECaff](
# -------
buckets.freeHeap()

# Multi scalar multiplication with batched affine additions
# -----------------------------------------------------------------------------------------------------------------------

func schedAccumulate*[NumBuckets, QueueLen, F, G; bits: static int](
sched: ptr Scheduler[NumBuckets, QueueLen, F, G],
bitIndex: int, miniMsmKind: static MiniMsmKind, c: static int,
Expand Down Expand Up @@ -344,7 +347,7 @@ func miniMSM_affine[NumBuckets, QueueLen, EC, ECaff; bits: static int](
for _ in 0 ..< c:
r.double()

func multiScalarMulAffine_vartime[bits: static int, EC, ECaff](
func msmAffineImpl_vartime[bits: static int, EC, ECaff](
r: var EC,
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[ECaff],
N: int, c: static int) {.tags:[VarTime, Alloca, HeapAlloc], meter.} =
Expand Down Expand Up @@ -389,6 +392,9 @@ func multiScalarMulAffine_vartime[bits: static int, EC, ECaff](
sched.freeHeap()
buckets.freeHeap()

# Endomorphism acceleration
# -----------------------------------------------------------------------------------------------------------------------

proc applyEndomorphism[bits: static int, ECaff](
coefs: ptr UncheckedArray[BigInt[bits]],
points: ptr UncheckedArray[ECaff],
Expand All @@ -403,7 +409,7 @@ proc applyEndomorphism[bits: static int, ECaff](
const G = when ECaff isnot EC_ShortW_Aff: G1
else: ECaff.G

const L = ECaff.getScalarField().bits().ceilDiv_vartime(M) + 1
const L = ECaff.getScalarField().bits().computeEndoRecodedLength(M)
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
let endoBasis = allocHeapArray(array[M, ECaff], N)

Expand Down Expand Up @@ -447,7 +453,10 @@ template withEndo[coefsBits: static int, EC, ECaff](
else:
msmProc(r, coefs, points, N, c)

func multiScalarMul_dispatch_vartime[bits: static int, F, G](
# Algorithm selection
# -----------------------------------------------------------------------------------------------------------------------

func msm_dispatch_vartime[bits: static int, F, G](
r: var (EC_ShortW_Jac[F, G] or EC_ShortW_Prj[F, G]),
coefs: ptr UncheckedArray[BigInt[bits]],
points: ptr UncheckedArray[EC_ShortW_Aff[F, G]], N: int) =
Expand All @@ -460,27 +469,27 @@ func multiScalarMul_dispatch_vartime[bits: static int, F, G](
# but it has no significant impact on performance

case c
of 2: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 2)
of 3: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 3)
of 4: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 4)
of 5: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 5)
of 6: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 6)
of 7: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 7)
of 8: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 8)

of 9: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 9)
of 10: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 10)
of 11: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 11)
of 12: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 12)
of 13: withEndo(multiScalarMulAffine_vartime, r, coefs, points, N, c = 13)
of 14: multiScalarMulAffine_vartime(r, coefs, points, N, c = 14)
of 15: multiScalarMulAffine_vartime(r, coefs, points, N, c = 15)

of 16..17: multiScalarMulAffine_vartime(r, coefs, points, N, c = 16)
of 2: withEndo(msmImpl_vartime, r, coefs, points, N, c = 2)
of 3: withEndo(msmImpl_vartime, r, coefs, points, N, c = 3)
of 4: withEndo(msmImpl_vartime, r, coefs, points, N, c = 4)
of 5: withEndo(msmImpl_vartime, r, coefs, points, N, c = 5)
of 6: withEndo(msmImpl_vartime, r, coefs, points, N, c = 6)
of 7: withEndo(msmImpl_vartime, r, coefs, points, N, c = 7)
of 8: withEndo(msmImpl_vartime, r, coefs, points, N, c = 8)

of 9: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 9)
of 10: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 10)
of 11: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 11)
of 12: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 12)
of 13: withEndo(msmAffineImpl_vartime, r, coefs, points, N, c = 13)
of 14: msmAffineImpl_vartime(r, coefs, points, N, c = 14)
of 15: msmAffineImpl_vartime(r, coefs, points, N, c = 15)

of 16..17: msmAffineImpl_vartime(r, coefs, points, N, c = 16)
else:
unreachable()

func multiScalarMul_dispatch_vartime[bits: static int, F](
func msm_dispatch_vartime[bits: static int, F](
r: var EC_TwEdw_Prj[F], coefs: ptr UncheckedArray[BigInt[bits]],
points: ptr UncheckedArray[EC_TwEdw_Aff[F]], N: int) =
## Multiscalar multiplication:
Expand All @@ -494,22 +503,22 @@ func multiScalarMul_dispatch_vartime[bits: static int, F](
# but it has no significant impact on performance

case c
of 2: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 2)
of 3: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 3)
of 4: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 4)
of 5: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 5)
of 6: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 6)
of 7: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 7)
of 8: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 8)
of 9: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 9)
of 10: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 10)
of 11: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 11)
of 12: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 12)
of 13: withEndo(multiScalarMul_vartime, r, coefs, points, N, c = 13)
of 14: multiScalarMul_vartime(r, coefs, points, N, c = 14)
of 15: multiScalarMul_vartime(r, coefs, points, N, c = 15)

of 16..17: multiScalarMul_vartime(r, coefs, points, N, c = 16)
of 2: withEndo(msmImpl_vartime, r, coefs, points, N, c = 2)
of 3: withEndo(msmImpl_vartime, r, coefs, points, N, c = 3)
of 4: withEndo(msmImpl_vartime, r, coefs, points, N, c = 4)
of 5: withEndo(msmImpl_vartime, r, coefs, points, N, c = 5)
of 6: withEndo(msmImpl_vartime, r, coefs, points, N, c = 6)
of 7: withEndo(msmImpl_vartime, r, coefs, points, N, c = 7)
of 8: withEndo(msmImpl_vartime, r, coefs, points, N, c = 8)
of 9: withEndo(msmImpl_vartime, r, coefs, points, N, c = 9)
of 10: withEndo(msmImpl_vartime, r, coefs, points, N, c = 10)
of 11: withEndo(msmImpl_vartime, r, coefs, points, N, c = 11)
of 12: withEndo(msmImpl_vartime, r, coefs, points, N, c = 12)
of 13: withEndo(msmImpl_vartime, r, coefs, points, N, c = 13)
of 14: msmImpl_vartime(r, coefs, points, N, c = 14)
of 15: msmImpl_vartime(r, coefs, points, N, c = 15)

of 16..17: msmImpl_vartime(r, coefs, points, N, c = 16)
else:
unreachable()

Expand All @@ -521,7 +530,7 @@ func multiScalarMul_vartime*[bits: static int, EC, ECaff](
## Multiscalar multiplication:
## r <- [aβ‚€]Pβ‚€ + [a₁]P₁ + ... + [aₙ₋₁]Pₙ₋₁

multiScalarMul_dispatch_vartime(r, coefs, points, len)
msm_dispatch_vartime(r, coefs, points, len)

func multiScalarMul_vartime*[bits: static int, EC, ECaff](
r: var EC,
Expand All @@ -531,7 +540,7 @@ func multiScalarMul_vartime*[bits: static int, EC, ECaff](
## r <- [aβ‚€]Pβ‚€ + [a₁]P₁ + ... + [aₙ₋₁]Pₙ₋₁
debug: doAssert coefs.len == points.len
let N = points.len
multiScalarMul_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)
msm_dispatch_vartime(r, coefs.asUnchecked(), points.asUnchecked(), N)

func multiScalarMul_vartime*[F, EC, ECaff](
r: var EC,
Expand Down
Loading

0 comments on commit 874f8d8

Please sign in to comment.