Skip to content

Commit

Permalink
feat(multilinear-extension): initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jul 4, 2024
1 parent dcc9310 commit 01c039c
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 1 deletion.
65 changes: 65 additions & 0 deletions benchmarks/bench_mle.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import
constantine/boolean_hypercube/multilinear_extensions,
constantine/named/algebras,
constantine/math/arithmetic,
constantine/math/io/io_fields,
helpers/prng_unsafe,
benchmarks/bench_blueprint,
std/macros

var rng*: RngState
let seed = 1234
rng.seed(seed)
echo "bench xoshiro512** seed: ", seed

proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
let ns = inNanoseconds((stop-start) div iters)
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<70} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
else:
echo &"{op:<70} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"

macro fixFieldDisplay(T: typedesc): untyped =
# At compile-time, enums are integers and their display is buggy
# we get the Curve ID instead of the curve name.
let instantiated = T.getTypeInst()
var name = $instantiated[1][0] # 𝔽p
name.add "[" & $Algebra(instantiated[1][1].intVal) & "]"
result = newLit name

template bench(op: string, T: typedesc, iters: int, body: untyped): untyped =
measure(iters, startTime, stopTime, startClk, stopClk, body)
report(op, fixFieldDisplay(T), startTime, stopTime, startClk, stopClk, iters)

func toFp[N: static int](a: array[N, SomeUnsignedInt], Name: static Algebra): array[N, Fp[Name]] =
for i in 0 ..< N:
result[i] = Fp[Name].fromUint(a[i])

proc bench_thaler() =
var evals = [uint32 1, 2, 1, 4].toFp(F5)
let mle = MultilinearExtension[Fp[F5]].new(2, evals)

var r: Fp[F5]
bench("Multilinear Extension: Evaluate", Fp[F5], 100):
r.evalMultilinearExtensionAt_reference(mle, [Fp[F5].fromUint(0'u), Fp[F5].fromUint(0'u)])

proc bench_arkworks(num_vars: int) =
var evals = newSeq[Fr[BLS12_381]](1 shl num_vars)
for eval in evals.mitems():
eval = rng.random_unsafe(Fr[BLS12_381])

let mle = MultilinearExtension[Fr[BLS12_381]].new(num_vars, evals)

var coords = newSeq[Fr[BLS12_381]](num_vars)
for coord in coords.mitems():
coord = rng.random_unsafe(Fr[BLS12_381])

var r: Fr[BLS12_381]
bench("Multilinear Extension: Evaluate/" & $num_vars, Fr[BLS12_381], 100):
r.evalMultilinearExtensionAt_reference(mle, coords)

bench_thaler()
bench_arkworks(10)
bench_arkworks(11)
bench_arkworks(12)
2 changes: 1 addition & 1 deletion constantine.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ proc setupTestCommand(flags, path: string): string =
if existsEnv"TEST_LANG":
lang = getEnv"TEST_LANG"

return "nim " & lang &
return "../Nim/bin/nim " & lang &
" -r " &
flags &
releaseBuildOptions() &
Expand Down
7 changes: 7 additions & 0 deletions constantine/boolean_hypercube/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Boolean Hypercube

This folder holds utilities to work with the Boolean Hypercube,
via MLE, Multilinear Extension of polynomials evaluated at {0, 1}ⁿ.

This is a heavy area of experimentation to nail down useful software architecture.
Expect many refactorings.
95 changes: 95 additions & 0 deletions constantine/boolean_hypercube/multilinear_extensions.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/[abstractions, allocs]

# Resources:
# - https://people.cs.georgetown.edu/jthaler/IPsandextensions.pdf
# - https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf
# Chapter 3.5

type
MultilinearExtension*[F] = object
## Note: to follow mathematical description, indices start at 1 and end of range is inclusive
## actual implementation will have indices start at 0
##
## Given a sequence of bits of size s {0,1}ˢ
## and an arbitrary function f: {0,1}ˢ -> 𝔽
## i.e. that maps a sequence of bits to a finite field 𝔽
##
## there is an unique multilinear polynomial f̃
## called multilinear extension of f
## that satisfies f̃(xᵢ) = f(xᵢ) for xᵢ ∈ {0,1}ˢ
##
## f̃(x₁, ...,xₛ) = ∑ₑ f(e) 𝛘ₑ(x₁, ...,xₛ)
## with e ∈ {0,1}ˢ and f(e) the evaluation of f at e.
##
## 𝛘ₑ(x₁, ...,xₛ) is the multilinear Lagrange basis polynomial
## which takes value 1 at 𝛘ₑ(e) and 0 at Xₑ(e̅) e̅ any other element ∈ {0,1}ˢ
##
## 𝛘ₑ(x₁, ...,xₛ) = ∏ᵢ₌₁ˢ(xᵢeᵢ + (1-xᵢ)(1-eᵢ)), i ∈ [1, s]
##
## A multilinear polynomial is linear (i.e. degree at most 1) in each
## of its variables.
num_vars*: int
base_poly_evals*: ptr UncheckedArray[F]

proc `=destroy`*[F](mle: var MultilinearExtension[F]) =
if not mle.base_poly_evals.isNil:
freeHeapAligned(mle.base_poly_evals)

func new*[F](T: type MultilinearExtension[F], num_vars: int, poly_evals: openArray[F]): T =
debug:
doAssert 1 shl num_vars == poly_evals.len, block:
"The MLE has " & $num_vars " variables\n" &
"but the poly it's derived from has " & $poly_evals.len & " evaluations.\n" &
"2^" & $num_vars & " = " & $(1 shl num_vars) & " were expected instead."
result.num_vars = num_vars
let L = 1 shl num_vars
result.base_poly_evals = allocHeapArrayAligned(F, L, alignment = 64)
for i in 0 ..< L:
result.base_poly_evals[i] = poly_evals[i]

iterator bits_be(n: SomeInteger, len: int): (int, bool) =
for i in 0 ..< len:
yield (i, bool((n shr (len-1-i) and 1)))

func evalMultilinearExtensionAt_reference*[F](
r: var F,
mle: MultilinearExtension[F],
xs: openArray[F]) =
## Compute
## f̃(x₁, ...,xₛ) = ∑ₑ f(e) ∏ᵢ(xᵢeᵢ + (1-xᵢ)(1-eᵢ))
## at supplied (x₁, ...,xₛ)
##
## This is a reference implementation using naive computation
## in O(n log n) with n being the numb
debug: doAssert mle.num_vars == coords.len

let L = 1 shl mle.num_vars

r.setZero()
for e in 0 ..< L:
# 𝛘ₑ(x₁, ...,xₛ) = ∏ᵢ(xᵢeᵢ + (1-xᵢ)(1-eᵢ))
# e ∈ {0,1}ˢ hence each factor is either:
# (1-xᵢ) or xᵢ
var chi_e {.noInit.}: F
chi_e.setOne()

for (i, ei) in bits_be(e, mle.num_vars):
if ei:
chi_e *= xs[i]
else:
var t {.noInit.}: F
t.diff(F.getOne(), xs[i])
chi_e *= t

var t {.noInit.}: F
t.prod(mle.base_poly_evals[e], chi_e)
r += t
4 changes: 4 additions & 0 deletions constantine/named/config_fields_and_curves.nim
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ export CurveFamily, SexticTwist
declareCurves:
# -----------------------------------------------------------------------------
# Curves added when passed "-d:CTT_TEST_CURVES"
curve F5:
testingCurve: true
bitwidth: 3
modulus: "0x5"
curve Fake101:
testingCurve: true
bitwidth: 7
Expand Down
41 changes: 41 additions & 0 deletions tests/math_interactive_proofs/t_mle.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import
constantine/boolean_hypercube/multilinear_extensions,
constantine/named/algebras,
constantine/math/arithmetic,
constantine/math/io/io_fields,
helpers/prng_unsafe

# Compile with -d:CTT_TEST_CURVES to define F5

func toF5[N: static int](a: array[N, SomeUnsignedInt]): array[N, Fp[F5]] =
for i in 0 ..< N:
result[i] = Fp[F5].fromUint(a[i])

# - https://people.cs.georgetown.edu/jthaler/IPsandextensions.pdf\
# Note: first row is
# 1 2 3 4 0 not 1 2 3 4 5 (though 5 ≡ 0 (mod 5) so arguably not wrong)
# - https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf
# Chapter 3.5

let evals = [uint32 1, 2, 1, 4].toF5()
let mle_evals = [
[byte 1, 2, 3, 4, 0],
[byte 1, 4, 2, 0, 3],
[byte 1, 1, 1, 1, 1],
[byte 1, 3, 0, 2, 4],
[byte 1, 0, 4, 3, 2],
]

let mle = MultilinearExtension[Fp[F5]].new(2, evals)

for i in 0'u32 .. 4:
var row: array[5, byte]
for j in 0'u32 .. 4:
var r: Fp[F5]
r.evalMultilinearExtensionAt_reference(mle, [Fp[F5].fromUint(i), Fp[F5].fromUint(j)])
var buf: array[1, byte]
buf.marshal(r, bigEndian)
row[j] = buf[0]

echo row
doAssert row == mle_evals[i]

0 comments on commit 01c039c

Please sign in to comment.