Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: range check gadget #472

Merged
merged 8 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions frontend/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,12 @@ type Committer interface {
// Commit commits to the variables and returns the commitment.
Commit(toCommit ...Variable) (commitment Variable, err error)
}

// Rangechecker allows to externally range-check the variables to be of
// specified width. Not all compilers implement this interface. Users should
// instead use [github.com/consensys/gnark/std/rangecheck] package which
// automatically chooses most optimal method for range checking the variables.
type Rangechecker interface {
// Check checks that the given variable v has bit-length bits.
Check(v Variable, bits int)
}
5 changes: 5 additions & 0 deletions frontend/cs/r1cs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/consensys/gnark/frontend/internal/expr"
"github.com/consensys/gnark/frontend/schema"
"github.com/consensys/gnark/internal/circuitdefer"
"github.com/consensys/gnark/internal/frontendtype"
"github.com/consensys/gnark/internal/kvstore"
"github.com/consensys/gnark/internal/tinyfield"
"github.com/consensys/gnark/internal/utils"
Expand Down Expand Up @@ -452,3 +453,7 @@ func (builder *builder) compress(le expr.LinearExpression) expr.LinearExpression
func (builder *builder) Defer(cb func(frontend.API) error) {
circuitdefer.Put(builder, cb)
}

func (*builder) FrontendType() frontendtype.Type {
return frontendtype.R1CS
}
5 changes: 5 additions & 0 deletions frontend/cs/scs/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/internal/expr"
"github.com/consensys/gnark/frontend/schema"
"github.com/consensys/gnark/internal/frontendtype"
"github.com/consensys/gnark/std/math/bits"
)

Expand Down Expand Up @@ -557,3 +558,7 @@ func (builder *builder) printArg(log *constraint.LogEntry, sbb *strings.Builder,
func (builder *builder) Compiler() frontend.Compiler {
return builder
}

func (*builder) FrontendType() frontendtype.Type {
return frontendtype.SCS
}
13 changes: 13 additions & 0 deletions internal/frontendtype/frontendtype.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Package frontendtype allows to assert frontend type.
package frontendtype

type Type int

const (
R1CS Type = iota
SCS
)

type FrontendTyper interface {
FrontendType() Type
}
Binary file modified internal/stats/latest.stats
Binary file not shown.
2 changes: 2 additions & 0 deletions std/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/consensys/gnark/std/algebra/native/sw_bls24315"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
"github.com/consensys/gnark/std/selector"
)

Expand All @@ -34,4 +35,5 @@ func registerHints() {
solver.RegisterHint(selector.MuxIndicators)
solver.RegisterHint(selector.MapIndicators)
solver.RegisterHint(emulated.GetHints()...)
solver.RegisterHint(rangecheck.CountHint, rangecheck.DecomposeHint)
}
3 changes: 3 additions & 0 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/logger"
"github.com/consensys/gnark/std/rangecheck"
"github.com/rs/zerolog"
"golang.org/x/exp/constraints"
)
Expand Down Expand Up @@ -38,6 +39,7 @@ type Field[T FieldParams] struct {
log zerolog.Logger

constrainedLimbs map[uint64]struct{}
checker frontend.Rangechecker
}

// NewField returns an object to be used in-circuit to perform emulated
Expand All @@ -53,6 +55,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
api: native,
log: logger.Logger(),
constrainedLimbs: make(map[uint64]struct{}),
checker: rangecheck.New(native),
}

// ensure prime is correctly set
Expand Down
55 changes: 14 additions & 41 deletions std/math/emulated/field_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ import (
"math/big"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/bits"
)

// assertLimbsEqualitySlow is the main routine in the package. It asserts that the
// two slices of limbs represent the same integer value. This is also the most
// costly operation in the package as it does bit decomposition of the limbs.
func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) {
func (f *Field[T]) assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) {

nbLimbs := max(len(l), len(r))
maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits)
Expand All @@ -33,52 +32,29 @@ func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits,
// carry is stored in the highest bits of diff[nbBits:nbBits+nbCarryBits+1]
// we know that diff[:nbBits] are 0 bits, but still need to constrain them.
// to do both; we do a "clean" right shift and only need to boolean constrain the carry part
carry = rsh(api, diff, int(nbBits), int(nbBits+nbCarryBits+1))
carry = f.rsh(diff, int(nbBits), int(nbBits+nbCarryBits+1))
}
api.AssertIsEqual(carry, maxValueShift)
}

// rsh right shifts a variable endDigit-startDigit bits and returns it.
func rsh(api frontend.API, v frontend.Variable, startDigit, endDigit int) frontend.Variable {
func (f *Field[T]) rsh(v frontend.Variable, startDigit, endDigit int) frontend.Variable {
// if v is a constant, work with the big int value.
if c, ok := api.Compiler().ConstantValue(v); ok {
if c, ok := f.api.Compiler().ConstantValue(v); ok {
bits := make([]frontend.Variable, endDigit-startDigit)
for i := 0; i < len(bits); i++ {
bits[i] = c.Bit(i + startDigit)
}
return bits
}

bits, err := api.Compiler().NewHint(NBitsShifted, endDigit-startDigit, v, startDigit)
shifted, err := f.api.Compiler().NewHint(RightShift, 1, startDigit, v)
if err != nil {
panic(err)
}

// we compute 2 sums;
// Σbi ensures that "ignoring" the lowest bits (< startDigit) still is a valid bit decomposition.
// that is, it ensures that bits from startDigit to endDigit * corresponding coefficients (powers of 2 shifted)
// are equal to the input variable
// ΣbiRShift computes the actual result; that is, the Σ (2**i * b[i])
Σbi := frontend.Variable(0)
ΣbiRShift := frontend.Variable(0)

cRShift := big.NewInt(1)
c := big.NewInt(1)
c.Lsh(c, uint(startDigit))

for i := 0; i < len(bits); i++ {
Σbi = api.MulAcc(Σbi, bits[i], c)
ΣbiRShift = api.MulAcc(ΣbiRShift, bits[i], cRShift)

c.Lsh(c, 1)
cRShift.Lsh(cRShift, 1)
api.AssertIsBoolean(bits[i])
panic(fmt.Sprintf("right shift: %v", err))
}

// constraint Σ (2**i_shift * b[i]) == v
api.AssertIsEqual(Σbi, v)
return ΣbiRShift

f.checker.Check(shifted[0], endDigit-startDigit)
shift := new(big.Int).Lsh(big.NewInt(1), uint(startDigit))
composed := f.api.Mul(shifted[0], shift)
f.api.AssertIsEqual(composed, v)
return shifted[0]
}

// AssertLimbsEquality asserts that the limbs represent a same integer value.
Expand Down Expand Up @@ -107,9 +83,9 @@ func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) {
// TODO: we previously assumed that one side was "larger" than the other
// side, but I think this assumption is not valid anymore
if a.overflow > b.overflow {
assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow)
f.assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow)
} else {
assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow)
f.assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow)
}
}

Expand All @@ -133,10 +109,7 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) {
// take only required bits from the most significant limb
limbNbBits = ((f.fParams.Modulus().BitLen() - 1) % int(f.fParams.BitsPerLimb())) + 1
}
// bits.ToBinary restricts the least significant NbDigits to be equal to
// the limb value. This is sufficient to restrict for the bitlength and
// we can discard the bits themselves.
bits.ToBinary(f.api, a.Limbs[i], bits.WithNbDigits(limbNbBits))
f.checker.Check(a.Limbs[i], limbNbBits)
}
}

Expand Down
25 changes: 16 additions & 9 deletions std/math/emulated/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func GetHints() []solver.Hint {
InverseHint,
MultiplicationHint,
RemHint,
NBitsShifted,
RightShift,
}
}

Expand Down Expand Up @@ -287,13 +287,20 @@ func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error
return nbBits, nbLimbs, x, y, nil
}

// NBitsShifted returns the first bits of the input, with a shift. The number of returned bits is
// defined by the length of the results slice.
func NBitsShifted(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
n := inputs[0]
shift := inputs[1].Uint64() // TODO @gbotrel validate input vs perf in large circuits.
for i := 0; i < len(results); i++ {
results[i].SetUint64(uint64(n.Bit(i + int(shift))))
}
// RightShift shifts input by the given number of bits. Expects two inputs:
// - first input is the shift, will be represented as uint64;
// - second input is the value to be shifted.
//
// Returns a single output which is the value shifted. Errors if number of
// inputs is not 2 and number of outputs is not 1.
func RightShift(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 2 {
return fmt.Errorf("expecting two inputs")
}
if len(outputs) != 1 {
return fmt.Errorf("expecting single output")
}
shift := inputs[0].Uint64()
outputs[0].Rsh(inputs[1], uint(shift))
return nil
}
30 changes: 30 additions & 0 deletions std/rangecheck/rangecheck.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Package rangecheck implements range checking gadget
//
// This package chooses the most optimal path for performing range checks:
// - if the backend supports native range checking and the frontend exports the variables in the proprietary format by implementing [frontend.Rangechecker], then use it directly;
// - if the backend supports creating a commitment of variables by implementing [frontend.Committer], then we use the product argument as in [BCG+18]. [r1cs.NewBuilder] returns a builder which implements this interface;
// - lacking these, we perform binary decomposition of variable into bits.
//
// [BCG+18]: https://eprint.iacr.org/2018/380
package rangecheck

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
)

// only for documentation purposes. If we import the package then godoc knows
// how to refer to package r1cs and we get nice links in godoc. We import the
// package anyway in test.
var _ = r1cs.NewBuilder

// New returns a new range checker depending on the frontend capabilities.
func New(api frontend.API) frontend.Rangechecker {
if rc, ok := api.(frontend.Rangechecker); ok {
return rc
}
if _, ok := api.(frontend.Committer); ok {
return newCommitRangechecker(api)
}
return plainChecker{api: api}
}
Loading