Skip to content

Commit

Permalink
feat(gt-multiexp): add parallel multi-exponentiation in ๐”พโ‚œ
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jul 17, 2024
1 parent 874f8d8 commit 242798f
Show file tree
Hide file tree
Showing 20 changed files with 469 additions and 154 deletions.
19 changes: 17 additions & 2 deletions benchmarks/bench_gt_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import
pairings_generic,
gt_exponentiations,
gt_exponentiations_vartime,
gt_multiexp
gt_multiexp, gt_multiexp_parallel,
],
constantine/threadpool,
# Helpers
Expand Down Expand Up @@ -125,7 +125,8 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in


var r{.noInit.}: GT
var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline, startMultiExpOpt, stopMultiExpOpt: MonoTime
var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline: MonoTime
var startMultiExpOpt, stopMultiExpOpt, startMultiExpPara, stopMultiExpPara: MonoTime

if numInputs <= 100000:
# startNaive = getMonotime()
Expand Down Expand Up @@ -159,9 +160,20 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in
r.multiExp_vartime(elems, exponents)
stopMultiExpOpt = getMonotime()

block:
ctx.tp = Threadpool.new()

startMultiExpPara = getMonotime()
bench("๐”พโ‚œ multi-exponentiations" & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
ctx.tp.multiExp_vartime_parallel(r, elems, exponents)
stopMultiExpPara = getMonotime()

ctx.tp.shutdown()

let perfNaive = inNanoseconds((stopNaive-startNaive) div iters)
let perfMultiExpBaseline = inNanoseconds((stopMultiExpBaseline-startMultiExpBaseline) div iters)
let perfMultiExpOpt = inNanoseconds((stopMultiExpOpt-startMultiExpOpt) div iters)
let perfMultiExpPara = inNanoseconds((stopMultiExpPara-startMultiExpPara) div iters)

if numInputs <= 100000:
let speedupBaseline = float(perfNaive) / float(perfMultiExpBaseline)
Expand All @@ -172,3 +184,6 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

let speedupOptBaseline = float(perfMultiExpBaseline) / float(perfMultiExpOpt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"

let speedupParaOpt = float(perfMultiExpOpt) / float(perfMultiExpPara)
echo &"Speedup ratio parallel over optimized linear combination: {speedupParaOpt:>6.3f}x"
58 changes: 27 additions & 31 deletions constantine/math/elliptic/ec_multi_scalar_mul_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import constantine/named/algebras,
constantine/math/endomorphisms/split_scalars,
constantine/math/extension_fields,
constantine/named/zoo_endomorphisms,
../../threadpool/[threadpool, partitioners]
constantine/threadpool/[threadpool, partitioners]
export bestBucketBitSize

# No exceptions allowed in core cryptographic operations
Expand Down Expand Up @@ -145,7 +145,7 @@ proc bucketAccumReduce_withInit[bits: static int, EC, ECaff](
buckets[i].setNeutral()
bucketAccumReduce(windowSum[], buckets, bitIndex, miniMsmKind, c, coefs, points, N)

proc msm_vartime_parallel[bits: static int, EC, ECaff](
proc msmImpl_vartime_parallel[bits: static int, EC, ECaff](
tp: Threadpool,
r: ptr EC,
coefs: ptr UncheckedArray[BigInt[bits]], points: ptr UncheckedArray[EC_aff],
Expand Down Expand Up @@ -465,7 +465,7 @@ proc applyEndomorphism_parallel[bits: static int, ECaff](
const G = when ECaff isnot EC_ShortW_Aff: G1
else: ECaff.G

const L = ECaff.getScalarField().bits().ceilDiv_vartime(M) + 1
const L = ECaff.getScalarField().bits().computeEndoRecodedLength(M)
let splitCoefs = allocHeapArray(array[M, BigInt[L]], N)
let endoBasis = allocHeapArray(array[M, ECaff], N)

Expand Down Expand Up @@ -544,14 +544,14 @@ proc multiScalarMul_dispatch_vartime_parallel[bits: static int, F, G](
# but it has no significant impact on performance

case c
of 2: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 2)
of 3: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 3)
of 4: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 4)
of 5: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 5)
of 6: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 6)
of 2: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 2)
of 3: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 3)
of 4: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 4)
of 5: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 5)
of 6: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 6)

of 7: msm_vartime_parallel(tp, r, coefs, points, N, c = 7)
of 8: msm_vartime_parallel(tp, r, coefs, points, N, c = 8)
of 7: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 7)
of 8: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 8)

of 9: withEndo(msmAffine_vartime_parallel_split, tp, r, coefs, points, N, c = 9, useParallelBuckets = true)
of 10: withEndo(msmAffine_vartime_parallel_split, tp, r, coefs, points, N, c = 10, useParallelBuckets = true)
Expand Down Expand Up @@ -579,23 +579,23 @@ proc multiScalarMul_dispatch_vartime_parallel[bits: static int, F](
# but it has no significant impact on performance

case c
of 2: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 2)
of 3: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 3)
of 4: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 4)
of 5: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 5)
of 6: withEndo(msm_vartime_parallel, tp, r, coefs, points, N, c = 6)

of 7: msm_vartime_parallel(tp, r, coefs, points, N, c = 7)
of 8: msm_vartime_parallel(tp, r, coefs, points, N, c = 8)
of 9: msm_vartime_parallel(tp, r, coefs, points, N, c = 9)
of 10: msm_vartime_parallel(tp, r, coefs, points, N, c = 10)
of 11: msm_vartime_parallel(tp, r, coefs, points, N, c = 11)
of 12: msm_vartime_parallel(tp, r, coefs, points, N, c = 12)
of 13: msm_vartime_parallel(tp, r, coefs, points, N, c = 13)
of 14: msm_vartime_parallel(tp, r, coefs, points, N, c = 14)
of 15: msm_vartime_parallel(tp, r, coefs, points, N, c = 16)

of 16..17: msm_vartime_parallel(tp, r, coefs, points, N, c = 16)
of 2: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 2)
of 3: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 3)
of 4: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 4)
of 5: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 5)
of 6: withEndo(msmImpl_vartime_parallel, tp, r, coefs, points, N, c = 6)

of 7: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 7)
of 8: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 8)
of 9: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 9)
of 10: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 10)
of 11: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 11)
of 12: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 12)
of 13: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 13)
of 14: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 14)
of 15: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 16)

of 16..17: msmImpl_vartime_parallel(tp, r, coefs, points, N, c = 16)
else:
unreachable()

Expand All @@ -620,7 +620,6 @@ proc multiScalarMul_vartime_parallel*[bits: static int, EC, ECaff](
## This function can be nested in another parallel function
debug: doAssert coefs.len == points.len
let N = points.len

tp.multiScalarMul_dispatch_vartime_parallel(r.addr, coefs.asUnchecked(), points.asUnchecked(), N)

proc multiScalarMul_vartime_parallel*[F, EC, ECaff](
Expand All @@ -631,7 +630,6 @@ proc multiScalarMul_vartime_parallel*[F, EC, ECaff](
len: int) {.meter.} =
## Multiscalar multiplication:
## r <- [aโ‚€]Pโ‚€ + [aโ‚]Pโ‚ + ... + [aโ‚™โ‚‹โ‚]Pโ‚™โ‚‹โ‚

let n = cast[int](len)
let coefs_big = allocHeapArrayAligned(F.getBigInt(), n, alignment = 64)

Expand All @@ -650,8 +648,6 @@ proc multiScalarMul_vartime_parallel*[EC, ECaff](
points: openArray[ECaff]) {.inline.} =
## Multiscalar multiplication:
## r <- [aโ‚€]Pโ‚€ + [aโ‚]Pโ‚ + ... + [aโ‚™โ‚‹โ‚]Pโ‚™โ‚‹โ‚

debug: doAssert coefs.len == points.len
let N = points.len

tp.multiScalarMul_vartime_parallel(r.addr, coefs.asUnchecked(), points.asUnchecked(), N)
34 changes: 13 additions & 21 deletions constantine/math/pairings/gt_multiexp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy Andrรฉ-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import constantine/named/algebras,
constantine/math/endomorphisms/split_scalars,
constantine/math/extension_fields,
Expand All @@ -35,7 +27,7 @@ import constantine/named/algebras,
# General utilities
# -------------------------------------------------------------

func bestBucketBitSize*(inputSize: int, scalarBitwidth: static int, useSignedBuckets, useManualTuning: static bool): int {.inline.} =
func bestBucketBitSize(inputSize: int, scalarBitwidth: static int, useSignedBuckets, useManualTuning: static bool): int {.inline.} =
## Evaluate the best bucket bit-size for the input size.
## That bucket size minimize group operations.
## This ignore cache effect. Computation can become memory-bound, especially with large buckets
Expand Down Expand Up @@ -86,7 +78,7 @@ func bestBucketBitSize*(inputSize: int, scalarBitwidth: static int, useSignedBuc
if 13 <= result:
result -= 1

func `~*=`*[Gt: ExtensionField](a: var Gt, b: Gt) {.inline.} =
func `~*=`[Gt: ExtensionField](a: var Gt, b: Gt) {.inline.} =
# TODO: Analyze the inputs to see if there is avalue in more complex shortcuts (-1, or partial 0 coordinates)
if a.isOne().bool():
a = b
Expand All @@ -95,13 +87,13 @@ func `~*=`*[Gt: ExtensionField](a: var Gt, b: Gt) {.inline.} =
else:
a *= b

func `~/=`*[Gt: ExtensionField](a: var Gt, b: Gt) {.inline.} =
func `~/=`[Gt: ExtensionField](a: var Gt, b: Gt) {.inline.} =
## Cyclotomic division
var t {.noInit.}: Gt
t.cyclotomic_inv(b)
a ~*= t

func setNeutral*[Gt: ExtensionField](a: var Gt) {.inline.} =
func setNeutral[Gt: ExtensionField](a: var Gt) {.inline.} =
a.setOne()

# Reference multi-exponentiation
Expand Down Expand Up @@ -269,7 +261,7 @@ type MiniMultiExpKind* = enum
kFullWindow
kBottomWindow

func bucketAccumReduce*[bits: static int, GT](
func bucketAccumReduce[bits: static int, GT](
r: var GT,
buckets: ptr UncheckedArray[GT],
bitIndex: int, miniMultiExpKind: static MiniMultiExpKind, c: static int,
Expand Down Expand Up @@ -371,8 +363,8 @@ proc applyEndomorphism[bits: static int, GT](
expos: ptr UncheckedArray[BigInt[bits]],
N: int): auto =
## Decompose (elems, expos) into mini-scalars
## Returns a new triplet (endoElems, endoexpos, N)
## endoElems and endoexpos MUST be freed afterwards
## Returns a new triplet (endoElems, endoExpos, N)
## endoElems and endoExpos MUST be freed afterwards

const M = when Gt is Fp6: 2
elif Gt is Fp12: 4
Expand All @@ -383,16 +375,16 @@ proc applyEndomorphism[bits: static int, GT](
let endoBasis = allocHeapArray(array[M, GT], N)

for i in 0 ..< N:
var negatePoints {.noinit.}: array[M, SecretBool]
splitExpos[i].decomposeEndo(negatePoints, expos[i], Fr[Gt.Name].bits(), Gt.Name, G2) # ๐”พโ‚œ has same decomposition as ๐”พโ‚‚
if negatePoints[0].bool:
var negateElems {.noinit.}: array[M, SecretBool]
splitExpos[i].decomposeEndo(negateElems, expos[i], Fr[Gt.Name].bits(), Gt.Name, G2) # ๐”พโ‚œ has same decomposition as ๐”พโ‚‚
if negateElems[0].bool:
endoBasis[i][0].cyclotomic_inv(elems[i])
else:
endoBasis[i][0] = elems[i]

cast[ptr array[M-1, GT]](endoBasis[i][1].addr)[].computeEndomorphisms(elems[i])
for m in 1 ..< M:
if negatePoints[m].bool:
if negateElems[m].bool:
endoBasis[i][m].cyclotomic_inv()

let endoElems = cast[ptr UncheckedArray[GT]](endoBasis)
Expand Down Expand Up @@ -457,15 +449,15 @@ func multiExp_vartime*[bits: static int, GT](
r: var GT,
elems: ptr UncheckedArray[GT],
expos: ptr UncheckedArray[BigInt[bits]],
len: int) {.tags:[VarTime, Alloca, HeapAlloc], meter.} =
len: int) {.tags:[VarTime, Alloca, HeapAlloc], meter, inline.} =
## Multiexponentiation:
## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™
multiExp_dispatch_vartime(r, elems, expos, len)

func multiExp_vartime*[bits: static int, GT](
r: var GT,
elems: openArray[GT],
expos: openArray[BigInt[bits]]) {.tags:[VarTime, Alloca, HeapAlloc], meter.} =
expos: openArray[BigInt[bits]]) {.tags:[VarTime, Alloca, HeapAlloc], meter, inline.} =
## Multiexponentiation:
## r <- gโ‚€^aโ‚€ + gโ‚^aโ‚ + ... + gโ‚™^aโ‚™
debug: doAssert elems.len == expos.len
Expand Down
Loading

0 comments on commit 242798f

Please sign in to comment.