Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: assert that the binary decomposition of a variable is less than the modulus #835

Merged
merged 18 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions frontend/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ type API interface {
// IsZero returns 1 if a is zero, 0 otherwise
IsZero(i1 Variable) Variable

// Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1<i2
// Cmp returns:
// * 1 if i1>i2,
// * 0 if i1=i2,
// * -1 if i1<i2.
//
// If the absolute difference between the variables i1 and i2 is known, then
// it is more efficient to use the bounded methdods in package
// [github.com/consensys/gnark/std/math/bits].
Cmp(i1, i2 Variable) Variable

// ---------------------------------------------------------------------------------------------
Expand All @@ -115,7 +122,11 @@ type API interface {
// AssertIsBoolean fails if v != 0 ∥ v != 1
AssertIsBoolean(i1 Variable)

// AssertIsLessOrEqual fails if v > bound
// AssertIsLessOrEqual fails if v > bound.
//
// If the absolute difference between the variables b and bound is known, then
// it is more efficient to use the bounded methdods in package
// [github.com/consensys/gnark/std/math/bits].
AssertIsLessOrEqual(v Variable, bound Variable)

// Println behaves like fmt.Println but accepts cd.Variable as parameter
Expand Down
12 changes: 8 additions & 4 deletions frontend/cs/r1cs/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package r1cs
import (
"errors"
"fmt"
"github.com/consensys/gnark/internal/utils"
"path/filepath"
"reflect"
"runtime"
"strings"

"github.com/consensys/gnark/internal/utils"

"github.com/consensys/gnark/debug"
"github.com/consensys/gnark/frontend/cs"

Expand Down Expand Up @@ -570,9 +571,12 @@ func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable {
// Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1<i2
func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable {

vars, _ := builder.toVariables(i1, i2)
bi1 := builder.ToBinary(vars[0], builder.cs.FieldBitLen())
bi2 := builder.ToBinary(vars[1], builder.cs.FieldBitLen())
nbBits := builder.cs.FieldBitLen()
// in AssertIsLessOrEq we omitted comparison against modulus for the left
// side as if `a+r<b` implies `a<b`, then here we compute the inequality
// directly.
bi1 := bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits))
bi2 := bits.ToBinary(builder, i2, bits.WithNbDigits(nbBits))

res := builder.cstZero()

Expand Down
27 changes: 17 additions & 10 deletions frontend/cs/r1cs/api_assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.
}
}

nbBits := builder.cs.FieldBitLen()
vBits := bits.ToBinary(builder, v, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())

// bound is constant
if bConst {
vv := builder.toVariable(v)
builder.mustBeLessOrEqCst(vv, builder.cs.ToBigInt(cb))
builder.MustBeLessOrEqCst(vBits, builder.cs.ToBigInt(cb), v)
return
}

Expand All @@ -119,8 +121,8 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) {

nbBits := builder.cs.FieldBitLen()

aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
boundBits := builder.ToBinary(bound, nbBits)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs(), bits.OmitModulusCheck())
boundBits := bits.ToBinary(builder, bound, bits.WithNbDigits(nbBits))

// constraint added
added := make([]int, 0, nbBits)
Expand Down Expand Up @@ -166,9 +168,18 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) {

}

func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.Int) {
// MustBeLessOrEqCst asserts that value represented using its bit decomposition
// aBits is less or equal than constant bound. The method boolean constraints
// the bits in aBits, so the caller can provide unconstrained bits.
func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big.Int, aForDebug frontend.Variable) {

nbBits := builder.cs.FieldBitLen()
if len(aBits) > nbBits {
panic("more input bits than field bit length")
}
for i := len(aBits); i < nbBits; i++ {
aBits = append(aBits, 0)
}

// ensure the bound is positive, it's bit-len doesn't matter
if bound.Sign() == -1 {
Expand All @@ -179,11 +190,7 @@ func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.In
}

// debug info
debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", builder.toVariable(bound))

// note that at this stage, we didn't boolean-constraint these new variables yet
// (as opposed to ToBinary)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
debug := builder.newDebugInfo("mustBeLessOrEq", aForDebug, " <= ", builder.toVariable(bound))

// t trailing bits in the bound
t := 0
Expand Down
8 changes: 6 additions & 2 deletions frontend/cs/scs/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,12 @@ func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable {
// Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1<i2
func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable {

bi1 := builder.ToBinary(i1, builder.cs.FieldBitLen())
bi2 := builder.ToBinary(i2, builder.cs.FieldBitLen())
nbBits := builder.cs.FieldBitLen()
// in AssertIsLessOrEq we omitted comparison against modulus for the left
// side as if `a+r<b` implies `a<b`, then here we compute the inequality
// directly.
bi1 := bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits))
bi2 := bits.ToBinary(builder, i2, bits.WithNbDigits(nbBits))

var res frontend.Variable
res = 0
Expand Down
57 changes: 36 additions & 21 deletions frontend/cs/scs/api_assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/consensys/gnark/debug"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/internal/expr"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/std/math/bits"
)

Expand Down Expand Up @@ -131,11 +130,30 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) {

// AssertIsLessOrEqual fails if v > bound
func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) {
switch b := bound.(type) {
case expr.Term:
cv, vConst := builder.constantValue(v)
cb, bConst := builder.constantValue(bound)

// both inputs are constants
if vConst && bConst {
bv, bb := builder.cs.ToBigInt(cv), builder.cs.ToBigInt(cb)
if bv.Cmp(bb) == 1 {
panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", bv.String(), bb.String()))
}
}

nbBits := builder.cs.FieldBitLen()
vBits := bits.ToBinary(builder, v, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())

// bound is constant
if bConst {
builder.MustBeLessOrEqCst(vBits, builder.cs.ToBigInt(cb), v)
return
}

if b, ok := bound.(expr.Term); ok {
builder.mustBeLessOrEqVar(v, b)
default:
builder.mustBeLessOrEqCst(v, utils.FromInterface(b))
} else {
panic(fmt.Sprintf("expected bound type expr.Term, got %T", bound))
}
}

Expand All @@ -145,8 +163,8 @@ func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term)

nbBits := builder.cs.FieldBitLen()

aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
boundBits := builder.ToBinary(bound, nbBits)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs(), bits.OmitModulusCheck())
boundBits := bits.ToBinary(builder, bound, bits.WithNbDigits(nbBits)) // enforces range check against modulus

p := make([]frontend.Variable, nbBits+1)
p[nbBits] = 1
Expand Down Expand Up @@ -191,9 +209,18 @@ func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term)

}

func (builder *builder) mustBeLessOrEqCst(a frontend.Variable, bound big.Int) {
// MustBeLessOrEqCst asserts that value represented using its bit decomposition
// aBits is less or equal than constant bound. The method boolean constraints
// the bits in aBits, so the caller can provide unconstrained bits.
func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big.Int, aForDebug frontend.Variable) {

nbBits := builder.cs.FieldBitLen()
if len(aBits) > nbBits {
panic("more input bits than field bit length")
}
for i := len(aBits); i < nbBits; i++ {
aBits = append(aBits, 0)
}

// ensure the bound is positive, it's bit-len doesn't matter
if bound.Sign() == -1 {
Expand All @@ -203,20 +230,8 @@ func (builder *builder) mustBeLessOrEqCst(a frontend.Variable, bound big.Int) {
panic("AssertIsLessOrEqual: bound is too large, constraint will never be satisfied")
}

if ca, ok := builder.constantValue(a); ok {
// a is constant, compare the big int values
ba := builder.cs.ToBigInt(ca)
if ba.Cmp(&bound) == 1 {
panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", ba.String(), bound.String()))
}
}

// debug info
debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound)

// note that at this stage, we didn't boolean-constraint these new variables yet
// (as opposed to ToBinary)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
debug := builder.newDebugInfo("mustBeLessOrEq", aForDebug, " <= ", bound)

// t trailing bits in the bound
t := 0
Expand Down
2 changes: 1 addition & 1 deletion internal/backend/circuits/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (circuit *recursiveHint) Define(api frontend.API) error {
// api.ToBinary calls another hint (bits.NBits) with linearExpression as input
// however, when the solver will resolve bits[...] it will need to detect w1 as a dependency
// in order to compute the correct linearExpression value
bits := api.ToBinary(linearExpression, 10)
bits := api.ToBinary(linearExpression, 6)

a := api.FromBinary(bits...)

Expand Down
2 changes: 2 additions & 0 deletions internal/regression_tests/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package regressiontests includes tests to avoid re-introducing regressions.
package regressiontests
Loading