From 11fa256040ecb06fa0183f1774529386b39b4787 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewa Date: Wed, 24 Jul 2024 13:33:47 +0200 Subject: [PATCH] Support structs (#47) --- extractor/extractor.go | 8 +- extractor/interface.go | 15 +-- extractor/lean_export.go | 124 +++++++++--------- extractor/misc.go | 57 +------- extractor/test/nested_test.go | 40 ++++++ .../test/nested_with_slice_reuse_test.go | 41 ++++++ test/TestNestedCircuit.lean | 20 +++ test/TestNestedReusedCircuit.lean | 20 +++ 8 files changed, 193 insertions(+), 132 deletions(-) create mode 100644 extractor/test/nested_test.go create mode 100644 extractor/test/nested_with_slice_reuse_test.go create mode 100644 test/TestNestedCircuit.lean create mode 100644 test/TestNestedReusedCircuit.lean diff --git a/extractor/extractor.go b/extractor/extractor.go index c428219..546b66b 100644 --- a/extractor/extractor.go +++ b/extractor/extractor.go @@ -170,9 +170,8 @@ type ExArgType struct { } type ExArg struct { - Name string - Kind reflect.Kind - Type ExArgType + Name string + ArrayType *ExArgType } type ExCircuit struct { @@ -423,12 +422,11 @@ func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstra panic("DefineGadget only takes pointers to the gadget") } schema, _ := getSchema(gadget) - circuitInit(gadget, schema) + args := circuitInit(gadget, schema) // Can't use `schema.NbPublic + schema.NbSecret` // for arity because each array element is considered // a parameter arity := len(schema.Fields) - args := getExArgs(gadget, schema.Fields) name := generateUniqueName(gadget, args) diff --git a/extractor/interface.go b/extractor/interface.go index b12333a..223ecac 100644 --- a/extractor/interface.go +++ b/extractor/interface.go @@ -16,14 +16,12 @@ import ( // CircuitToLean(circuit abstractor.Circuit, field ecc.ID, namespace ...string) because the long term view // is to add an optional parameter to support custom `set_option` directives in the header. func CircuitToLeanWithName(circuit frontend.Circuit, field ecc.ID, namespace string) (out string, err error) { - defer recoverError() - schema, err := getSchema(circuit) if err != nil { return "", err } - circuitInit(circuit, schema) + inputs := circuitInit(circuit, schema) api := CodeExtractor{ Code: []App{}, @@ -37,7 +35,7 @@ func CircuitToLeanWithName(circuit frontend.Circuit, field ecc.ID, namespace str } extractorCircuit := ExCircuit{ - Inputs: getExArgs(circuit, schema.Fields), + Inputs: inputs, Gadgets: api.Gadgets, Code: api.Code, Field: api.FieldID, @@ -57,8 +55,6 @@ func CircuitToLean(circuit frontend.Circuit, field ecc.ID) (string, error) { // GadgetToLeanWithName exports a `gadget` to Lean over a `field` with `namespace` // Same notes written for CircuitToLeanWithName apply to GadgetToLeanWithName and GadgetToLean func GadgetToLeanWithName(gadget abstractor.GadgetDefinition, field ecc.ID, namespace string) (out string, err error) { - defer recoverError() - api := CodeExtractor{ Code: []App{}, Gadgets: []ExGadget{}, @@ -80,7 +76,6 @@ func GadgetToLean(gadget abstractor.GadgetDefinition, field ecc.ID) (string, err // ExtractCircuits is used to export a series of `circuits` to Lean over a `field` under `namespace`. func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circuit) (out string, err error) { - defer recoverError() api := CodeExtractor{ Code: []App{}, @@ -103,14 +98,14 @@ func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circui if err != nil { return "", err } - args := getExArgs(circuit, schema.Fields) + args := circuitInit(circuit, schema) + name := generateUniqueName(circuit, args) if slices.Contains(past_circuits, name) { continue } past_circuits = append(past_circuits, name) - circuitInit(circuit, schema) err = circuit.Define(&api) if err != nil { return "", err @@ -136,8 +131,6 @@ func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circui // ExtractGadgets is used to export a series of `gadgets` to Lean over a `field` under `namespace`. func ExtractGadgets(namespace string, field ecc.ID, gadgets ...abstractor.GadgetDefinition) (out string, err error) { - defer recoverError() - api := CodeExtractor{ Code: []App{}, Gadgets: []ExGadget{}, diff --git a/extractor/lean_export.go b/extractor/lean_export.go index 819c3bb..671b51b 100644 --- a/extractor/lean_export.go +++ b/extractor/lean_export.go @@ -99,76 +99,74 @@ func exportCircuit(circuit ExCircuit, name string) string { // circuitInit takes struct and a schema to populate all the // circuit/gagdget fields with Operand. -func circuitInit(class any, schema *schema.Schema) { - // https://stackoverflow.com/a/49704408 - // https://stackoverflow.com/a/14162161 - // https://stackoverflow.com/a/63422049 - +func circuitInit(circuit any, sch *schema.Schema) []ExArg { // The purpose of this function is to initialise the // struct fields with Operand interfaces for being // processed by the Extractor. - v := reflect.ValueOf(class) - if v.Type().Kind() == reflect.Ptr { - ptr := v - v = ptr.Elem() - } else { - ptr := reflect.New(reflect.TypeOf(class)) - temp := ptr.Elem() - temp.Set(v) - } - - tmp_c := reflect.ValueOf(&class).Elem().Elem() - tmp := reflect.New(tmp_c.Type()).Elem() - tmp.Set(tmp_c) - for j, f := range schema.Fields { - field_name := f.Name - field := v.FieldByName(field_name) - field_type := field.Type() - - // Can't assign an array to another array, therefore - // initialise each element in the array - - if field_type.Kind() == reflect.Array { - arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j}) - } else if field_type.Kind() == reflect.Slice { - // Recreate a zeroed array to remove overlapping pointers if input - // arguments are duplicated (i.e. `api.Call(SliceGadget{circuit.Path, circuit.Path})`) - arrayZero(tmp.Elem().FieldByName(field_name)) - arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j}) - } else if field_type.Kind() == reflect.Interface { - init := Input{j} - value := reflect.ValueOf(init) - - tmp.Elem().FieldByName(field_name).Set(value) - } else { - fmt.Printf("Skipped type %s\n", field_type.Kind()) + v := reflect.ValueOf(circuit) + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + v = v.Elem() + } + + return structInit(v, sch.Fields, 0, "") +} + +func structInit(val reflect.Value, fields []schema.Field, offset int, prefix string) []ExArg { + var result []ExArg + for _, f := range fields { + fieldVal := val.FieldByName(f.Name) + switch f.Type { + case schema.Leaf: + input := Input{offset} + value := reflect.ValueOf(input) + fieldVal.Set(value) + result = append(result, ExArg{prefix + f.Name, nil}) + offset++ + case schema.Array: + if fieldVal.Type().Kind() == reflect.Slice { + arrayZero(fieldVal) + } + arrTp := arrayInit(fieldVal, arraySubfield(f), Input{offset}) + result = append(result, ExArg{prefix + f.Name, arrTp}) + offset++ + case schema.Struct: + recurResult := structInit(val.FieldByName(f.Name), f.SubFields, offset, prefix+f.Name+"_") + result = append(result, recurResult...) + offset += len(recurResult) } } + return result } -func circuitArgs(field schema.Field) ExArgType { - // Handling only subfields which are nested arrays - switch len(field.SubFields) { - case 1: - subType := circuitArgs(field.SubFields[0]) - return ExArgType{field.ArraySize, &subType} - case 0: - return ExArgType{field.ArraySize, nil} - default: - panic("Only nested arrays supported in SubFields") +func arraySubfield(field schema.Field) *schema.Field { + if len(field.SubFields) == 0 { + return nil } + return &field.SubFields[0] } -// getExArgs generates a list of ExArg given a `circuit` and a -// list of `Field`. It is used in the Circuit to Lean functions -func getExArgs(circuit any, fields []schema.Field) []ExArg { - args := []ExArg{} - for _, f := range fields { - kind := kindOfField(circuit, f.Name) - arg := ExArg{f.Name, kind, circuitArgs(f)} - args = append(args, arg) +func arrayInit(val reflect.Value, elemField *schema.Field, baseVal Operand) *ExArgType { + var childType *ExArgType + if elemField == nil { + for i := 0; i < val.Len(); i++ { + proj := Proj{baseVal, i, val.Len()} + value := reflect.ValueOf(proj) + val.Index(i).Set(value) + } + } else { + switch elemField.Type { + case schema.Leaf: + panic("Gnark contract broken – leaf inside array should be nil") + case schema.Array: + for i := 0; i < val.Len(); i++ { + childType = arrayInit(val.Index(i), arraySubfield(*elemField), Proj{baseVal, i, val.Len()}) + } + case schema.Struct: + panic("Struct inside array not supported") + } + } - return args + return &ExArgType{val.Len(), childType} } // getSchema is a cloned version of NewSchema without constraints @@ -187,10 +185,10 @@ func genNestedArrays(a ExArgType) string { func genArgs(inAssignment []ExArg) string { args := make([]string, len(inAssignment)) for i, in := range inAssignment { - switch in.Kind { - case reflect.Array, reflect.Slice: - args[i] = fmt.Sprintf("(%s: %s)", in.Name, genNestedArrays(in.Type)) - default: + if in.ArrayType != nil { + args[i] = fmt.Sprintf("(%s: %s)", in.Name, genNestedArrays(*in.ArrayType)) + + } else { args[i] = fmt.Sprintf("(%s: F)", in.Name) } } diff --git a/extractor/misc.go b/extractor/misc.go index 9d100e7..7ee447e 100644 --- a/extractor/misc.go +++ b/extractor/misc.go @@ -1,34 +1,15 @@ package extractor import ( - "errors" - "flag" "fmt" "reflect" - "runtime/debug" "strings" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/schema" "github.com/mitchellh/copystructure" "github.com/reilabs/gnark-lean-extractor/v2/abstractor" ) -// recoverError is used in the top level interface to prevent panic -// caused by any of the methods in the extractor from propagating -// When go is running in test mode, it prints the stack trace to aid -// debugging. -func recoverError() (err error) { - if recover() != nil { - if flag.Lookup("test.v") != nil { - stack := string(debug.Stack()) - fmt.Println(stack) - } - err = errors.New("Panic extracting circuit to Lean") - } - return nil -} - // arrayToSlice returns a slice of elements identical to // the input array `v` func arrayToSlice(v reflect.Value) []frontend.Variable { @@ -82,29 +63,6 @@ func flattenSlice(value reflect.Value) []frontend.Variable { return value.Interface().([]frontend.Variable) } -// arrayInit generates the Proj{} object for each element of v -func arrayInit(f schema.Field, v reflect.Value, op Operand) error { - for i := 0; i < f.ArraySize; i++ { - op := Proj{op, i, f.ArraySize} - switch len(f.SubFields) { - case 1: - arrayInit(f.SubFields[0], v.Index(i), op) - case 0: - if v.Len() != f.ArraySize { - // Slices of this type aren't supported yet [[ ] [ ]] - // gnark newSchema doesn't handle different dimensions - fmt.Printf("Wrong slices dimensions %+v\n", v) - panic("Only slices dimensions not matching") - } - value := reflect.ValueOf(op) - v.Index(i).Set(value) - default: - panic("Only nested arrays supported in SubFields") - } - } - return nil -} - // arrayZero sets all the elements of the input slice v to nil. // It is used when initialising a new circuit or gadget to ensure // the object is clean @@ -119,8 +77,8 @@ func arrayZero(v reflect.Value) { arrayZero(v.Addr().Elem().Index(i)) } } else { - zero_array := make([]frontend.Variable, v.Len(), v.Len()) - v.Set(reflect.ValueOf(&zero_array).Elem()) + zeroArray := make([]frontend.Variable, v.Len()) + v.Set(reflect.ValueOf(zeroArray)) } } default: @@ -128,13 +86,6 @@ func arrayZero(v reflect.Value) { } } -// kindOfField returns the Kind of field in struct a -func kindOfField(a any, field string) reflect.Kind { - v := reflect.ValueOf(a).Elem() - f := v.FieldByName(field) - return f.Kind() -} - // getStructName returns the name of struct a func getStructName(a any) string { return reflect.TypeOf(a).Elem().Name() @@ -221,9 +172,9 @@ func cloneGadget(gadget abstractor.GadgetDefinition) abstractor.GadgetDefinition func generateUniqueName(element any, args []ExArg) string { suffix := "" for _, a := range args { - if a.Kind == reflect.Array || a.Kind == reflect.Slice { + if a.ArrayType != nil { suffix += "_" - suffix += strings.Join(getSizeGadgetArgs(a.Type), "_") + suffix += strings.Join(getSizeGadgetArgs(*a.ArrayType), "_") } } diff --git a/extractor/test/nested_test.go b/extractor/test/nested_test.go new file mode 100644 index 0000000..6c92c6e --- /dev/null +++ b/extractor/test/nested_test.go @@ -0,0 +1,40 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/v2/extractor" +) + +type Nest1 struct { + In1 frontend.Variable + In2 [5]frontend.Variable +} + +type Nest2 struct { + In1 [4][4]frontend.Variable + In2 [3]frontend.Variable +} + +type NestedCircuit struct { + N1 Nest1 + N2 Nest2 +} + +func (circuit *NestedCircuit) Define(api frontend.API) error { + sum := api.Add(circuit.N1.In2[2], circuit.N2.In2[0]) + api.AssertIsEqual(sum, circuit.N1.In2[1]) + return nil +} + +func TestNestedCircuit(t *testing.T) { + circuit := NestedCircuit{} + out, err := extractor.CircuitToLean(&circuit, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/nested_with_slice_reuse_test.go b/extractor/test/nested_with_slice_reuse_test.go new file mode 100644 index 0000000..6e0b8f7 --- /dev/null +++ b/extractor/test/nested_with_slice_reuse_test.go @@ -0,0 +1,41 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/v2/extractor" +) + +type N1 struct { + In1 []frontend.Variable +} + +type N2 struct { + In1 []frontend.Variable +} + +type NestedReusedCircuit struct { + N1 N1 + N2 N2 +} + +func (circuit *NestedReusedCircuit) Define(api frontend.API) error { + sum := api.Add(circuit.N1.In1[0], circuit.N2.In1[0]) + api.AssertIsEqual(sum, 0) + return nil +} + +func TestNestedReusedCircuit(t *testing.T) { + circuit := NestedReusedCircuit{} + ins := make([]frontend.Variable, 1) + circuit.N2.In1 = ins + circuit.N1.In1 = ins + out, err := extractor.CircuitToLean(&circuit, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/test/TestNestedCircuit.lean b/test/TestNestedCircuit.lean new file mode 100644 index 0000000..785fe22 --- /dev/null +++ b/test/TestNestedCircuit.lean @@ -0,0 +1,20 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace NestedCircuit + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order +abbrev Gates := GatesGnark8 Order + + + +def circuit (N1_In1: F) (N1_In2: Vector F 5) (N2_In1: Vector (Vector F 4) 4) (N2_In2: Vector F 3): Prop := + ∃gate_0, gate_0 = Gates.add N1_In2[2] N2_In2[0] ∧ + Gates.eq gate_0 N1_In2[1] ∧ + True + +end NestedCircuit \ No newline at end of file diff --git a/test/TestNestedReusedCircuit.lean b/test/TestNestedReusedCircuit.lean new file mode 100644 index 0000000..e33de86 --- /dev/null +++ b/test/TestNestedReusedCircuit.lean @@ -0,0 +1,20 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace NestedReusedCircuit + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order +abbrev Gates := GatesGnark8 Order + + + +def circuit (N1_In1: Vector F 1) (N2_In1: Vector F 1): Prop := + ∃gate_0, gate_0 = Gates.add N1_In1[0] N2_In1[0] ∧ + Gates.eq gate_0 (0:F) ∧ + True + +end NestedReusedCircuit \ No newline at end of file