Skip to content

Commit

Permalink
feat: improved Gadget API (#9)
Browse files Browse the repository at this point in the history
* Refactoring of Gadget API

* Added support for arrays in gadgets. Tested from_binary and to_binary

* Fixed gadget call with array

* Renamed GadgetDefine to DefineGadget
  • Loading branch information
Eagle941 authored Jul 5, 2023
1 parent f08f001 commit 70db057
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 71 deletions.
8 changes: 6 additions & 2 deletions abstractor/abstractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package abstractor
import "github.com/consensys/gnark/frontend"

type Gadget interface {
Call(args ...frontend.Variable) []frontend.Variable
Call(gadget GadgetDefinition) []frontend.Variable
}

type GadgetDefinition interface {
DefineGadget(api API) []frontend.Variable
}

type API interface {
frontend.API
DefineGadget(name string, arity int, constructor func(api API, args ...frontend.Variable) []frontend.Variable) Gadget
DefineGadget(gadget GadgetDefinition) Gadget
}

type Circuit interface {
Expand Down
11 changes: 5 additions & 6 deletions abstractor/concretizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import (
)

type ConcreteGadget struct {
api API
constructor func(api API, args ...frontend.Variable) []frontend.Variable
api API
}

func (g *ConcreteGadget) Call(args ...frontend.Variable) []frontend.Variable {
return g.constructor(g.api, args...)
func (g *ConcreteGadget) Call(gadget GadgetDefinition) []frontend.Variable {
return gadget.DefineGadget(g.api)
}

type Concretizer struct {
Expand Down Expand Up @@ -119,8 +118,8 @@ func (c *Concretizer) ConstantValue(v frontend.Variable) (*big.Int, bool) {
return c.api.ConstantValue(v)
}

func (c *Concretizer) DefineGadget(name string, arity int, constructor func(api API, args ...frontend.Variable) []frontend.Variable) Gadget {
return &ConcreteGadget{c, constructor}
func (c *Concretizer) DefineGadget(gadget GadgetDefinition) Gadget {
return &ConcreteGadget{c}
}

var _ API = &(Concretizer{})
93 changes: 71 additions & 22 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/schema"
)

type Operand interface {
Expand Down Expand Up @@ -44,6 +45,12 @@ type Proj struct {

func (_ Proj) isOperand() {}

type ProjArray struct {
Proj []Operand
}

func (_ ProjArray) isOperand() {}

type Op interface {
isOp()
}
Expand Down Expand Up @@ -91,14 +98,27 @@ type ExGadget struct {
Code []App
Outputs []Operand
Extractor *CodeExtractor
Fields []schema.Field
Args []ExArg
}

func (g *ExGadget) isOp() {}

func (g *ExGadget) Call(args ...frontend.Variable) []frontend.Variable {
if len(args) != g.Arity {
panic("wrong number of arguments")
func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable {
args := []frontend.Variable{}

rv := reflect.Indirect(reflect.ValueOf(gadget))
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
fld := rt.Field(i)
v := rv.FieldByName(fld.Name)
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
args = append(args, v.Interface().([]frontend.Variable))
} else {
args = append(args, v.Elem().Interface().(frontend.Variable))
}
}

gate := g.Extractor.AddApp(g, args...)
outs := make([]frontend.Variable, len(g.Outputs))
if len(g.Outputs) == 1 {
Expand Down Expand Up @@ -134,23 +154,33 @@ type CodeExtractor struct {
Field ecc.ID
}

func operandFromArray(arg []frontend.Variable) Operand {
return arg[0].(Proj).Operand
}

func sanitizeVars(args ...frontend.Variable) []Operand {
func operandFromArray(args []frontend.Variable) []Operand {
ops := make([]Operand, len(args))
for i, arg := range args {
switch arg.(type) {
case Input, Gate, Proj, Const:
ops[i] = arg.(Operand)
default:
ops[i] = arg.(Proj).Operand
}
}
return ops
}

func sanitizeVars(args ...frontend.Variable) []Operand {
ops := []Operand{}
for _, arg := range args {
switch arg.(type) {
case Input, Gate, Proj, Const:
ops = append(ops, arg.(Operand))
case int:
ops[i] = Const{big.NewInt(int64(arg.(int)))}
ops = append(ops, Const{big.NewInt(int64(arg.(int)))})
case big.Int:
casted := arg.(big.Int)
ops[i] = Const{&casted}
ops = append(ops, Const{&casted})
case []frontend.Variable:
ops[i] = operandFromArray(arg.([]frontend.Variable))
opsArray := operandFromArray(arg.([]frontend.Variable))
ops = append(ops, ProjArray{opsArray})
default:
fmt.Printf("invalid argument of type %T\n%#v\n", arg, arg)
panic("invalid argument")
Expand All @@ -160,7 +190,8 @@ func sanitizeVars(args ...frontend.Variable) []Operand {
}

func (ce *CodeExtractor) AddApp(op Op, args ...frontend.Variable) Operand {
ce.Code = append(ce.Code, App{op, sanitizeVars(args...)})
app := App{op, sanitizeVars(args...)}
ce.Code = append(ce.Code, app)
return Gate{len(ce.Code) - 1}
}

Expand Down Expand Up @@ -290,25 +321,43 @@ func (ce *CodeExtractor) ConstantValue(v frontend.Variable) (*big.Int, bool) {
}
}

func (ce *CodeExtractor) DefineGadget(name string, arity int, constructor func(api abstractor.API, args ...frontend.Variable) []frontend.Variable) abstractor.Gadget {
func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstractor.Gadget {
schema, _ := GetSchema(gadget)
CircuitInit(gadget, schema)
// Can't use `schema.NbPublic + schema.NbSecret`
// for arity because each array element is considered
// a parameter
arity := len(schema.Fields)
name := reflect.TypeOf(gadget).Elem().Name()
args := GetExArgs(gadget, schema.Fields)

// To distinguish between gadgets instantiated with different array
// sizes, add a suffix to the name. The suffix of each instantiation
// is made up of the concatenation of the length of all the array
// fields in the gadget
suffix := ""
for _, a := range args {
if a.Kind == reflect.Array || a.Kind == reflect.Slice {
suffix += fmt.Sprintf("_%d", a.Type.Size)
}
}

oldCode := ce.Code
ce.Code = make([]App, 0)
inputs := make([]frontend.Variable, arity)
for i := 0; i < arity; i++ {
inputs[i] = Input{i}
}
outputs := constructor(ce, inputs...)
outputs := gadget.DefineGadget(ce)
newCode := ce.Code
ce.Code = oldCode
gadget := ExGadget{
Name: name,
exGadget := ExGadget{
Name: fmt.Sprintf("%s%s", name, suffix),
Arity: arity,
Code: newCode,
Outputs: sanitizeVars(outputs...),
Extractor: ce,
Fields: schema.Fields,
Args: args,
}
ce.Gadgets = append(ce.Gadgets, gadget)
return &gadget
ce.Gadgets = append(ce.Gadgets, exGadget)
return &exGadget
}

var _ abstractor.API = &CodeExtractor{}
102 changes: 82 additions & 20 deletions extractor/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,49 @@ import (
"github.com/consensys/gnark/frontend"
)

// Example: circuit with constant parameter
type SliceGadget struct {
In_1 []frontend.Variable
In_2 []frontend.Variable
}

func (gadget SliceGadget) DefineGadget(api abstractor.API) []frontend.Variable {
for i := 0; i < len(gadget.In_1); i++ {
api.Mul(gadget.In_1[i], gadget.In_2[i])
}

r := api.FromBinary(gadget.In_1...)
return []frontend.Variable{r}
}

type CircuitWithParameter struct {
In frontend.Variable `gnark:",public"`
In frontend.Variable `gnark:",public"`
Path []frontend.Variable `gnark:",public"`
Tree []frontend.Variable `gnark:",public"`
Param int
}

func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error {
slice_3 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 3),
In_2: make([]frontend.Variable, 3),
})

slice_2 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 2),
In_2: make([]frontend.Variable, 2),
})

api.FromBinary(circuit.Path...)
bin := api.ToBinary(circuit.In)
bin = api.ToBinary(circuit.Param)

dec := api.FromBinary(bin...)
api.AssertIsEqual(circuit.Param, dec)
slice_3.Call(SliceGadget{circuit.Path, circuit.Path})

api.Mul(circuit.Path[0], circuit.Path[0])
slice_2.Call(SliceGadget{circuit.Tree, circuit.Tree})
api.AssertIsEqual(circuit.Param, circuit.In)

return nil
Expand All @@ -25,7 +62,7 @@ func (circuit CircuitWithParameter) Define(api frontend.API) error {
}

func TestCircuitWithParameter(t *testing.T) {
assignment := CircuitWithParameter{}
assignment := CircuitWithParameter{Path: make([]frontend.Variable, 3), Tree: make([]frontend.Variable, 2)}
assignment.Param = 20
err := CircuitToLean(&assignment, ecc.BW6_756)
if err != nil {
Expand All @@ -34,6 +71,17 @@ func TestCircuitWithParameter(t *testing.T) {
}
}

// Example: circuit with arrays and gadget
type DummyHash struct {
In_1 frontend.Variable
In_2 frontend.Variable
}

func (gadget DummyHash) DefineGadget(api abstractor.API) []frontend.Variable {
r := api.Mul(gadget.In_1, gadget.In_2)
return []frontend.Variable{r}
}

type MerkleRecover struct {
Root frontend.Variable `gnark:",public"`
Element frontend.Variable `gnark:",public"`
Expand All @@ -42,14 +90,12 @@ type MerkleRecover struct {
}

func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error {
hash := api.DefineGadget("hash", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
return []frontend.Variable{api.Mul(args[0], args[1])}
})
hash := api.DefineGadget(&DummyHash{})

current := circuit.Element
for i := 0; i < len(circuit.Path); i++ {
leftHash := hash.Call(current, circuit.Proof[i])[0]
rightHash := hash.Call(circuit.Proof[i], current)[0]
leftHash := hash.Call(DummyHash{current, circuit.Proof[i]})[0]
rightHash := hash.Call(DummyHash{circuit.Proof[i], current})[0]
current = api.Select(circuit.Path[i], rightHash, leftHash)
}
api.AssertIsEqual(current, circuit.Root)
Expand All @@ -70,28 +116,44 @@ func TestMerkleRecover(t *testing.T) {
}
}

// Example: circuit with multiple gadgets
type MyWidget struct {
Test_1 frontend.Variable
Test_2 frontend.Variable
}

func (gadget MyWidget) DefineGadget(api abstractor.API) []frontend.Variable {
sum := api.Add(gadget.Test_1, gadget.Test_2)
mul := api.Mul(gadget.Test_1, gadget.Test_2)
r := api.Div(sum, mul)
return []frontend.Variable{r}
}

type MySecondWidget struct {
Test_1 frontend.Variable
Test_2 frontend.Variable
}

func (gadget MySecondWidget) DefineGadget(api abstractor.API) []frontend.Variable {
my_widget := api.DefineGadget(&MyWidget{})

mul := api.Mul(gadget.Test_1, gadget.Test_2)
snd := my_widget.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0]
r := api.Mul(mul, snd)
return []frontend.Variable{r}
}

type TwoGadgets struct {
In_1 frontend.Variable
In_2 frontend.Variable
}

func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error {
my_widget := api.DefineGadget("my_widget", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
sum := api.Add(args[0], args[1])
mul := api.Mul(args[0], args[1])
r := api.Div(sum, mul)
return []frontend.Variable{r}
})
my_snd_widget := api.DefineGadget("my_snd_widget", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable {
mul := api.Mul(args[0], args[1])
snd := my_widget.Call(args[0], args[1])
r := api.Mul(mul, snd[0])
return []frontend.Variable{r}
})
my_snd_widget := api.DefineGadget(&MySecondWidget{})

sum := api.Add(circuit.In_1, circuit.In_2)
prod := api.Mul(circuit.In_1, circuit.In_2)
my_snd_widget.Call(sum, prod)
my_snd_widget.Call(MySecondWidget{sum, prod})

return nil
}
Expand Down
Loading

0 comments on commit 70db057

Please sign in to comment.