diff --git a/abstractor/abstractor.go b/abstractor/abstractor.go index 01b11ac..94d6a7b 100644 --- a/abstractor/abstractor.go +++ b/abstractor/abstractor.go @@ -3,12 +3,16 @@ package abstractor import "github.com/consensys/gnark/frontend" type Gadget interface { - Call(args ...frontend.Variable) []frontend.Variable + Call(gadget GadgetDefinition) []frontend.Variable +} + +type GadgetDefinition interface { + DefineGadget(api API) []frontend.Variable } type API interface { frontend.API - DefineGadget(name string, arity int, constructor func(api API, args ...frontend.Variable) []frontend.Variable) Gadget + DefineGadget(gadget GadgetDefinition) Gadget } type Circuit interface { diff --git a/abstractor/concretizer.go b/abstractor/concretizer.go index b8dc328..395e145 100644 --- a/abstractor/concretizer.go +++ b/abstractor/concretizer.go @@ -7,12 +7,11 @@ import ( ) type ConcreteGadget struct { - api API - constructor func(api API, args ...frontend.Variable) []frontend.Variable + api API } -func (g *ConcreteGadget) Call(args ...frontend.Variable) []frontend.Variable { - return g.constructor(g.api, args...) +func (g *ConcreteGadget) Call(gadget GadgetDefinition) []frontend.Variable { + return gadget.DefineGadget(g.api) } type Concretizer struct { @@ -119,8 +118,8 @@ func (c *Concretizer) ConstantValue(v frontend.Variable) (*big.Int, bool) { return c.api.ConstantValue(v) } -func (c *Concretizer) DefineGadget(name string, arity int, constructor func(api API, args ...frontend.Variable) []frontend.Variable) Gadget { - return &ConcreteGadget{c, constructor} +func (c *Concretizer) DefineGadget(gadget GadgetDefinition) Gadget { + return &ConcreteGadget{c} } var _ API = &(Concretizer{}) diff --git a/extractor/extractor.go b/extractor/extractor.go index f98dda8..770a1c5 100644 --- a/extractor/extractor.go +++ b/extractor/extractor.go @@ -9,6 +9,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/schema" ) type Operand interface { @@ -44,6 +45,12 @@ type Proj struct { func (_ Proj) isOperand() {} +type ProjArray struct { + Proj []Operand +} + +func (_ ProjArray) isOperand() {} + type Op interface { isOp() } @@ -91,14 +98,27 @@ type ExGadget struct { Code []App Outputs []Operand Extractor *CodeExtractor + Fields []schema.Field + Args []ExArg } func (g *ExGadget) isOp() {} -func (g *ExGadget) Call(args ...frontend.Variable) []frontend.Variable { - if len(args) != g.Arity { - panic("wrong number of arguments") +func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable { + args := []frontend.Variable{} + + rv := reflect.Indirect(reflect.ValueOf(gadget)) + rt := rv.Type() + for i := 0; i < rt.NumField(); i++ { + fld := rt.Field(i) + v := rv.FieldByName(fld.Name) + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + args = append(args, v.Interface().([]frontend.Variable)) + } else { + args = append(args, v.Elem().Interface().(frontend.Variable)) + } } + gate := g.Extractor.AddApp(g, args...) outs := make([]frontend.Variable, len(g.Outputs)) if len(g.Outputs) == 1 { @@ -134,23 +154,33 @@ type CodeExtractor struct { Field ecc.ID } -func operandFromArray(arg []frontend.Variable) Operand { - return arg[0].(Proj).Operand -} - -func sanitizeVars(args ...frontend.Variable) []Operand { +func operandFromArray(args []frontend.Variable) []Operand { ops := make([]Operand, len(args)) for i, arg := range args { switch arg.(type) { case Input, Gate, Proj, Const: ops[i] = arg.(Operand) + default: + ops[i] = arg.(Proj).Operand + } + } + return ops +} + +func sanitizeVars(args ...frontend.Variable) []Operand { + ops := []Operand{} + for _, arg := range args { + switch arg.(type) { + case Input, Gate, Proj, Const: + ops = append(ops, arg.(Operand)) case int: - ops[i] = Const{big.NewInt(int64(arg.(int)))} + ops = append(ops, Const{big.NewInt(int64(arg.(int)))}) case big.Int: casted := arg.(big.Int) - ops[i] = Const{&casted} + ops = append(ops, Const{&casted}) case []frontend.Variable: - ops[i] = operandFromArray(arg.([]frontend.Variable)) + opsArray := operandFromArray(arg.([]frontend.Variable)) + ops = append(ops, ProjArray{opsArray}) default: fmt.Printf("invalid argument of type %T\n%#v\n", arg, arg) panic("invalid argument") @@ -160,7 +190,8 @@ func sanitizeVars(args ...frontend.Variable) []Operand { } func (ce *CodeExtractor) AddApp(op Op, args ...frontend.Variable) Operand { - ce.Code = append(ce.Code, App{op, sanitizeVars(args...)}) + app := App{op, sanitizeVars(args...)} + ce.Code = append(ce.Code, app) return Gate{len(ce.Code) - 1} } @@ -290,25 +321,43 @@ func (ce *CodeExtractor) ConstantValue(v frontend.Variable) (*big.Int, bool) { } } -func (ce *CodeExtractor) DefineGadget(name string, arity int, constructor func(api abstractor.API, args ...frontend.Variable) []frontend.Variable) abstractor.Gadget { +func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstractor.Gadget { + schema, _ := GetSchema(gadget) + CircuitInit(gadget, schema) + // Can't use `schema.NbPublic + schema.NbSecret` + // for arity because each array element is considered + // a parameter + arity := len(schema.Fields) + name := reflect.TypeOf(gadget).Elem().Name() + args := GetExArgs(gadget, schema.Fields) + + // To distinguish between gadgets instantiated with different array + // sizes, add a suffix to the name. The suffix of each instantiation + // is made up of the concatenation of the length of all the array + // fields in the gadget + suffix := "" + for _, a := range args { + if a.Kind == reflect.Array || a.Kind == reflect.Slice { + suffix += fmt.Sprintf("_%d", a.Type.Size) + } + } + oldCode := ce.Code ce.Code = make([]App, 0) - inputs := make([]frontend.Variable, arity) - for i := 0; i < arity; i++ { - inputs[i] = Input{i} - } - outputs := constructor(ce, inputs...) + outputs := gadget.DefineGadget(ce) newCode := ce.Code ce.Code = oldCode - gadget := ExGadget{ - Name: name, + exGadget := ExGadget{ + Name: fmt.Sprintf("%s%s", name, suffix), Arity: arity, Code: newCode, Outputs: sanitizeVars(outputs...), Extractor: ce, + Fields: schema.Fields, + Args: args, } - ce.Gadgets = append(ce.Gadgets, gadget) - return &gadget + ce.Gadgets = append(ce.Gadgets, exGadget) + return &exGadget } var _ abstractor.API = &CodeExtractor{} diff --git a/extractor/extractor_test.go b/extractor/extractor_test.go index 01f56b9..4d30fc9 100644 --- a/extractor/extractor_test.go +++ b/extractor/extractor_test.go @@ -9,12 +9,49 @@ import ( "github.com/consensys/gnark/frontend" ) +// Example: circuit with constant parameter +type SliceGadget struct { + In_1 []frontend.Variable + In_2 []frontend.Variable +} + +func (gadget SliceGadget) DefineGadget(api abstractor.API) []frontend.Variable { + for i := 0; i < len(gadget.In_1); i++ { + api.Mul(gadget.In_1[i], gadget.In_2[i]) + } + + r := api.FromBinary(gadget.In_1...) + return []frontend.Variable{r} +} + type CircuitWithParameter struct { - In frontend.Variable `gnark:",public"` + In frontend.Variable `gnark:",public"` + Path []frontend.Variable `gnark:",public"` + Tree []frontend.Variable `gnark:",public"` Param int } func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error { + slice_3 := api.DefineGadget(&SliceGadget{ + In_1: make([]frontend.Variable, 3), + In_2: make([]frontend.Variable, 3), + }) + + slice_2 := api.DefineGadget(&SliceGadget{ + In_1: make([]frontend.Variable, 2), + In_2: make([]frontend.Variable, 2), + }) + + api.FromBinary(circuit.Path...) + bin := api.ToBinary(circuit.In) + bin = api.ToBinary(circuit.Param) + + dec := api.FromBinary(bin...) + api.AssertIsEqual(circuit.Param, dec) + slice_3.Call(SliceGadget{circuit.Path, circuit.Path}) + + api.Mul(circuit.Path[0], circuit.Path[0]) + slice_2.Call(SliceGadget{circuit.Tree, circuit.Tree}) api.AssertIsEqual(circuit.Param, circuit.In) return nil @@ -25,7 +62,7 @@ func (circuit CircuitWithParameter) Define(api frontend.API) error { } func TestCircuitWithParameter(t *testing.T) { - assignment := CircuitWithParameter{} + assignment := CircuitWithParameter{Path: make([]frontend.Variable, 3), Tree: make([]frontend.Variable, 2)} assignment.Param = 20 err := CircuitToLean(&assignment, ecc.BW6_756) if err != nil { @@ -34,6 +71,17 @@ func TestCircuitWithParameter(t *testing.T) { } } +// Example: circuit with arrays and gadget +type DummyHash struct { + In_1 frontend.Variable + In_2 frontend.Variable +} + +func (gadget DummyHash) DefineGadget(api abstractor.API) []frontend.Variable { + r := api.Mul(gadget.In_1, gadget.In_2) + return []frontend.Variable{r} +} + type MerkleRecover struct { Root frontend.Variable `gnark:",public"` Element frontend.Variable `gnark:",public"` @@ -42,14 +90,12 @@ type MerkleRecover struct { } func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error { - hash := api.DefineGadget("hash", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable { - return []frontend.Variable{api.Mul(args[0], args[1])} - }) + hash := api.DefineGadget(&DummyHash{}) current := circuit.Element for i := 0; i < len(circuit.Path); i++ { - leftHash := hash.Call(current, circuit.Proof[i])[0] - rightHash := hash.Call(circuit.Proof[i], current)[0] + leftHash := hash.Call(DummyHash{current, circuit.Proof[i]})[0] + rightHash := hash.Call(DummyHash{circuit.Proof[i], current})[0] current = api.Select(circuit.Path[i], rightHash, leftHash) } api.AssertIsEqual(current, circuit.Root) @@ -70,28 +116,44 @@ func TestMerkleRecover(t *testing.T) { } } +// Example: circuit with multiple gadgets +type MyWidget struct { + Test_1 frontend.Variable + Test_2 frontend.Variable +} + +func (gadget MyWidget) DefineGadget(api abstractor.API) []frontend.Variable { + sum := api.Add(gadget.Test_1, gadget.Test_2) + mul := api.Mul(gadget.Test_1, gadget.Test_2) + r := api.Div(sum, mul) + return []frontend.Variable{r} +} + +type MySecondWidget struct { + Test_1 frontend.Variable + Test_2 frontend.Variable +} + +func (gadget MySecondWidget) DefineGadget(api abstractor.API) []frontend.Variable { + my_widget := api.DefineGadget(&MyWidget{}) + + mul := api.Mul(gadget.Test_1, gadget.Test_2) + snd := my_widget.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0] + r := api.Mul(mul, snd) + return []frontend.Variable{r} +} + type TwoGadgets struct { In_1 frontend.Variable In_2 frontend.Variable } func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error { - my_widget := api.DefineGadget("my_widget", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable { - sum := api.Add(args[0], args[1]) - mul := api.Mul(args[0], args[1]) - r := api.Div(sum, mul) - return []frontend.Variable{r} - }) - my_snd_widget := api.DefineGadget("my_snd_widget", 2, func(api abstractor.API, args ...frontend.Variable) []frontend.Variable { - mul := api.Mul(args[0], args[1]) - snd := my_widget.Call(args[0], args[1]) - r := api.Mul(mul, snd[0]) - return []frontend.Variable{r} - }) + my_snd_widget := api.DefineGadget(&MySecondWidget{}) sum := api.Add(circuit.In_1, circuit.In_2) prod := api.Mul(circuit.In_1, circuit.In_2) - my_snd_widget.Call(sum, prod) + my_snd_widget.Call(MySecondWidget{sum, prod}) return nil } diff --git a/extractor/lean_export.go b/extractor/lean_export.go index 3c0b4cc..0a47ee8 100644 --- a/extractor/lean_export.go +++ b/extractor/lean_export.go @@ -16,10 +16,7 @@ func ExportGadget(gadget ExGadget) string { if len(gadget.Outputs) > 1 { kArgsType = fmt.Sprintf("Vect F %d", len(gadget.Outputs)) } - inAssignment := make([]ExArg, gadget.Arity) - for i := 0; i < gadget.Arity; i++ { - inAssignment[i] = ExArg{fmt.Sprintf("in_%d", i), reflect.Interface, ExArgType{1, nil}} - } + inAssignment := gadget.Args return fmt.Sprintf("def %s %s (k: %s -> Prop): Prop :=\n%s", gadget.Name, genArgs(inAssignment), kArgsType, genGadgetBody(inAssignment, gadget)) } @@ -48,7 +45,7 @@ func ArrayInit(f schema.Field, v reflect.Value, op Operand) error { return nil } -func CircuitInit(class abstractor.Circuit, schema *schema.Schema) error { +func CircuitInit(class any, schema *schema.Schema) error { // https://stackoverflow.com/a/49704408 // https://stackoverflow.com/a/14162161 // https://stackoverflow.com/a/63422049 @@ -95,7 +92,7 @@ func CircuitInit(class abstractor.Circuit, schema *schema.Schema) error { return nil } -func KindOfField(a interface{}, s string) reflect.Kind { +func KindOfField(a any, s string) reflect.Kind { v := reflect.ValueOf(a).Elem() f := v.FieldByName(s) return f.Kind() @@ -114,8 +111,24 @@ func CircuitArgs(field schema.Field) ExArgType { } } +func GetExArgs(circuit any, fields []schema.Field) []ExArg { + args := []ExArg{} + for _, f := range fields { + kind := KindOfField(circuit, f.Name) + arg := ExArg{f.Name, kind, CircuitArgs(f)} + args = append(args, arg) + } + return args +} + +// Cloned version of NewSchema without constraints +func GetSchema(circuit any) (*schema.Schema, error) { + tVariable := reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() + return schema.New(circuit, tVariable) +} + func CircuitToLean(circuit abstractor.Circuit, field ecc.ID) error { - schema, err := frontend.NewSchema(circuit) + schema, err := GetSchema(circuit) if err != nil { return err } @@ -137,15 +150,8 @@ func CircuitToLean(circuit abstractor.Circuit, field ecc.ID) error { return err } - var circuitInputs []ExArg - for _, f := range schema.Fields { - kind := KindOfField(circuit, f.Name) - arg := ExArg{f.Name, kind, CircuitArgs(f)} - circuitInputs = append(circuitInputs, arg) - } - extractorCircuit := ExCircuit{ - Inputs: circuitInputs, + Inputs: GetExArgs(circuit, schema.Fields), Gadgets: api.Gadgets, Code: api.Code, } @@ -212,12 +218,12 @@ func assignGateVars(code []App, additional ...Operand) []string { func genGadgetCall(gateVar string, inAssignment []ExArg, gateVars []string, gadget *ExGadget, args []Operand) string { name := gadget.Name - operands := strings.Join(operandExprs(args, inAssignment, gateVars), " ") + operands := operandExprs(args, inAssignment, gateVars) binder := "_" if gateVar != "" { binder = gateVar } - return fmt.Sprintf(" %s %s fun %s =>\n", name, operands, binder) + return fmt.Sprintf(" %s %s fun %s =>\n", name, strings.Join(operands, " "), binder) } func genGateOp(op Op) string { @@ -290,9 +296,25 @@ func genFunctionalGate(gateVar string, op Op, operands []string) string { return fmt.Sprintf(" %s%s %s ∧\n", genGateBinder(gateVar), genGateOp(op), strings.Join(operands, " ")) } -func genCallbackGate(gateVar string, op Op, operands []string) string { +func genCallbackGate(gateVar string, op Op, operands []string, args []Operand) string { gateName := getGateName(gateVar, false) - return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) + switch op { + case OpFromBinary: + is_gate := reflect.TypeOf(args[0]) == reflect.TypeOf(Gate{}) + if len(args) == 1 && is_gate { + return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) + } + return fmt.Sprintf(" ∃%s, %s vec![%s] %s ∧\n", gateName, genGateOp(op), strings.Join(operands, ", "), gateName) + case OpToBinary: + is_const := reflect.TypeOf(args[0]) == reflect.TypeOf(Const{}) + if is_const { + operands[0] = fmt.Sprintf("(%s:F)", operands[0]) + return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) + } + return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) + default: + return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) + } } func genGenericGate(op Op, operands []string) string { @@ -325,10 +347,11 @@ func genOpCall(gateVar string, inAssignment []ExArg, gateVars []string, op Op, a } return finalStr } + default: + return genFunctionalGate(gateVar, op, operands) } - return genFunctionalGate(gateVar, op, operands) } else if callback { - return genCallbackGate(gateVar, op, operands) + return genCallbackGate(gateVar, op, operands, args) } else { return genGenericGate(op, operands) } @@ -377,6 +400,10 @@ func operandExpr(operand Operand, inAssignment []ExArg, gateVars []string) strin return gateVars[operand.(Gate).Index] case Proj: return fmt.Sprintf("%s[%d]", operandExpr(operand.(Proj).Operand, inAssignment, gateVars), operand.(Proj).Index) + case ProjArray: + opArray := operandExprs(operand.(ProjArray).Proj, inAssignment, gateVars) + opArray = []string{strings.Join(opArray, ", ")} + return fmt.Sprintf("vec!%s", opArray) case Const: return operand.(Const).Value.Text(10) default: