Skip to content

Commit

Permalink
feat(special primes accel): Support Crandall primes / Pseudo-Mersenne…
Browse files Browse the repository at this point in the history
… Prime fast reduction - closes #11
  • Loading branch information
mratsim committed Jul 25, 2024
1 parent 8b70dc1 commit 67a4b2c
Show file tree
Hide file tree
Showing 10 changed files with 413 additions and 95 deletions.
22 changes: 12 additions & 10 deletions benchmarks/bench_fp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,24 @@ proc main() =
sqr2xUnrBench(Fp[curve], Iters)
rdc2xBench(Fp[curve], Iters)
smallSeparator()
sumprodBench(Fp[curve], Iters)
smallSeparator()
when not Fp[curve].isCrandallPrimeField():
sumprodBench(Fp[curve], Iters)
smallSeparator()
toBigBench(Fp[curve], Iters)
toFieldBench(Fp[curve], Iters)
smallSeparator()
invBench(Fp[curve], ExponentIters)
invVartimeBench(Fp[curve], ExponentIters)
isSquareBench(Fp[curve], ExponentIters)
sqrtBench(Fp[curve], ExponentIters)
sqrtRatioBench(Fp[curve], ExponentIters)
when curve == Bandersnatch:
sqrtVartimeBench(Fp[curve], ExponentIters)
sqrtRatioVartimeBench(Fp[curve], ExponentIters)
# Exponentiation by a "secret" of size ~the curve order
powBench(Fp[curve], ExponentIters)
powVartimeBench(Fp[curve], ExponentIters)
when not Fp[curve].isCrandallPrimeField(): # TODO implement
sqrtBench(Fp[curve], ExponentIters)
sqrtRatioBench(Fp[curve], ExponentIters)
when curve == Bandersnatch:
sqrtVartimeBench(Fp[curve], ExponentIters)
sqrtRatioVartimeBench(Fp[curve], ExponentIters)
# Exponentiation by a "secret" of size ~the curve order
powBench(Fp[curve], ExponentIters)
powVartimeBench(Fp[curve], ExponentIters)
separator()

main()
Expand Down
6 changes: 3 additions & 3 deletions constantine.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,10 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
("tests/math_fields/t_io_fields", false),
# ("tests/math_fields/t_finite_fields.nim", false),
# ("tests/math_fields/t_finite_fields_conditional_arithmetic.nim", false),
# ("tests/math_fields/t_finite_fields_mulsquare.nim", false),
("tests/math_fields/t_finite_fields_mulsquare.nim", false),
# ("tests/math_fields/t_finite_fields_sqrt.nim", false),
("tests/math_fields/t_finite_fields_powinv.nim", false),
# ("tests/math_fields/t_finite_fields_vs_gmp.nim", true),
# ("tests/math_fields/t_finite_fields_powinv.nim", false),
("tests/math_fields/t_finite_fields_vs_gmp.nim", true),
# ("tests/math_fields/t_fp_cubic_root.nim", false),

# Double-precision finite fields
Expand Down
146 changes: 101 additions & 45 deletions constantine/math/arithmetic/finite_fields.nim
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import
constantine/platforms/abstractions,
constantine/serialization/endians,
constantine/named/properties_fields,
./bigints, ./bigints_montgomery
./bigints, ./bigints_montgomery,
./limbs_crandall, ./limbs_extmul

when UseASM_X86_64:
import ./assembly/limbs_asm_modular_x86
Expand All @@ -54,18 +55,24 @@ export Fp, Fr, FF

func fromBig*(dst: var FF, src: BigInt) =
## Convert a BigInt to its Montgomery form
when nimvm:
dst.mres.montyResidue_precompute(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord())
when FF.isCrandallPrimeField():
dst.mres = src
else:
dst.mres.getMont(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits())
when nimvm:
dst.mres.montyResidue_precompute(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord())
else:
dst.mres.getMont(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits())

func fromBig*[Name: static Algebra](T: type FF[Name], src: BigInt): FF[Name] {.noInit.} =
## Convert a BigInt to its Montgomery form
result.fromBig(src)

func fromField*(dst: var BigInt, src: FF) {.inline.} =
## Convert a finite-field element to a BigInt in natural representation
dst.fromMont(src.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits())
when FF.isCrandallPrimeField():
dst = src.mres
else:
dst.fromMont(src.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits())

func toBig*(src: FF): auto {.noInit, inline.} =
## Convert a finite-field element to a BigInt in natural representation
Expand Down Expand Up @@ -121,11 +128,17 @@ func isZero*(a: FF): SecretBool =

func isOne*(a: FF): SecretBool =
## Constant-time check if one
a.mres == FF.getMontyOne()
when FF.isCrandallPrimeField():
a.mres.isOne()
else:
a.mres == FF.getMontyOne()

func isMinusOne*(a: FF): SecretBool =
## Constant-time check if -1 (mod p)
a.mres == FF.getMontyPrimeMinus1()
when FF.isCrandallPrimeField:
{.error: "Not implemented".}
else:
a.mres == FF.getMontyPrimeMinus1()

func isOdd*(a: FF): SecretBool {.
error: "Do you need the actual value to be odd\n" &
Expand All @@ -141,14 +154,20 @@ func setOne*(a: var FF) =
# Note: we need 1 in Montgomery residue form
# TODO: Nim codegen is not optimal it uses a temporary
# Check if the compiler optimizes it away
a.mres = FF.getMontyOne()
when FF.isCrandallPrimeField():
a.mres.setOne()
else:
a.mres = FF.getMontyOne()

func setMinusOne*(a: var FF) =
## Set ``a`` to -1 (mod p)
# Note: we need -1 in Montgomery residue form
# TODO: Nim codegen is not optimal it uses a temporary
# Check if the compiler optimizes it away
a.mres = FF.getMontyPrimeMinus1()
when FF.isCrandallPrimeField():
{.error: "Not implemented".}
else:
a.mres = FF.getMontyPrimeMinus1()

func neg*(r: var FF, a: FF) {.meter.} =
## Negate modulo p
Expand Down Expand Up @@ -237,19 +256,36 @@ func double*(r: var FF, a: FF) {.meter.} =
func prod*(r: var FF, a, b: FF, skipFinalSub: static bool = false) {.meter.} =
## Store the product of ``a`` by ``b`` modulo p into ``r``
## ``r`` is initialized / overwritten
r.mres.mulMont(a.mres, b.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)
when FF.isCrandallPrimeField():
var r2 {.noInit.}: FF.Name.getLimbs2x()
r2.prod(a.mres.limbs, b.mres.limbs)
r.mres.limbs.reduce_crandall_partial(r2, FF.bits(), FF.getCrandallPrimeSubterm())
when not skipFinalSub:
r.mres.limbs.reduce_crandall_final(FF.bits(), FF.getCrandallPrimeSubterm())
else:
r.mres.mulMont(a.mres, b.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)

func square*(r: var FF, a: FF, skipFinalSub: static bool = false) {.meter.} =
## Squaring modulo p
r.mres.squareMont(a.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)
when FF.isCrandallPrimeField():
var r2 {.noInit.}: FF.Name.getLimbs2x()
r2.square(a.mres.limbs)
r.mres.limbs.reduce_crandall_partial(r2, FF.bits(), FF.getCrandallPrimeSubterm())
when not skipFinalSub:
r.mres.limbs.reduce_crandall_final(FF.bits(), FF.getCrandallPrimeSubterm())
else:
r.mres.squareMont(a.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)

func sumprod*[N: static int](r: var FF, a, b: array[N, FF], skipFinalSub: static bool = false) {.meter.} =
## Compute r <- ⅀aᵢ.bᵢ (mod M) (sum of products)
# We rely on FF and Bigints having the same repr to avoid array copies
r.mres.sumprodMont(
cast[ptr array[N, typeof(a[0].mres)]](a.unsafeAddr)[],
cast[ptr array[N, typeof(b[0].mres)]](b.unsafeAddr)[],
FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)
when FF.isCrandallPrimeField():
{.error: "Not implemented".}
else:
r.mres.sumprodMont(
cast[ptr array[N, typeof(a[0].mres)]](a.unsafeAddr)[],
cast[ptr array[N, typeof(b[0].mres)]](b.unsafeAddr)[],
FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)

# ############################################################
#
Expand Down Expand Up @@ -329,7 +365,10 @@ func inv*(r: var FF, a: FF) =
## Incidentally this avoids extra check
## to convert Jacobian and Projective coordinates
## to affine for elliptic curve
r.mres.invmod(a.mres, FF.getR2modP(), FF.getModulus())
when FF.isCrandallPrimeField():
r.mres.invmod(a.mres, FF.getModulus())
else:
r.mres.invmod(a.mres, FF.getR2modP(), FF.getModulus())

func inv*(a: var FF) =
## Inversion modulo p
Expand All @@ -347,7 +386,10 @@ func inv_vartime*(r: var FF, a: FF) {.tags: [VarTime].} =
## Incidentally this avoids extra check
## to convert Jacobian and Projective coordinates
## to affine for elliptic curve
r.mres.invmod_vartime(a.mres, FF.getR2modP(), FF.getModulus())
when FF.isCrandallPrimeField():
r.mres.invmod_vartime(a.mres, FF.getModulus())
else:
r.mres.invmod_vartime(a.mres, FF.getR2modP(), FF.getModulus())

func inv_vartime*(a: var FF) {.tags: [VarTime].} =
## Variable-time Inversion modulo p
Expand Down Expand Up @@ -509,25 +551,31 @@ func pow*(a: var FF, exponent: BigInt) =
## Exponentiation modulo p
## ``a``: a field element to be exponentiated
## ``exponent``: a big integer
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)
when FF.isCrandallPrimeField():
{.error: "Not implemented".}
else:
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)

func pow*(a: var FF, exponent: openarray[byte]) =
## Exponentiation modulo p
## ``a``: a field element to be exponentiated
## ``exponent``: a big integer in canonical big endian representation
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)
when FF.isCrandallPrimeField():
{.error: "Not implemented".}
else:
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)

func pow*(a: var FF, exponent: FF) =
## Exponentiation modulo p
Expand Down Expand Up @@ -557,13 +605,17 @@ func pow_vartime*(a: var FF, exponent: BigInt) =
## - memory access analysis
## - power analysis
## - timing analysis
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont_vartime(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)

when FF.isCrandallPrimeField():
{.error: "Not implemented".}
else:
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont_vartime(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)

func pow_vartime*(a: var FF, exponent: openarray[byte]) =
## Exponentiation modulo p
Expand All @@ -576,13 +628,17 @@ func pow_vartime*(a: var FF, exponent: openarray[byte]) =
## - memory access analysis
## - power analysis
## - timing analysis
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont_vartime(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)

when FF.isCrandallPrimeField():
{.error: "Not implemented".}
else:
const windowSize = 5 # TODO: find best window size for each curves
a.mres.powMont_vartime(
exponent,
FF.getModulus(), FF.getMontyOne(),
FF.getNegInvModWord(), windowSize,
FF.getSpareBits()
)

func pow_vartime*(a: var FF, exponent: FF) =
## Exponentiation modulo p
Expand Down
Loading

0 comments on commit 67a4b2c

Please sign in to comment.