Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

precomp sqrt optimization #354

Merged
merged 14 commits into from
Feb 15, 2024
Merged

precomp sqrt optimization #354

merged 14 commits into from
Feb 15, 2024

Conversation

advaita-saha
Copy link
Collaborator

@advaita-saha advaita-saha commented Feb 2, 2024

Optimization of square roots Tonelli-Shanks, with pre-computed dlog tables

Fixes #236

Notes about the approach :

Reference Implementation from Gottfried in gnark
https://github.com/GottfriedHerold/Bandersnatch/blob/f665f90b64892b9c4c89cff3219e70456bb431e5/bandersnatch/fieldElements/field_element_square_root.go

Currently pre-computes are added for Bandersnatch & Banderwagon

acc.prod(acc, z255)
square_repeated(acc, 8)
acc.prod(acc, z255)
# acc is now z^((BaseFieldMultiplicativeOddOrder - 1)/2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the BaseFieldMultiplicativeOddOrder?

If it's the prime p, it would be about 4~5 time faster to use the Legendre symbol instead of computing it via Fermat's Little Theorem:

func legendre*(a, M: Limbs, bits: static int): SecretWord =
## Compute the Legendre symbol
##
## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square
## ≡ -1 (mod p), iff a is quadratic non-residue
## ≡ 0 (mod p), iff a is 0
const Excess = 2
const k = WordBitWidth - Excess
const NumUnsatWords = bits.ceilDiv_vartime(k)
# Convert values to unsaturated repr
var m2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess]
m2.fromPackedRepr(M)
var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess]
a2.fromPackedRepr(a)
legendreImpl(a2, m2, k, bits)
func legendre*(a: Limbs, M: static Limbs, bits: static int): SecretWord =
## Compute the Legendre symbol (compile-time modulus overload)
##
## (a/p)ₗ ≡ a^((p-1)/2) ≡ 1 (mod p), iff a is a square
## ≡ -1 (mod p), iff a is quadratic non-residue
## ≡ 0 (mod p), iff a is 0
const Excess = 2
const k = WordBitWidth - Excess
const NumUnsatWords = bits.ceilDiv_vartime(k)
# Convert values to unsaturated repr
const m2 = LimbsUnsaturated[NumUnsatWords, Excess].fromPackedRepr(M)
var a2 {.noInit.}: LimbsUnsaturated[NumUnsatWords, Excess]
a2.fromPackedRepr(a)
legendreImpl(a2, m2, k, bits)

constantine/math/constants/banderwagon_sqrt.nim Outdated Show resolved Hide resolved
constantine/math/constants/banderwagon_sqrt.nim Outdated Show resolved Hide resolved
constantine/math/elliptic/ec_twistededwards_affine.nim Outdated Show resolved Hide resolved
constantine/math/constants/banderwagon_sqrt.nim Outdated Show resolved Hide resolved
Copy link
Owner

@mratsim mratsim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

@mratsim
Copy link
Owner

mratsim commented Feb 3, 2024

Some more comments: I think we can totally remove the if checks inside the algorithm.

  1. This facilitates constant time implementation
  2. checking if the number is a square is easy once we have computed the candidate sqrt, we can just multiply it by itself, which is somewhat cheap, and it is what we do:
    func sqrt_invsqrt_if_square*[C](sqrt, invsqrt: var Fp[C], a: Fp[C]): SecretBool =
    ## Compute the square root and ivnerse square root of ``a``
    ##
    ## This returns true if ``a`` is square and sqrt/invsqrt contains the square root/inverse square root
    ##
    ## The result is undefined otherwise
    ##
    ## The square root, if it exist is multivalued,
    ## i.e. both x² == (-x)²
    ## This procedure returns a deterministic result
    sqrt_invsqrt(sqrt, invsqrt, a)
    var test {.noInit.}: Fp[C]
    test.square(sqrt)
    result = test == a
  3. The case that matters and does NOT need constant-time is deserialization, and wrong deserialization means some protocol is not followed properly which means we need to stop interacting with the other party (if network) or the database (if local). This is useful for batch deserialization to shave 10~100 microseconds over thousands of points, but failing earlier after 5us instead of 10us is not necessary. (even for a DOS we're not in the milliseconds range here)

@mratsim
Copy link
Owner

mratsim commented Feb 3, 2024

I can confirm a 50% perf improvement on my machine 🔥 🎉

image

@advaita-saha
Copy link
Collaborator Author

I will be solving the constant time issues
I wanted to make sure first that if the optimisation is working as expected

@mratsim
Copy link
Owner

mratsim commented Feb 4, 2024

Strangely I can't find the sqrt or deserialization benches in go-ipa.

@advaita-saha
Copy link
Collaborator Author

Strangely I can't find the sqrt or deserialization benches in go-ipa.

There isn't, I am pushing in a few hours
Only bench that ignacio did on my request was of serialisation
crate-crypto/go-ipa#62

@advaita-saha
Copy link
Collaborator Author

Benchmarking code for go-ipa
https://github.com/advaita-saha/go-ipa/blob/15686fb2ed7bf62b2c9883b17c77e3de1d09be0a/banderwagon/element_test.go#L393-L409

@agnxsh
Copy link
Collaborator

agnxsh commented Feb 4, 2024

I will be solving the constant time issues

I wanted to make sure first that if the optimisation is working as expected

Also if it's necessary I can give a benchmark on M2 Pro by coming Monday/Tuesday. Just to follow up on that there's no significant difference after performing precomp sqrt optimisation, moreover even the unoptimised and optimised are significantly slower on the M2 Pro chip with ARM arch

@advaita-saha
Copy link
Collaborator Author

Current Benchmarks Tested in AWS instance with AMD EPYC 7R13 Processor

Screenshot 2024-02-07 at 1 56 16 AM

Results

w.r.t go-ipa ~23% performance improvement 🔥
w.r.t constantine(prev-impl) ~50% performance improvement 🚀

@advaita-saha advaita-saha marked this pull request as ready for review February 6, 2024 21:15
@mratsim
Copy link
Owner

mratsim commented Feb 10, 2024

Confirmed on my machine:

go-ipa:
image

Constantine:
image

Copy link
Owner

@mratsim mratsim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most important thing is to rename the sqrtAlg_NegDlogInSmallDyadicSubgroup function with the _vartime suffix as well as its callers.

The functions in finite_fields_square_root.nim starting from invsqrt should be duplicated to have a ..._vartime suffix and trySetFromCoordX should have a trySetFromCoordX_vartime version as well.

This is important so that protocols that use invsqrt or sqrt with secret data do not suddenly leak secrets.

The deserialization functions should then point to that vartime function.

Plan for the future:

  1. Add SageCode to generate the precomputation automatically
  2. Expand that precomputation to other curves with no fast sqrt like the pasta curves and BLS12-377
  3. Refactor the algorithm to make it constant-time and remove the dependency on Nim standard-library.

if key in Fp.C.tonelliShanks(sqrtPrecomp_dlogLUT):
return Fp.C.tonelliShanks(sqrtPrecomp_dlogLUT)[key]
return 0
return Fp.C.sqrtDlog(dlogLUT).getOrDefault(key, 0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unfortunately isn't constant-time, see: https://github.com/nim-lang/Nim/blob/57658b685cd08dc5a0c1f7a9aa58fafa553efc4e/lib/pure/collections/tableimpl.nim#L187-L191

template getOrDefaultImpl(t, key, default: untyped): untyped =
  mixin rawGet
  var hc: Hash
  var index = rawGet(t, key, hc)
  result = if index >= 0: t.data[index].val else: default

There is an if-branch.

To be fully constant-time, and ensure it stays that way, Constantine cannot depend on the standard library for code that may handle secrets.

Here, we can instead rename to sqrtAlg_NegDlogInSmallDyadicSubgroup_vartime and suffix all callers with vartime. I'll try to see later if we can have a constant-time LUT, or maybe split it into 2 LUTs.

dst.prod(candidate, rootOfUnity)
return SecretBool(true)

dst.inv()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the input x

The result of invSqrtEqDyadic(rootOfUnity)

should be x⁻¹ᐟ² according to https://github.com/GottfriedHerold/Bandersnatch/blob/7c0464f7dfae8f1139c9f812d1543de631f5262b/bandersnatch/fieldElements/field_element_square_root.go#L273-L277

// invSqrtEqDyadic asserts that z is a 2^32 root of unity and tries to set z := 1/sqrt(z).
//
// If z is actually a 2^32th *primitive* root of unity, the square root does not exist and we return false without modifying z.
// Otherwise, z is changed to 1/sqrt(z) and we return true
func (z *feType_SquareRoot) invSqrtEqDyadic() (ok bool) {

then the last two lines do

dst <- candidate * x⁻¹ᐟ²
dst <- dst⁻¹ (and this is equal to x⁻¹ᐟ² because it matches with invsqrt)

Ergo, candidate = x and rootOfUnity = x⁻¹ᐟ² before the last 2 lines, so they are unnecessary.

This can be raised as a refactoring issue and then cleaned up afterwards.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen this while implementing, that's why I haven't commented this particular function.
The problem is when I try to reverse this from the back, then it doesn't work which implies x != candidate, which actually should be

But again if sqrtX = candidate * x⁻¹ᐟ², then candidate should be equal to x. I have been confused initially with this, but left this to proceed with the implementation

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then we can leave it as is and scheduled that as a next refactoring.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar trick to Tonelli-Shanks as described here can probably be applied:

def sqrt_inv_sqrt_tonelli_shanks_impl(a, a_pre_exp, s, e, root_of_unity):
## Square root and inverse square root for any `a` in a field of prime characteristic p
##
## a_pre_exp = a^((q-1-2^e)/(2*2^e))
## with
## s and e, precomputed values
## such as q == s * 2^e + 1
# Implementation
# 1/√a * a = √a
# Notice that in Tonelli Shanks, the result `r` is bootstrapped by "z*a"
# We bootstrap it instead by just z to get invsqrt for free
z = a_pre_exp
t = z*z*a
r = z
b = t
root = root_of_unity
for i in range(e, 1, -1): # e .. 2
for j in range(1, i-1): # 1 .. i-2
b *= b
doCopy = b != 1
r = ccopy(r, r * root, doCopy)
root *= root
t = ccopy(t, t * root, doCopy)
b = t
return r*a, r

constantine/math/arithmetic/finite_fields_square_root.nim Outdated Show resolved Hide resolved
constantine/math/elliptic/ec_twistededwards_affine.nim Outdated Show resolved Hide resolved
Copy link
Owner

@mratsim mratsim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


For the current state:

Also @GottfriedHerold, it seems like your code has no license, can you clarify whether a direct port in Constantine (license MIT+Apache) is OK.


For a future PR

After thinking a bit, in a future PR, to enable this optimization for:

  • constant-time square root (for example hash-to-curve)
  • other curves
  • remove Nim std/tables dependency (due to exceptions, there is no alloc since it's compile-time tables)

We will need:

  • smaller tables, ideally configurable, because 65536 size is too big for constant-time
  • a Sagemath code to generate the LUT tables.

As such it makes sense to reuse Pornin's:

for the constant-time implementation. And it's probably possible to get it within 5% of Gottfried algorithm for the variable-time so we can reuse the same tables.

@mratsim
Copy link
Owner

mratsim commented Feb 11, 2024

It probably would be cleaner to rebase on top of master instead of merging master into this branch to keep a linear history.

@GottfriedHerold
Copy link

#Also @GottfriedHerold, it seems like your code has no license, can you clarify whether a direct port in Constantine (license MIT+Apache) is OK.

Yes, that's OK.

mratsim
mratsim previously approved these changes Feb 12, 2024
Copy link
Owner

@mratsim mratsim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some nits.

constantine/math/arithmetic/finite_fields_square_root.nim Outdated Show resolved Hide resolved
constantine/math/arithmetic/finite_fields_square_root.nim Outdated Show resolved Hide resolved
# acc is now z^((BaseFieldMultiplicativeOddOrder - 1)/2)
rootOfUnity.square(acc)
rootOfUnity *= z
squareRootCandidate.prod(acc, z)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference, I have a feeling that this addchain computes the same thing as

func precompute_tonelli_shanks_addchain*(
r: var Fp[Bandersnatch],
a: Fp[Bandersnatch]) {.addchain.} =
## Does a^Bandersnatch_TonelliShanks_exponent
## via an addition-chain
var
x10 {.noInit.}: Fp[Bandersnatch]
x100 {.noInit.}: Fp[Bandersnatch]
x110 {.noInit.}: Fp[Bandersnatch]
x1100 {.noInit.}: Fp[Bandersnatch]
x10010 {.noInit.}: Fp[Bandersnatch]
x10011 {.noInit.}: Fp[Bandersnatch]
x10110 {.noInit.}: Fp[Bandersnatch]
x11000 {.noInit.}: Fp[Bandersnatch]
x11010 {.noInit.}: Fp[Bandersnatch]
x100010 {.noInit.}: Fp[Bandersnatch]
x110101 {.noInit.}: Fp[Bandersnatch]
x111011 {.noInit.}: Fp[Bandersnatch]
x1001011 {.noInit.}: Fp[Bandersnatch]
x1001101 {.noInit.}: Fp[Bandersnatch]
x1010101 {.noInit.}: Fp[Bandersnatch]
x1100111 {.noInit.}: Fp[Bandersnatch]
x1101001 {.noInit.}: Fp[Bandersnatch]
x10000011 {.noInit.}: Fp[Bandersnatch]
x10011001 {.noInit.}: Fp[Bandersnatch]
x10011101 {.noInit.}: Fp[Bandersnatch]
x10111111 {.noInit.}: Fp[Bandersnatch]
x11010111 {.noInit.}: Fp[Bandersnatch]
x11011011 {.noInit.}: Fp[Bandersnatch]
x11100111 {.noInit.}: Fp[Bandersnatch]
x11101111 {.noInit.}: Fp[Bandersnatch]
x11111111 {.noInit.}: Fp[Bandersnatch]
x10 .square(a)
x100 .square(x10)
x110 .prod(x10, x100)
x1100 .square(x110)
x10010 .prod(x110, x1100)
x10011 .prod(a, x10010)
x10110 .prod(x100, x10010)
x11000 .prod(x10, x10110)
x11010 .prod(x10, x11000)
x100010 .prod(x1100, x10110)
x110101 .prod(x10011, x100010)
x111011 .prod(x110, x110101)
x1001011 .prod(x10110, x110101)
x1001101 .prod(x10, x1001011)
x1010101 .prod(x11010, x111011)
x1100111 .prod(x10010, x1010101)
x1101001 .prod(x10, x1100111)
x10000011 .prod(x11010, x1101001)
x10011001 .prod(x10110, x10000011)
x10011101 .prod(x100, x10011001)
x10111111 .prod(x100010, x10011101)
x11010111 .prod(x11000, x10111111)
x11011011 .prod(x100, x11010111)
x11100111 .prod(x1100, x11011011)
x11101111 .prod(x11000, x11010111)
x11111111 .prod(x11000, x11100111)
# 26 operations
let a = a # Allow aliasing between r and a
# 26+28 = 54 operations
r.square_repeated(x11100111, 8)
r *= x11011011
r.square_repeated(9)
r *= x10011101
r.square_repeated(9)
# 54 + 20 = 74 operations
r *= x10011001
r.square_repeated(9)
r *= x10011001
r.square_repeated(8)
r *= x11010111
# 74 + 27 = 101 operations
r.square_repeated(6)
r *= x110101
r.square_repeated(10)
r *= x10000011
r.square_repeated(9)
# 101 + 19 = 120 operations
r *= x1100111
r.square_repeated(8)
r *= x111011
r.square_repeated(8)
r *= a
# 120 + 41 = 160 operations
r.square_repeated(14)
r *= x1001101
r.square_repeated(10)
r *= x111011
r.square_repeated(15)
# 161 + 21 = 182 operations
r *= x1010101
r.square_repeated(10)
r *= x11101111
r.square_repeated(8)
r *= x1101001
# 182 + 33 = 215 operations
r.square_repeated(16)
r *= x10111111
r.square_repeated(8)
r *= x11111111
r.square_repeated(7)
# 215 + 20 = 235 operations
r *= x1001011
r.square_repeated(9)
r *= x11111111
r.square_repeated(8)
r *= x10111111
# 235 + 26 = 261 operations
r.square_repeated(8)
r *= x11111111
r.square_repeated(8)
r *= x11111111
r.square_repeated(8)
# 261 + 3 = 264 operations
r *= x11111111
r.square()
r *= a

Also it would need to be moved to Banderwagon/Bandersnatch specific files in the future

constantine/math/constants/bandersnatch_sqrt.nim Outdated Show resolved Hide resolved
constantine/math/constants/bandersnatch_sqrt.nim Outdated Show resolved Hide resolved
constantine/math/constants/banderwagon_sqrt.nim Outdated Show resolved Hide resolved
constantine/serialization/codecs_banderwagon.nim Outdated Show resolved Hide resolved
@advaita-saha
Copy link
Collaborator Author

advaita-saha commented Feb 12, 2024

@mratsim
completed the suggested changes

mratsim
mratsim previously approved these changes Feb 13, 2024
Copy link
Owner

@mratsim mratsim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I can either merge as-is or wait for the extra benches to be added.

benchmarks/bench_verkle_primitives.nim Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Faster point decompression
4 participants