Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved Gadget API #9

Merged
merged 10 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions abstractor/abstractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package abstractor
import "github.com/consensys/gnark/frontend"

type Gadget interface {
Call(args ...frontend.Variable) []frontend.Variable
Call(gadget GadgetDefinition) []frontend.Variable
}

type GadgetDefinition interface {
DefineGadget(api API) []frontend.Variable
}

type API interface {
frontend.API
DefineGadget(name string, arity int, constructor func(api API, args ...frontend.Variable) []frontend.Variable) Gadget
DefineGadget(gadget GadgetDefinition) Gadget
}

type Circuit interface {
Expand Down
11 changes: 5 additions & 6 deletions abstractor/concretizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import (
)

type ConcreteGadget struct {
api API
constructor func(api API, args ...frontend.Variable) []frontend.Variable
api API
}

func (g *ConcreteGadget) Call(args ...frontend.Variable) []frontend.Variable {
return g.constructor(g.api, args...)
func (g *ConcreteGadget) Call(gadget GadgetDefinition) []frontend.Variable {
return gadget.DefineGadget(g.api)
}

type Concretizer struct {
Expand Down Expand Up @@ -119,8 +118,8 @@ func (c *Concretizer) ConstantValue(v frontend.Variable) (*big.Int, bool) {
return c.api.ConstantValue(v)
}

func (c *Concretizer) DefineGadget(name string, arity int, constructor func(api API, args ...frontend.Variable) []frontend.Variable) Gadget {
return &ConcreteGadget{c, constructor}
func (c *Concretizer) DefineGadget(gadget GadgetDefinition) Gadget {
return &ConcreteGadget{c}
}

var _ API = &(Concretizer{})
93 changes: 71 additions & 22 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/schema"
)

type Operand interface {
Expand Down Expand Up @@ -44,6 +45,12 @@ type Proj struct {

func (_ Proj) isOperand() {}

type ProjArray struct {
Proj []Operand
}

func (_ ProjArray) isOperand() {}

type Op interface {
isOp()
}
Expand Down Expand Up @@ -91,14 +98,27 @@ type ExGadget struct {
Code []App
Outputs []Operand
Extractor *CodeExtractor
Fields []schema.Field
Args []ExArg
}

func (g *ExGadget) isOp() {}

func (g *ExGadget) Call(args ...frontend.Variable) []frontend.Variable {
if len(args) != g.Arity {
panic("wrong number of arguments")
func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable {
args := []frontend.Variable{}

rv := reflect.Indirect(reflect.ValueOf(gadget))
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
fld := rt.Field(i)
v := rv.FieldByName(fld.Name)
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
args = append(args, v.Interface().([]frontend.Variable))
} else {
args = append(args, v.Elem().Interface().(frontend.Variable))
}
}

gate := g.Extractor.AddApp(g, args...)
outs := make([]frontend.Variable, len(g.Outputs))
if len(g.Outputs) == 1 {
Expand Down Expand Up @@ -134,23 +154,33 @@ type CodeExtractor struct {
Field ecc.ID
}

func operandFromArray(arg []frontend.Variable) Operand {
return arg[0].(Proj).Operand
}

func sanitizeVars(args ...frontend.Variable) []Operand {
func operandFromArray(args []frontend.Variable) []Operand {
ops := make([]Operand, len(args))
for i, arg := range args {
switch arg.(type) {
case Input, Gate, Proj, Const:
ops[i] = arg.(Operand)
default:
ops[i] = arg.(Proj).Operand
}
}
return ops
}

func sanitizeVars(args ...frontend.Variable) []Operand {
ops := []Operand{}
for _, arg := range args {
switch arg.(type) {
case Input, Gate, Proj, Const:
ops = append(ops, arg.(Operand))
case int:
ops[i] = Const{big.NewInt(int64(arg.(int)))}
ops = append(ops, Const{big.NewInt(int64(arg.(int)))})
case big.Int:
casted := arg.(big.Int)
ops[i] = Const{&casted}
ops = append(ops, Const{&casted})
case []frontend.Variable:
ops[i] = operandFromArray(arg.([]frontend.Variable))
opsArray := operandFromArray(arg.([]frontend.Variable))
ops = append(ops, ProjArray{opsArray})
default:
fmt.Printf("invalid argument of type %T\n%#v\n", arg, arg)
panic("invalid argument")
Expand All @@ -160,7 +190,8 @@ func sanitizeVars(args ...frontend.Variable) []Operand {
}

func (ce *CodeExtractor) AddApp(op Op, args ...frontend.Variable) Operand {
ce.Code = append(ce.Code, App{op, sanitizeVars(args...)})
app := App{op, sanitizeVars(args...)}
ce.Code = append(ce.Code, app)
return Gate{len(ce.Code) - 1}
}

Expand Down Expand Up @@ -290,25 +321,43 @@ func (ce *CodeExtractor) ConstantValue(v frontend.Variable) (*big.Int, bool) {
}
}

func (ce *CodeExtractor) DefineGadget(name string, arity int, constructor func(api abstractor.API, args ...frontend.Variable) []frontend.Variable) abstractor.Gadget {
func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstractor.Gadget {
schema, _ := GetSchema(gadget)
CircuitInit(gadget, schema)
// Can't use `schema.NbPublic + schema.NbSecret`
// for arity because each array element is considered
// a parameter
arity := len(schema.Fields)
name := reflect.TypeOf(gadget).Elem().Name()
args := GetExArgs(gadget, schema.Fields)

// To distinguish between gadgets instantiated with different array
// sizes, add a suffix to the name. The suffix of each instantiation
// is made up of the concatenation of the length of all the array
// fields in the gadget
suffix := ""
for _, a := range args {
if a.Kind == reflect.Array || a.Kind == reflect.Slice {
suffix += fmt.Sprintf("_%d", a.Type.Size)
}
}

oldCode := ce.Code
ce.Code = make([]App, 0)
inputs := make([]frontend.Variable, arity)
for i := 0; i < arity; i++ {
inputs[i] = Input{i}
}
outputs := constructor(ce, inputs...)
outputs := gadget.DefineGadget(ce)
newCode := ce.Code
ce.Code = oldCode
gadget := ExGadget{
Name: name,
exGadget := ExGadget{
Name: fmt.Sprintf("%s%s", name, suffix),
Arity: arity,
Code: newCode,
Outputs: sanitizeVars(outputs...),
Extractor: ce,
Fields: schema.Fields,
Args: args,
}
ce.Gadgets = append(ce.Gadgets, gadget)
return &gadget
ce.Gadgets = append(ce.Gadgets, exGadget)
return &exGadget
}

var _ abstractor.API = &CodeExtractor{}
102 changes: 82 additions & 20 deletions extractor/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,49 @@ import (
"github.com/consensys/gnark/frontend"
)

// Example: circuit with constant parameter
type SliceGadget struct {
In_1 []frontend.Variable
In_2 []frontend.Variable
}

func (gadget SliceGadget) DefineGadget(api abstractor.API) []frontend.Variable {
for i := 0; i < len(gadget.In_1); i++ {
api.Mul(gadget.In_1[i], gadget.In_2[i])
}

r := api.FromBinary(gadget.In_1...)
return []frontend.Variable{r}
}

type CircuitWithParameter struct {
In frontend.Variable `gnark:",public"`
In frontend.Variable `gnark:",public"`
Path []frontend.Variable `gnark:",public"`
Tree []frontend.Variable `gnark:",public"`
Param int
}

func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error {
slice_3 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 3),
In_2: make([]frontend.Variable, 3),
})

slice_2 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 2),
In_2: make([]frontend.Variable, 2),
})

api.FromBinary(circuit.Path...)
bin := api.ToBinary(circuit.In)
bin = api.ToBinary(circuit.Param)

dec := api.FromBinary(bin...)
api.AssertIsEqual(circuit.Param, dec)
slice_3.Call(SliceGadget{circuit.Path, circuit.Path})

api.Mul(circuit.Path[0], circuit.Path[0])
slice_2.Call(SliceGadget{circuit.Tree, circuit.Tree})
api.AssertIsEqual(circuit.Param, circuit.In)

return nil
Expand All @@ -25,7 +62,7 @@ func (circuit CircuitWithParameter) Define(api frontend.API) error {
}

func TestCircuitWithParameter(t *testing.T) {
assignment := CircuitWithParameter{}
assignment := CircuitWithParameter{Path: make([]frontend.Variable, 3), Tree: make([]frontend.Variable, 2)}
assignment.Param = 20
err := CircuitToLean(&assignment, ecc.BW6_756)
if err != nil {
Expand All @@ -34,6 +71,17 @@ func TestCircuitWithParameter(t *testing.T) {
}
}

// Example: circuit with arrays and gadget
type DummyHash struct {
In_1 frontend.Variable
In_2 frontend.Variable
}

func (gadget DummyHash) DefineGadget(api abstractor.API) []frontend.Variable {
r := api.Mul(gadget.In_1, gadget.In_2)
return []frontend.Variable{r}
}

type MerkleRecover struct {
Root frontend.Variable `gnark:",public"`
Element frontend.Variable `gnark:",public"`
Expand All @@ -42,14 +90,12 @@ type MerkleRecover struct {
}

func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error {
hash := api.DefineGadget("hash", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
return []frontend.Variable{api.Mul(args[0], args[1])}
})
hash := api.DefineGadget(&DummyHash{})

current := circuit.Element
for i := 0; i < len(circuit.Path); i++ {
leftHash := hash.Call(current, circuit.Proof[i])[0]
rightHash := hash.Call(circuit.Proof[i], current)[0]
leftHash := hash.Call(DummyHash{current, circuit.Proof[i]})[0]
rightHash := hash.Call(DummyHash{circuit.Proof[i], current})[0]
current = api.Select(circuit.Path[i], rightHash, leftHash)
}
api.AssertIsEqual(current, circuit.Root)
Expand All @@ -70,28 +116,44 @@ func TestMerkleRecover(t *testing.T) {
}
}

// Example: circuit with multiple gadgets
type MyWidget struct {
Test_1 frontend.Variable
Test_2 frontend.Variable
}

func (gadget MyWidget) DefineGadget(api abstractor.API) []frontend.Variable {
sum := api.Add(gadget.Test_1, gadget.Test_2)
mul := api.Mul(gadget.Test_1, gadget.Test_2)
r := api.Div(sum, mul)
return []frontend.Variable{r}
}

type MySecondWidget struct {
Test_1 frontend.Variable
Test_2 frontend.Variable
}

func (gadget MySecondWidget) DefineGadget(api abstractor.API) []frontend.Variable {
my_widget := api.DefineGadget(&MyWidget{})

mul := api.Mul(gadget.Test_1, gadget.Test_2)
snd := my_widget.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0]
r := api.Mul(mul, snd)
return []frontend.Variable{r}
}

type TwoGadgets struct {
In_1 frontend.Variable
In_2 frontend.Variable
}

func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error {
my_widget := api.DefineGadget("my_widget", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
sum := api.Add(args[0], args[1])
mul := api.Mul(args[0], args[1])
r := api.Div(sum, mul)
return []frontend.Variable{r}
})
my_snd_widget := api.DefineGadget("my_snd_widget", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
mul := api.Mul(args[0], args[1])
snd := my_widget.Call(args[0], args[1])
r := api.Mul(mul, snd[0])
return []frontend.Variable{r}
})
my_snd_widget := api.DefineGadget(&MySecondWidget{})

sum := api.Add(circuit.In_1, circuit.In_2)
prod := api.Mul(circuit.In_1, circuit.In_2)
my_snd_widget.Call(sum, prod)
my_snd_widget.Call(MySecondWidget{sum, prod})

return nil
}
Expand Down
Loading