Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize modexp #6247

Merged
merged 1 commit into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions common/math/modexp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package math

import (
"math/big"
"math/bits"

"github.com/ledgerwatch/erigon/common"
)

// FastExp is semantically equivalent to x.Exp(x,y, m), but is faster for even
// modulus.
func FastExp(x, y, m *big.Int) *big.Int {
// Split m = m1 × m2 where m1 = 2ⁿ
n := m.TrailingZeroBits()
m1 := new(big.Int).Lsh(common.Big1, n)
mask := new(big.Int).Sub(m1, common.Big1)
m2 := new(big.Int).Rsh(m, n)

// We want z = x**y mod m.
// z1 = x**y mod m1 = (x**y mod m) mod m1 = z mod m1
// z2 = x**y mod m2 = (x**y mod m) mod m2 = z mod m2
z1 := fastExpPow2(x, y, mask)
z2 := new(big.Int).Exp(x, y, m2)

// Reconstruct z from z1, z2 using CRT, using algorithm from paper,
// which uses only a single modInverse.
// p = (z1 - z2) * m2⁻¹ (mod m1)
// z = z2 + p * m2
z := new(big.Int).Set(z2)

// Compute (z1 - z2) mod m1 [m1 == 2**n] into z1.
z1 = z1.And(z1, mask)
z2 = z2.And(z2, mask)
z1 = z1.Sub(z1, z2)
if z1.Sign() < 0 {
z1 = z1.Add(z1, m1)
}

// Reuse z2 for p = z1 * m2inv.
m2inv := new(big.Int).ModInverse(m2, m1)
z2 = z2.Mul(z1, m2inv)
z2 = z2.And(z2, mask)

// Reuse z1 for m2 * p.
z = z.Add(z, z1.Mul(z2, m2))
z = z.Rem(z, m)

return z
}

func fastExpPow2(x, y *big.Int, mask *big.Int) *big.Int {
z := big.NewInt(1)
if y.Sign() == 0 {
return z
}
p := new(big.Int).Set(x)
p = p.And(p, mask)
if p.Cmp(z) <= 0 { // p <= 1
return p
}
if y.Cmp(mask) > 0 {
y = new(big.Int).And(y, mask)
}
t := new(big.Int)

for _, b := range y.Bits() {
for i := 0; i < bits.UintSize; i++ {
if b&1 != 0 {
z, t = t.Mul(z, p), z
z = z.And(z, mask)
}
p, t = t.Mul(p, p), p
p = p.And(p, mask)
b >>= 1
}
}
return z
}
15 changes: 13 additions & 2 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,23 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) {
base = new(big.Int).SetBytes(getData(input, 0, baseLen))
exp = new(big.Int).SetBytes(getData(input, baseLen, expLen))
mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen))
v []byte
)
if mod.Sign() == 0 {
switch {
case mod.BitLen() == 0:
// Modulo 0 is undefined, return zero
return common.LeftPadBytes([]byte{}, int(modLen)), nil
case base.Cmp(common.Big1) == 0:
//If base == 1, then we can just return base % mod (if mod >= 1, which it is)
v = base.Mod(base, mod).Bytes()
//case mod.Bit(0) == 0:
// // Modulo is even
// v = math.FastExp(base, exp, mod).Bytes()
default:
// Modulo is odd
v = base.Exp(base, exp, mod).Bytes()
}
return common.LeftPadBytes(base.Exp(base, exp, mod).Bytes(), int(modLen)), nil
return common.LeftPadBytes(v, int(modLen)), nil
}

// newCurvePoint unmarshals a binary blob into a bn256 elliptic curve point,
Expand Down