Skip to content

Commit

Permalink
Torus-acceleration for multiexponentiation on GT (#485)
Browse files Browse the repository at this point in the history
* gt-torus: add Fp6 mul/sqr with Toom-C00k-3 + DFT

* initial support of torus-based crypto

* gt: add torus tests and benchmarks, make cyclotomic/pairing proc towering agnostic

* gt: batch conversion

* gt: stash progress, Fp12 over Fp6 fails ref or opt multiexp while Fp12 over Fp4 doesn't (without Torus)

* gt: add preliminary benchmarks for Torus based cryptography

* gt: fix exponentiation by 1 and GT torus conversion

* gt: fix aliasing issue in mixed torus multiplication

* gt: add torus optimization to optimized GT multiexp

* gt: combine endomorphism acceleration and Torus acceleration

* gt: parallel torus multiexp

* gt: enable endomorphism + torus

* gt: rework conversion

* test: add GT multiexp to test suite

* GT: fix memory leak

* windows: aligned alloc need explicit aligned dealloc
  • Loading branch information
mratsim authored Dec 1, 2024
1 parent 2d67670 commit bc3845a
Show file tree
Hide file tree
Showing 27 changed files with 1,836 additions and 257 deletions.
26 changes: 15 additions & 11 deletions benchmarks/bench_fields_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,30 @@ import
./bench_blueprint

export notes, abstractions
proc separator*() = separator(165)
proc separator*() = separator(145)
proc smallSeparator*() = separator(8)

proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
let ns = inNanoseconds((stop-start) div iters)
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<70} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
echo &"{op:<49} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
else:
echo &"{op:<70} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
echo &"{op:<49} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"

macro fixFieldDisplay(T: typedesc): untyped =
# At compile-time, enums are integers and their display is buggy
# we get the Curve ID instead of the curve name.
let instantiated = T.getTypeInst()
var name = $instantiated[1][0] # 𝔽p
name.add "[" & $Algebra(instantiated[1][1].intVal) & "]"
if instantiated[1][1].kind == nnkIntLit:
name.add "[" & $Algebra(instantiated[1][1].intVal) & "]"
else:
name.add "[" & $instantiated[1][1][0] # QuadraticExt[𝔽p6[
name.add "[" & $Algebra(instantiated[1][1][1].intVal) & "]]"
result = newLit name

template bench(op: string, T: typedesc, iters: int, body: untyped): untyped =
template bench*(op: string, T: typedesc, iters: int, body: untyped): untyped =
measure(iters, startTime, stopTime, startClk, stopClk, body)
report(op, fixFieldDisplay(T), startTime, stopTime, startClk, stopClk, iters)

Expand Down Expand Up @@ -184,10 +188,10 @@ proc sqrtBench*(T: typedesc, iters: int) =
"Tonelli-Shanks"
const addchain = block:
when T.Name.hasSqrtAddchain() or T.Name.hasTonelliShanksAddchain():
"with addition chain"
"+ addchain"
else:
"without addition chain"
const desc = "Square Root (constant-time " & algoType & " " & addchain & ")"
"no addchain"
const desc = "Sqrt (constant-time " & algoType & " " & addchain & ")"
bench(desc, T, iters):
var r = x
discard r.sqrt_if_square()
Expand All @@ -211,10 +215,10 @@ proc sqrtVartimeBench*(T: typedesc, iters: int) =
"Tonelli-Shanks"
const addchain = block:
when T.Name.hasSqrtAddchain() or T.Name.hasTonelliShanksAddchain():
"with addition chain"
"+ addchain"
else:
"without addition chain"
const desc = "Square Root (vartime " & algoType & " " & addchain & ")"
"no addchain"
const desc = "Sqrt (vartime " & algoType & " " & addchain & ")"
bench(desc, T, iters):
var r = x
discard r.sqrt_if_square_vartime()
Expand Down
14 changes: 11 additions & 3 deletions benchmarks/bench_gt_multiexp_bls12_381.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,25 @@ const AvailableCurves = [
BLS12_381,
]

const testNumPoints = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
# const testNumPoints = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
const testNumPoints = [128, 256]


type Fp12over4[C: static Algebra] = CubicExt[Fp4[C]]
type Fp12over6[C: static Algebra] = QuadraticExt[Fp6[C]]

proc main() =
separator()
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
var ctx = createBenchMultiExpContext(Fp12[curve], testNumPoints)
var ctx12o4 = createBenchMultiExpContext(Fp12over4[curve], testNumPoints)
var ctx12o6 = createBenchMultiExpContext(Fp12over6[curve], testNumPoints)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
ctx.multiExpParallelBench(numPoints, batchIters)
ctx12o4.multiExpParallelBench(numPoints, batchIters)
echo "----"
ctx12o6.multiExpParallelBench(numPoints, batchIters)
separator()
separator()

Expand Down
80 changes: 64 additions & 16 deletions benchmarks/bench_gt_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ proc report(op, domain: string, start, stop: MonoTime, startClk, stopClk: int64,
let ns = inNanoseconds((stop-start) div iters)
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<68} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
echo &"{op:<65} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
else:
echo &"{op:<68} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op"
echo &"{op:<65} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op"

macro fixFieldDisplay(T: typedesc): untyped =
# At compile-time, enums are integers and their display is buggy
Expand All @@ -52,7 +52,7 @@ macro fixFieldDisplay(T: typedesc): untyped =
result = newLit name

func fixDisplay(T: typedesc): string =
when T is (Fp or Fp2 or Fp4 or Fp6 or Fp12):
when T is (Fp or ExtensionField):
fixFieldDisplay(T)
else:
$T
Expand All @@ -68,7 +68,7 @@ func random_gt*(rng: var RngState, F: typedesc): F {.inline, noInit.} =
result = rng.random_unsafe(F)
result.finalExp()

# Multi-exponentiations
# multi-exp
# ---------------------------------------------------------------------------

type BenchMultiexpContext*[GT] = object
Expand Down Expand Up @@ -126,11 +126,19 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

var r{.noInit.}: GT
var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline: MonoTime
var startMultiExpOpt, stopMultiExpOpt, startMultiExpPara, stopMultiExpPara: MonoTime
var startMultiExpOpt, stopMultiExpOpt: MonoTime
var startMultiExpPara, stopMultiExpPara: MonoTime
var startMultiExpParaTorus, stopMultiExpParaTorus: MonoTime

when GT is QuadraticExt:
var startMultiExpBaselineTorus: MonoTime
var stopMultiExpBaselineTorus: MonoTime
var startMultiExpOptTorus: MonoTime
var stopMultiExpOptTorus: Monotime

if numInputs <= 100000:
# startNaive = getMonotime()
bench("𝔾ₜ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
bench("𝔾ₜ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
var tmp: GT
r.setOne()
for i in 0 ..< elems.len:
Expand All @@ -140,7 +148,7 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

if numInputs <= 100000:
startNaive = getMonotime()
bench("𝔾ₜ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
bench("𝔾ₜ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
var tmp: GT
r.setOne()
for i in 0 ..< elems.len:
Expand All @@ -150,30 +158,59 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

if numInputs <= 100000:
startMultiExpBaseline = getMonotime()
bench("𝔾ₜ multi-exponentiations baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_reference_vartime(elems, exponents)
bench("𝔾ₜ multi-exp baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_reference_vartime(elems, exponents, useTorus = false)
stopMultiExpBaseline = getMonotime()

if numInputs <= 100000:
when GT is QuadraticExt:
startMultiExpBaselineTorus = getMonotime()
bench("𝔾ₜ multi-exp baseline + torus" & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_reference_vartime(elems, exponents, useTorus = true)
stopMultiExpBaselineTorus = getMonotime()

block:
startMultiExpOpt = getMonotime()
bench("𝔾ₜ multi-exponentiations optimized " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_vartime(elems, exponents)
bench("𝔾ₜ multi-exp opt " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_vartime(elems, exponents, useTorus = false)
stopMultiExpOpt = getMonotime()

when GT is QuadraticExt:
block:
startMultiExpOptTorus = getMonotime()
bench("𝔾ₜ multi-exp opt + torus " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_vartime(elems, exponents, useTorus = true)
stopMultiExpOptTorus = getMonotime()

block:
ctx.tp = Threadpool.new()

startMultiExpPara = getMonotime()
bench("𝔾ₜ multi-exponentiations" & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
ctx.tp.multiExp_vartime_parallel(r, elems, exponents)
bench("𝔾ₜ multi-exp " & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
ctx.tp.multiExp_vartime_parallel(r, elems, exponents, useTorus = false)
stopMultiExpPara = getMonotime()

ctx.tp.shutdown()

when GT is QuadraticExt:
block:
ctx.tp = Threadpool.new()

startMultiExpParaTorus = getMonotime()
bench("𝔾ₜ multi-exp torus" & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
ctx.tp.multiExp_vartime_parallel(r, elems, exponents, useTorus = true)
stopMultiExpParaTorus = getMonotime()

ctx.tp.shutdown()

let perfNaive = inNanoseconds((stopNaive-startNaive) div iters)
let perfMultiExpBaseline = inNanoseconds((stopMultiExpBaseline-startMultiExpBaseline) div iters)
let perfMultiExpOpt = inNanoseconds((stopMultiExpOpt-startMultiExpOpt) div iters)
let perfMultiExpPara = inNanoseconds((stopMultiExpPara-startMultiExpPara) div iters)
when GT is QuadraticExt:
let perfMultiExpBaselineTorus = inNanoseconds((stopMultiExpBaselineTorus-startMultiExpBaselineTorus) div iters)
let perfMultiExpOptTorus = inNanoseconds((stopMultiExpOptTorus-startMultiExpOptTorus) div iters)
let perfMultiExpParaTorus = inNanoSeconds((stopMultiExpParaTorus-startMultiExpParaTorus) div iters)

if numInputs <= 100000:
let speedupBaseline = float(perfNaive) / float(perfMultiExpBaseline)
Expand All @@ -182,8 +219,19 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in
let speedupOpt = float(perfNaive) / float(perfMultiExpOpt)
echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x"

let speedupOptBaseline = float(perfMultiExpBaseline) / float(perfMultiExpOpt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
when GT is QuadraticExt:
let speedupTorusOverBaseline = float(perfMultiExpBaseline) / float(perfMultiExpBaselineTorus)
echo &"Speedup ratio baseline + Torus over baseline linear combination: {speedupTorusOverBaseline:>6.3f}x"

let speedupTorusOverOpt = float(perfMultiExpOpt) / float(perfMultiExpOptTorus)
echo &"Speedup ratio optimized + Torus over optimized: {speedupTorusOverOpt:>6.3f}x"

let speedupParaOpt = float(perfMultiExpOpt) / float(perfMultiExpPara)
echo &"Speedup ratio parallel over optimized linear combination: {speedupParaOpt:>6.3f}x"
echo &"Speedup ratio parallel over serial optimized linear combination: {speedupParaOpt:>6.3f}x"

when GT is QuadraticExt:
let speedupParaTorus = float(perfMultiExpOptTorus) / float(perfMultiExpParaTorus)
echo &"Speedup ratio parallel over serial for Torus-based multiexp: {speedupParaTorus:>6.3f}x"

let speedupParaTorusOpt = float(perfMultiExpPara) / float(perfMultiExpParaTorus)
echo &"Speedup ratio parallel over parallel Torus-based multiexp: {speedupParaTorusOpt:>6.3f}x"
Loading

0 comments on commit bc3845a

Please sign in to comment.