From 4ac7e569c57a0ccb53257ca3e9f340328c0f47a8 Mon Sep 17 00:00:00 2001 From: Eagle941 <8973725+Eagle941@users.noreply.github.com> Date: Mon, 25 Sep 2023 16:42:44 +0100 Subject: [PATCH] feat: `void` return gadgets, multiple circuits export and `Vector` optimisation (#38) * Added check in CallGadget * Added support for gadgets with no return types * Fixed const type coercion * Added support for extraction of multiple gadgets * Added support of Compiler interface * Improved gadget naming * Fixes to nested array of int in gadgets * Added uninitialised structs * Fixed sanitizeVars to support all flavours of ints * Added extraction of multiple circuits * Added vector optimisation * Added vector optimisation in gadget return * Added size check in isVectorComplete * Added support for array of ints as parameters * Fixed debug * Fixed use of append * Improved gotest * Fixed gate vector optimisation * Fixed clone gadget function * Added 2dim slices optimisation * Added tests * Added support of n-dim slices optimisation * Added support for multiple gadget return types * Fix replaceArg * Fixed nested slices * Moved public mathods to interface.go * Tests moved to test folder * Moving generic functions to file misc.go * Added comments and refactoring * Added comments and error handling * Gofmt --- README.md | 8 +- abstractor/abstractor.go | 15 +- abstractor/concretizer.go | 4 +- extractor/extractor.go | 252 ++++++---- extractor/extractor_test.go | 241 ---------- extractor/interface.go | 186 ++++++++ extractor/lean_export.go | 443 +++++++++++------- extractor/misc.go | 264 +++++++++++ extractor/test/another_circuit_test.go | 62 +++ extractor/test/circuit_with_parameter_test.go | 92 ++++ extractor/test/deletion_mbu_circuit_test.go | 86 ++++ extractor/test/merkle_recover_test.go | 54 +++ extractor/test/my_circuit_test.go | 37 ++ extractor/test/slices_optimisation_test.go | 110 +++++ extractor/test/to_binary_circuit_test.go | 91 ++++ extractor/test/two_gadgets_test.go | 121 +++++ extractor/test/utils_test.go | 62 +++ go.mod | 5 +- go.sum | 10 +- test/TestAnotherCircuit.lean | 24 + test/TestCircuitWithParameter.lean | 53 +++ test/TestDeletionMbuCircuit.lean | 20 + test/TestExtractCircuits.lean | 134 ++++++ test/TestExtractGadgets.lean | 42 ++ test/TestExtractGadgetsVectors.lean | 28 ++ test/TestGadgetExtraction.lean | 18 + test/TestMerkleRecover.lean | 80 ++++ test/TestMyCircuit.lean | 19 + test/TestSlicesOptimisation.lean | 35 ++ test/TestToBinaryCircuit.lean | 28 ++ test/TestTwoGadgets.lean | 31 ++ 31 files changed, 2136 insertions(+), 519 deletions(-) delete mode 100644 extractor/extractor_test.go create mode 100644 extractor/interface.go create mode 100644 extractor/misc.go create mode 100644 extractor/test/another_circuit_test.go create mode 100644 extractor/test/circuit_with_parameter_test.go create mode 100644 extractor/test/deletion_mbu_circuit_test.go create mode 100644 extractor/test/merkle_recover_test.go create mode 100644 extractor/test/my_circuit_test.go create mode 100644 extractor/test/slices_optimisation_test.go create mode 100644 extractor/test/to_binary_circuit_test.go create mode 100644 extractor/test/two_gadgets_test.go create mode 100644 extractor/test/utils_test.go create mode 100644 test/TestAnotherCircuit.lean create mode 100644 test/TestCircuitWithParameter.lean create mode 100644 test/TestDeletionMbuCircuit.lean create mode 100644 test/TestExtractCircuits.lean create mode 100644 test/TestExtractGadgets.lean create mode 100644 test/TestExtractGadgetsVectors.lean create mode 100644 test/TestGadgetExtraction.lean create mode 100644 test/TestMerkleRecover.lean create mode 100644 test/TestMyCircuit.lean create mode 100644 test/TestSlicesOptimisation.lean create mode 100644 test/TestToBinaryCircuit.lean create mode 100644 test/TestTwoGadgets.lean diff --git a/README.md b/README.md index a14c09e..5dbc80e 100644 --- a/README.md +++ b/README.md @@ -40,15 +40,15 @@ func (circuit *MyCircuit) AbsDefine(api abstractor.API) error { return nil } -func (circuit *MyCircuit) Define(api frontend.API) error { - return abstractor.Concretize(api, circuit) +func (circuit MyCircuit) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) } ``` Once you export this to Lean, you get a definition as follows: ```lean -namespace DummyCircuit +namespace MyCircuit def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 variable [Fact (Nat.Prime Order)] @@ -59,7 +59,7 @@ def circuit (In_1: F) (In_2: F) (Out: F): Prop := Gates.eq gate_0 Out ∧ True -end DummyCircuit +end MyCircuit ``` Further examples of this process with various levels of complexity can be seen diff --git a/abstractor/abstractor.go b/abstractor/abstractor.go index 169cf5c..4b8bb85 100644 --- a/abstractor/abstractor.go +++ b/abstractor/abstractor.go @@ -3,11 +3,11 @@ package abstractor import "github.com/consensys/gnark/frontend" type Gadget interface { - Call(gadget GadgetDefinition) []frontend.Variable + Call(gadget GadgetDefinition) interface{} } type GadgetDefinition interface { - DefineGadget(api API) []frontend.Variable + DefineGadget(api API) interface{} } type API interface { @@ -15,7 +15,7 @@ type API interface { DefineGadget(gadget GadgetDefinition) Gadget frontend.API - Call(gadget GadgetDefinition) []frontend.Variable + Call(gadget GadgetDefinition) interface{} } type Circuit interface { @@ -27,6 +27,13 @@ func Concretize(api frontend.API, circuit Circuit) error { return circuit.AbsDefine(&Concretizer{api}) } -func CallGadget(api frontend.API, circuit GadgetDefinition) []frontend.Variable { +func CallGadget(api frontend.API, circuit GadgetDefinition) interface{} { + _, ok := api.(API) + if ok { + // The consequence of calling CallGadget with abstractor.API is that + // the circuit is extracted as a single function instead of + // splitting in sub-circuits + panic("abstractor.CallGadget can't be called with abstractor.API") + } return circuit.DefineGadget(&Concretizer{api}) } diff --git a/abstractor/concretizer.go b/abstractor/concretizer.go index 717e6c0..bd76575 100644 --- a/abstractor/concretizer.go +++ b/abstractor/concretizer.go @@ -10,7 +10,7 @@ type ConcreteGadget struct { api API } -func (g *ConcreteGadget) Call(gadget GadgetDefinition) []frontend.Variable { +func (g *ConcreteGadget) Call(gadget GadgetDefinition) interface{} { return gadget.DefineGadget(g.api) } @@ -122,7 +122,7 @@ func (c *Concretizer) DefineGadget(gadget GadgetDefinition) Gadget { return &ConcreteGadget{c} } -func (c *Concretizer) Call(gadget GadgetDefinition) []frontend.Variable { +func (c *Concretizer) Call(gadget GadgetDefinition) interface{} { return c.DefineGadget(gadget).Call(gadget) } diff --git a/extractor/extractor.go b/extractor/extractor.go index afdcb31..586d8db 100644 --- a/extractor/extractor.go +++ b/extractor/extractor.go @@ -4,7 +4,6 @@ import ( "fmt" "math/big" "reflect" - "strings" "github.com/reilabs/gnark-lean-extractor/abstractor" @@ -24,6 +23,16 @@ type Const struct { func (_ Const) isOperand() {} +// Integer struct is used to distinguish between a constant in +// place of a frontend.Variable and an integer where an integer +// is the only type allowed. Integer sruct is currently only +// used for the length of the result in ToBinary function. +type Integer struct { + Value *big.Int +} + +func (_ Integer) isOperand() {} + type Gate struct { Index int } @@ -40,15 +49,20 @@ func (_ Input) isOperand() {} // Index is the index to be accessed in the array // Operand[Index] +// Size is a placeholder to keep track of the whole +// array size. It is essential to know if the whole +// vector or only a slice is used as function +// argument. type Proj struct { Operand Operand Index int + Size int } func (_ Proj) isOperand() {} type ProjArray struct { - Proj []Operand + Projs []Operand } func (_ ProjArray) isOperand() {} @@ -95,68 +109,57 @@ type Code struct { } type ExGadget struct { - Name string - Arity int - Code []App - Outputs []Operand - Extractor *CodeExtractor - Fields []schema.Field - Args []ExArg + Name string + Arity int + Code []App + OutputsFlat []Operand + Outputs interface{} + Extractor *CodeExtractor + Fields []schema.Field + Args []ExArg } func (g *ExGadget) isOp() {} -func ArrayToSlice(v reflect.Value) []frontend.Variable { - res := make([]frontend.Variable, v.Len()) - - for i := 0; i < v.Len(); i++ { - res[i] = v.Index(i).Elem().Interface().(frontend.Variable) - } - - return res -} - -func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable { +func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) interface{} { args := []frontend.Variable{} rv := reflect.Indirect(reflect.ValueOf(gadget)) rt := rv.Type() + // Looping through the circuit fields only. for i := 0; i < rt.NumField(); i++ { fld := rt.Field(i) v := rv.FieldByName(fld.Name) switch v.Kind() { case reflect.Slice: - args = append(args, v.Interface().([]frontend.Variable)) + arg := flattenSlice(v) + if len(arg) != 0 { + args = append(args, arg) + } case reflect.Array: // I can't convert from array to slice using Reflect because - // the field is unaddressable. - args = append(args, ArrayToSlice(v)) + // the field is unaddressable. Therefore I recreate a slice + // with the same elements as the input array. + arg := arrayToSlice(v) + // Checking length != 0 because I need to keep nested slices + // as nested slices, but not empty slices + if len(arg) != 0 { + args = append(args, arg) + } case reflect.Interface: args = append(args, v.Elem().Interface().(frontend.Variable)) } } - gate := g.Extractor.AddApp(g, args...) - outs := make([]frontend.Variable, len(g.Outputs)) - if len(g.Outputs) == 1 { - outs[0] = gate - } else { - for i := range g.Outputs { - outs[i] = Proj{gate, i} - } - } - return outs -} -func cloneGadget(gadget abstractor.GadgetDefinition) abstractor.GadgetDefinition { - v := reflect.ValueOf(gadget) - tmp_gadget := reflect.New(v.Type()) - tmp_gadget.Elem().Set(v) - return tmp_gadget.Interface().(abstractor.GadgetDefinition) + res := replaceArg(g.Outputs, gate) + return res } -func (ce *CodeExtractor) Call(gadget abstractor.GadgetDefinition) []frontend.Variable { - // Copying `gadget` because `DefineGadget` needs to manipulate the input +func (ce *CodeExtractor) Call(gadget abstractor.GadgetDefinition) interface{} { + // Deep copying `gadget` because `DefineGadget` needs to modify the gadget fields. + // This was done as a replacement to the initial method of declaring gadgets using + // a direct call to `Define Gadget` within the circuit and then calling GadgetDefinition.Call clonedGadget := cloneGadget(gadget) g := ce.DefineGadget(clonedGadget) return g.Call(gadget) @@ -178,13 +181,12 @@ type ExCircuit struct { Gadgets []ExGadget Code []App Field ecc.ID - Name string } type CodeExtractor struct { Code []App Gadgets []ExGadget - Field ecc.ID + FieldID ecc.ID } func sanitizeVars(args ...frontend.Variable) []Operand { @@ -193,17 +195,41 @@ func sanitizeVars(args ...frontend.Variable) []Operand { switch arg.(type) { case Input, Gate, Proj, Const: ops = append(ops, arg.(Operand)) + case Integer: + ops = append(ops, arg.(Operand)) case int: - ops = append(ops, Const{big.NewInt(int64(arg.(int)))}) + ops = append(ops, Const{new(big.Int).SetInt64(int64(arg.(int)))}) + case int8: + ops = append(ops, Const{new(big.Int).SetInt64(int64(arg.(int8)))}) + case int16: + ops = append(ops, Const{new(big.Int).SetInt64(int64(arg.(int16)))}) + case int32: + ops = append(ops, Const{new(big.Int).SetInt64(int64(arg.(int32)))}) + case int64: + ops = append(ops, Const{new(big.Int).SetInt64(arg.(int64))}) + case uint: + ops = append(ops, Const{new(big.Int).SetUint64(uint64(arg.(uint)))}) + case uint8: + ops = append(ops, Const{new(big.Int).SetUint64(uint64(arg.(uint8)))}) + case uint16: + ops = append(ops, Const{new(big.Int).SetUint64(uint64(arg.(uint16)))}) + case uint32: + ops = append(ops, Const{new(big.Int).SetUint64(uint64(arg.(uint32)))}) + case uint64: + ops = append(ops, Const{new(big.Int).SetUint64(arg.(uint64))}) case big.Int: casted := arg.(big.Int) ops = append(ops, Const{&casted}) case []frontend.Variable: opsArray := sanitizeVars(arg.([]frontend.Variable)...) ops = append(ops, ProjArray{opsArray}) + case nil: + // This takes care of uninitialised fields that are + // passed to gadgets + ops = append(ops, Const{big.NewInt(int64(0))}) default: - fmt.Printf("invalid argument of type %T\n%#v\n", arg, arg) - panic("invalid argument") + fmt.Printf("sanitizeVars invalid argument of type %T\n%#v\n", arg, arg) + panic("sanitizeVars invalid argument") } } return ops @@ -248,7 +274,7 @@ func (ce *CodeExtractor) Inverse(i1 frontend.Variable) frontend.Variable { } func (ce *CodeExtractor) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { - nbBits := ce.Field.ScalarField().BitLen() + nbBits := ce.FieldID.ScalarField().BitLen() if len(n) == 1 { nbBits = n[0] if nbBits < 0 { @@ -256,16 +282,22 @@ func (ce *CodeExtractor) ToBinary(i1 frontend.Variable, n ...int) []frontend.Var } } - gate := ce.AddApp(OpToBinary, i1, nbBits) + gate := ce.AddApp(OpToBinary, i1, Integer{big.NewInt(int64(nbBits))}) outs := make([]frontend.Variable, nbBits) for i := range outs { - outs[i] = Proj{gate, i} + outs[i] = Proj{gate, i, len(outs)} } return outs } func (ce *CodeExtractor) FromBinary(b ...frontend.Variable) frontend.Variable { // Packs in little-endian + if len(b) == 0 { + panic("FromBinary has to have at least one argument!") + } + if reflect.TypeOf(b[0]) == reflect.TypeOf([]frontend.Variable{}) { + panic("Pass operators to FromBinary using ellipsis") + } return ce.AddApp(OpFromBinary, append([]frontend.Variable{}, b...)...) } @@ -318,6 +350,27 @@ func (ce *CodeExtractor) Println(a ...frontend.Variable) { } func (ce *CodeExtractor) Compiler() frontend.Compiler { + return ce +} + +func (ce *CodeExtractor) MarkBoolean(v frontend.Variable) { + panic("implement me") +} + +func (ce *CodeExtractor) IsBoolean(v frontend.Variable) bool { + panic("implement me") +} + +func (ce *CodeExtractor) Field() *big.Int { + scalarField := ce.FieldID.ScalarField() + return new(big.Int).Set(scalarField) +} + +func (ce *CodeExtractor) FieldBitLen() int { + return ce.FieldID.ScalarField().BitLen() +} + +func (ce *CodeExtractor) Commit(...frontend.Variable) (frontend.Variable, error) { panic("implement me") } @@ -330,14 +383,34 @@ func (ce *CodeExtractor) ConstantValue(v frontend.Variable) (*big.Int, bool) { case Const: return v.(Const).Value, true case Proj: - switch v.(Proj).Operand.(type) { - case Const: - return v.(Proj).Operand.(Const).Value, true - default: - return nil, false + { + switch v.(Proj).Operand.(type) { + case Const: + return v.(Proj).Operand.(Const).Value, true + default: + return nil, false + } } + case int: + return new(big.Int).SetInt64(int64(v.(int))), true + case int8: + return new(big.Int).SetInt64(int64(v.(int8))), true + case int16: + return new(big.Int).SetInt64(int64(v.(int16))), true + case int32: + return new(big.Int).SetInt64(int64(v.(int32))), true case int64: - return big.NewInt(v.(int64)), true + return new(big.Int).SetInt64(v.(int64)), true + case uint: + return new(big.Int).SetUint64(uint64(v.(uint))), true + case uint8: + return new(big.Int).SetUint64(uint64(v.(uint8))), true + case uint16: + return new(big.Int).SetUint64(uint64(v.(uint16))), true + case uint32: + return new(big.Int).SetUint64(uint64(v.(uint32))), true + case uint64: + return new(big.Int).SetUint64(v.(uint64)), true case big.Int: casted := v.(big.Int) return &casted, true @@ -346,46 +419,19 @@ func (ce *CodeExtractor) ConstantValue(v frontend.Variable) (*big.Int, bool) { } } -func getGadgetByName(gadgets []ExGadget, name string) abstractor.Gadget { - for _, gadget := range gadgets { - if gadget.Name == name { - return &gadget - } - } - return nil -} - -func getSize(elem ExArgType) []string { - if elem.Type == nil { - return []string{fmt.Sprintf("%d", elem.Size)} - } - return append(getSize(*elem.Type), fmt.Sprintf("%d", elem.Size)) -} - func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstractor.Gadget { if reflect.ValueOf(gadget).Kind() != reflect.Ptr { panic("DefineGadget only takes pointers to the gadget") } - schema, _ := GetSchema(gadget) - CircuitInit(gadget, schema) + schema, _ := getSchema(gadget) + circuitInit(gadget, schema) // Can't use `schema.NbPublic + schema.NbSecret` // for arity because each array element is considered // a parameter arity := len(schema.Fields) - args := GetExArgs(gadget, schema.Fields) - - // To distinguish between gadgets instantiated with different array - // sizes, add a suffix to the name. The suffix of each instantiation - // is made up of the concatenation of the length of all the array - // fields in the gadget - suffix := "" - for _, a := range args { - if a.Kind == reflect.Array || a.Kind == reflect.Slice { - suffix += "_" - suffix += strings.Join(getSize(a.Type), "_") - } - } - name := fmt.Sprintf("%s%s", reflect.TypeOf(gadget).Elem().Name(), suffix) + args := getExArgs(gadget, schema.Fields) + + name := generateUniqueName(gadget, args) ptr_gadget := getGadgetByName(ce.Gadgets, name) if ptr_gadget != nil { @@ -395,16 +441,34 @@ func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstra oldCode := ce.Code ce.Code = make([]App, 0) outputs := gadget.DefineGadget(ce) + + // Handle gadgets returning nil. + // Without the if-statement, the nil would be replaced with (0:F) + // due to the case in sanitizeVars + if outputs == nil { + outputs = []frontend.Variable{} + } + + // flattenSlice needs to be called only if there are nested + // slices in order to generate a slice of Operand. + // TODO: remove `OutputsFlat` field and use only `Outputs` + flatOutput := []frontend.Variable{outputs} + vOutputs := reflect.ValueOf(outputs) + if vOutputs.Kind() == reflect.Slice { + flatOutput = flattenSlice(vOutputs) + } + newCode := ce.Code ce.Code = oldCode exGadget := ExGadget{ - Name: name, - Arity: arity, - Code: newCode, - Outputs: sanitizeVars(outputs...), - Extractor: ce, - Fields: schema.Fields, - Args: args, + Name: name, + Arity: arity, + Code: newCode, + OutputsFlat: sanitizeVars(flatOutput...), + Outputs: outputs, + Extractor: ce, + Fields: schema.Fields, + Args: args, } ce.Gadgets = append(ce.Gadgets, exGadget) return &exGadget diff --git a/extractor/extractor_test.go b/extractor/extractor_test.go deleted file mode 100644 index 6c27bf6..0000000 --- a/extractor/extractor_test.go +++ /dev/null @@ -1,241 +0,0 @@ -package extractor - -import ( - "fmt" - "log" - "testing" - - "github.com/reilabs/gnark-lean-extractor/abstractor" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/stretchr/testify/assert" -) - -// Example: ToBinary behaviour and nested Slice -type VectorGadget struct { - In_1 []frontend.Variable - In_2 []frontend.Variable -} - -func (gadget VectorGadget) DefineGadget(api abstractor.API) []frontend.Variable { - var sum frontend.Variable - for i := 0; i < len(gadget.In_1); i++ { - sum = api.Mul(gadget.In_1[i], gadget.In_2[i]) - } - return []frontend.Variable{sum, sum, sum} -} - -type ToBinaryCircuit struct { - In frontend.Variable `gnark:",public"` - Out frontend.Variable `gnark:",public"` - Double [][]frontend.Variable `gnark:",public"` -} - -func (circuit *ToBinaryCircuit) AbsDefine(api abstractor.API) error { - bin := api.ToBinary(circuit.In, 3) - bout := api.ToBinary(circuit.Out, 3) - - api.Add(circuit.Double[2][2], circuit.Double[1][1], circuit.Double[0][0]) - api.Mul(bin[1], bout[1]) - d := api.Call(VectorGadget{circuit.Double[2][:], circuit.Double[0][:]}) - api.Mul(d[2], d[1]) - - return nil -} - -func (circuit ToBinaryCircuit) Define(api frontend.API) error { - return abstractor.Concretize(api, &circuit) -} - -func TestToBinaryCircuit(t *testing.T) { - dim_1 := 3 - dim_2 := 3 - doubleSlice := make([][]frontend.Variable, dim_1) - for i := 0; i < int(dim_1); i++ { - doubleSlice[i] = make([]frontend.Variable, dim_2) - } - assignment := ToBinaryCircuit{Double: doubleSlice} - out, err := CircuitToLean(&assignment, ecc.BN254) - if err != nil { - log.Fatal(err) - } - fmt.Println(out) -} - -// Example: readme circuit -type DummyCircuit struct { - In_1 frontend.Variable - In_2 frontend.Variable - Out frontend.Variable -} - -func (circuit *DummyCircuit) AbsDefine(api abstractor.API) error { - sum := api.Add(circuit.In_1, circuit.In_2) - api.AssertIsEqual(sum, circuit.Out) - return nil -} - -func (circuit DummyCircuit) Define(api frontend.API) error { - return abstractor.Concretize(api, &circuit) -} - -func TestDummyCircuit(t *testing.T) { - assignment := DummyCircuit{} - out, err := CircuitToLean(&assignment, ecc.BN254) - if err != nil { - log.Fatal(err) - } - fmt.Println(out) -} - -// Example: circuit with constant parameter -type SliceGadget struct { - In_1 []frontend.Variable - In_2 []frontend.Variable -} - -func (gadget SliceGadget) DefineGadget(api abstractor.API) []frontend.Variable { - for i := 0; i < len(gadget.In_1); i++ { - api.Mul(gadget.In_1[i], gadget.In_2[i]) - } - - r := api.FromBinary(gadget.In_1...) - return []frontend.Variable{r} -} - -type CircuitWithParameter struct { - In frontend.Variable `gnark:",public"` - Path []frontend.Variable `gnark:",public"` - Tree []frontend.Variable `gnark:",public"` - Param int -} - -func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error { - api.FromBinary(circuit.Path...) - bin := api.ToBinary(circuit.In) - bin = api.ToBinary(circuit.Param) - - dec := api.FromBinary(bin...) - api.AssertIsEqual(circuit.Param, dec) - api.Call(SliceGadget{circuit.Path, circuit.Path}) - - api.Mul(circuit.Path[0], circuit.Path[0]) - api.Call(SliceGadget{circuit.Tree, circuit.Tree}) - api.AssertIsEqual(circuit.Param, circuit.In) - - return nil -} - -func (circuit CircuitWithParameter) Define(api frontend.API) error { - return abstractor.Concretize(api, &circuit) -} - -func TestCircuitWithParameter(t *testing.T) { - paramValue := 20 - assignment := CircuitWithParameter{Path: make([]frontend.Variable, 3), Tree: make([]frontend.Variable, 2)} - assignment.Param = paramValue - assert.Equal(t, assignment.Param, paramValue, "assignment.Param is a const and should be 20.") - out, err := CircuitToLean(&assignment, ecc.BN254) - if err != nil { - log.Fatal(err) - } - fmt.Println(out) -} - -// Example: circuit with arrays and gadget -type DummyHash struct { - In_1 frontend.Variable - In_2 frontend.Variable -} - -func (gadget DummyHash) DefineGadget(api abstractor.API) []frontend.Variable { - r := api.Mul(gadget.In_1, gadget.In_2) - return []frontend.Variable{r} -} - -type MerkleRecover struct { - Root frontend.Variable `gnark:",public"` - Element frontend.Variable `gnark:",public"` - Path [20]frontend.Variable `gnark:",secret"` - Proof [20]frontend.Variable `gnark:",secret"` -} - -func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error { - current := circuit.Element - for i := 0; i < len(circuit.Path); i++ { - leftHash := api.Call(DummyHash{current, circuit.Proof[i]})[0] - rightHash := api.Call(DummyHash{circuit.Proof[i], current})[0] - current = api.Select(circuit.Path[i], rightHash, leftHash) - } - api.AssertIsEqual(current, circuit.Root) - - return nil -} - -func (circuit MerkleRecover) Define(api frontend.API) error { - return abstractor.Concretize(api, &circuit) -} - -func TestMerkleRecover(t *testing.T) { - assignment := MerkleRecover{} - out, err := CircuitToLean(&assignment, ecc.BN254) - if err != nil { - log.Fatal(err) - } - fmt.Println(out) -} - -// Example: circuit with multiple gadgets -type MyWidget struct { - Test_1 frontend.Variable - Test_2 frontend.Variable - Num int -} - -func (gadget MyWidget) DefineGadget(api abstractor.API) []frontend.Variable { - sum := api.Add(gadget.Test_1, gadget.Test_2) - mul := api.Mul(gadget.Test_1, gadget.Test_2) - r := api.Div(sum, mul) - api.AssertIsBoolean(gadget.Num) - return []frontend.Variable{r} -} - -type MySecondWidget struct { - Test_1 frontend.Variable - Test_2 frontend.Variable - Num int -} - -func (gadget MySecondWidget) DefineGadget(api abstractor.API) []frontend.Variable { - mul := api.Mul(gadget.Test_1, gadget.Test_2) - snd := api.Call(MyWidget{gadget.Test_1, gadget.Test_2, gadget.Num})[0] - r := api.Mul(mul, snd) - return []frontend.Variable{r} -} - -type TwoGadgets struct { - In_1 frontend.Variable - In_2 frontend.Variable - Num int -} - -func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error { - sum := api.Add(circuit.In_1, circuit.In_2) - prod := api.Mul(circuit.In_1, circuit.In_2) - api.Call(MySecondWidget{sum, prod, circuit.Num}) - return nil -} - -func (circuit TwoGadgets) Define(api frontend.API) error { - return abstractor.Concretize(api, &circuit) -} - -func TestTwoGadgets(t *testing.T) { - assignment := TwoGadgets{Num: 11} - out, err := CircuitToLean(&assignment, ecc.BN254) - if err != nil { - log.Fatal(err) - } - fmt.Println(out) -} diff --git a/extractor/interface.go b/extractor/interface.go new file mode 100644 index 0000000..115b4d8 --- /dev/null +++ b/extractor/interface.go @@ -0,0 +1,186 @@ +// This file contains the public API for using the extractor. +// The Call functions are used to call gadgets and get their returnd object. +// These methods are prepared for doing automated casting from interface{}. +// Alternatively it's possible to do manual casting by calling +// abstractor.API.Call() and casting the result to the needed type. +package extractor + +import ( + "fmt" + "strings" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "golang.org/x/exp/slices" +) + +// CallVoid is used to call a Gadget which doesn't return anything +func CallVoid(api abstractor.API, gadget abstractor.GadgetDefinition) { + api.Call(gadget) +} + +// Call is used to call a Gadget which returns frontend.Variable (i.e. a single element `F` in Lean) +func Call(api abstractor.API, gadget abstractor.GadgetDefinition) frontend.Variable { + return api.Call(gadget).(frontend.Variable) +} + +// Call1 is used to call a Gadget which returns []frontend.Variable (i.e. `Vector F d` in Lean) +func Call1(api abstractor.API, gadget abstractor.GadgetDefinition) []frontend.Variable { + return api.Call(gadget).([]frontend.Variable) +} + +// Call2 is used to call a Gadget which returns a [][]frontend.Variable +// (i.e. `Vector (Vector F a) b` in Lean) +func Call2(api abstractor.API, gadget abstractor.GadgetDefinition) [][]frontend.Variable { + return api.Call(gadget).([][]frontend.Variable) +} + +// Call3 is used to call a Gadget which returns a [][][]frontend.Variable +// (i.e. `Vector (Vector (Vector F a) b) c` in Lean) +func Call3(api abstractor.API, gadget abstractor.GadgetDefinition) [][][]frontend.Variable { + return api.Call(gadget).([][][]frontend.Variable) +} + +// CircuitToLeanWithName exports a `circuit` to Lean over a `field` with `namespace` +// CircuitToLeanWithName and CircuitToLean aren't joined in a single function +// CircuitToLean(circuit abstractor.Circuit, field ecc.ID, namespace ...string) because the long term view +// is to add an optional parameter to support custom `set_option` directives in the header. +func CircuitToLeanWithName(circuit abstractor.Circuit, field ecc.ID, namespace string) (out string, err error) { + defer recoverError() + + schema, err := getSchema(circuit) + if err != nil { + return "", err + } + + circuitInit(circuit, schema) + + api := CodeExtractor{ + Code: []App{}, + Gadgets: []ExGadget{}, + FieldID: field, + } + + err = circuit.AbsDefine(&api) + if err != nil { + return "", err + } + + extractorCircuit := ExCircuit{ + Inputs: getExArgs(circuit, schema.Fields), + Gadgets: api.Gadgets, + Code: api.Code, + Field: api.FieldID, + } + out = exportCircuit(extractorCircuit, namespace) + return out, nil +} + +// CircuitToLean exports a `circuit` to Lean over a `field` with the namespace being the +// struct name of `circuit` +// When the namespace argument is not defined, it uses the name of the struct circuit +func CircuitToLean(circuit abstractor.Circuit, field ecc.ID) (string, error) { + name := getStructName(circuit) + return CircuitToLeanWithName(circuit, field, name) +} + +// GadgetToLeanWithName exports a `gadget` to Lean over a `field` with `namespace` +// Same notes written for CircuitToLeanWithName apply to GadgetToLeanWithName and GadgetToLean +func GadgetToLeanWithName(gadget abstractor.GadgetDefinition, field ecc.ID, namespace string) (out string, err error) { + defer recoverError() + + api := CodeExtractor{ + Code: []App{}, + Gadgets: []ExGadget{}, + FieldID: field, + } + + api.DefineGadget(gadget) + gadgets := exportGadgets(api.Gadgets) + prelude := exportPrelude(namespace, api.FieldID.ScalarField()) + footer := exportFooter(namespace) + return fmt.Sprintf("%s\n\n%s\n\n%s", prelude, gadgets, footer), nil +} + +// GadgetToLean exports a `gadget` to Lean over a `field` +func GadgetToLean(gadget abstractor.GadgetDefinition, field ecc.ID) (string, error) { + name := getStructName(gadget) + return GadgetToLeanWithName(gadget, field, name) +} + +// ExtractCircuits is used to export a series of `circuits` to Lean over a `field` under `namespace`. +func ExtractCircuits(namespace string, field ecc.ID, circuits ...abstractor.Circuit) (out string, err error) { + defer recoverError() + + api := CodeExtractor{ + Code: []App{}, + Gadgets: []ExGadget{}, + FieldID: field, + } + + var circuits_extracted []string + var past_circuits []string + + extractorCircuit := ExCircuit{ + Inputs: []ExArg{}, + Gadgets: []ExGadget{}, + Code: []App{}, + Field: api.FieldID, + } + + for _, circuit := range circuits { + schema, err := getSchema(circuit) + if err != nil { + return "", err + } + args := getExArgs(circuit, schema.Fields) + name := generateUniqueName(circuit, args) + if slices.Contains(past_circuits, name) { + continue + } + past_circuits = append(past_circuits, name) + + circuitInit(circuit, schema) + err = circuit.AbsDefine(&api) + if err != nil { + return "", err + } + + extractorCircuit.Inputs = args + extractorCircuit.Code = api.Code + + circ := fmt.Sprintf("def %s %s: Prop :=\n%s", name, genArgs(extractorCircuit.Inputs), genCircuitBody(extractorCircuit)) + circuits_extracted = append(circuits_extracted, circ) + + // Resetting elements for next circuit + extractorCircuit.Inputs = []ExArg{} + extractorCircuit.Code = []App{} + api.Code = []App{} + } + + prelude := exportPrelude(namespace, extractorCircuit.Field.ScalarField()) + gadgets := exportGadgets(api.Gadgets) + footer := exportFooter(namespace) + return fmt.Sprintf("%s\n\n%s\n\n%s\n\n%s", prelude, gadgets, strings.Join(circuits_extracted, "\n\n"), footer), nil +} + +// ExtractGadgets is used to export a series of `gadgets` to Lean over a `field` under `namespace`. +func ExtractGadgets(namespace string, field ecc.ID, gadgets ...abstractor.GadgetDefinition) (out string, err error) { + defer recoverError() + + api := CodeExtractor{ + Code: []App{}, + Gadgets: []ExGadget{}, + FieldID: field, + } + + for _, gadget := range gadgets { + api.DefineGadget(gadget) + } + + gadgets_string := exportGadgets(api.Gadgets) + prelude := exportPrelude(namespace, api.FieldID.ScalarField()) + footer := exportFooter(namespace) + return fmt.Sprintf("%s\n\n%s\n\n%s", prelude, gadgets_string, footer), nil +} diff --git a/extractor/lean_export.go b/extractor/lean_export.go index 6b09a68..3bb487d 100644 --- a/extractor/lean_export.go +++ b/extractor/lean_export.go @@ -4,95 +4,101 @@ import ( "fmt" "math/big" "reflect" + "regexp" "strings" - "github.com/reilabs/gnark-lean-extractor/abstractor" - - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/schema" ) -func ExportPrelude(name string, order *big.Int) string { +// isWhitespacePresent checks there are no whitespaces in the middle +// of a string. It's used to avoid a user requesting a `namespace` that +// contains whitespaces (which wouldn't be compliant with Lean4 syntax) +func isWhitespacePresent(input string) bool { + return regexp.MustCompile(`\s`).MatchString(input) +} + +// exportPrelude generates the string to put at the beginning of the +// autogenerated Lean4 code. It includes the relevant `provenZK` library +// import. +func exportPrelude(name string, order *big.Int) string { + trimmedName := strings.TrimSpace(name) + if isWhitespacePresent(trimmedName) { + panic("Whitespace isn't allowed in namespace tag") + } s := fmt.Sprintf(`import ProvenZk.Gates import ProvenZk.Ext.Vector +set_option linter.unusedVariables false + namespace %s def Order : ℕ := 0x%s variable [Fact (Nat.Prime Order)] -abbrev F := ZMod Order`, name, order.Text(16)) +abbrev F := ZMod Order`, trimmedName, order.Text(16)) return s } -func ExportFooter(name string) string { - s := fmt.Sprintf(`end %s`, name) +// exportFooter generates the string to put at the end of the +// autogenerated Lean4 code. At the moment it only closes +// the namespace. +func exportFooter(name string) string { + trimmedName := strings.TrimSpace(name) + if isWhitespacePresent(trimmedName) { + panic("Whitespace isn't allowed in namespace tag") + } + s := fmt.Sprintf(`end %s`, trimmedName) return s } -func ExportGadget(gadget ExGadget) string { - kArgsType := "F" - if len(gadget.Outputs) > 1 { - kArgsType = fmt.Sprintf("Vector F %d", len(gadget.Outputs)) +// genKTypeSignature generates the type signature of the `k` +// argument of exported gadgets. +func genKTypeSignature(output reflect.Value) string { + if output.Kind() != reflect.Slice { + return "" + } + if output.Index(0).Kind() == reflect.Slice { + innerType := genKTypeSignature(output.Index(0)) + return fmt.Sprintf("Vector (%s) %d", innerType, output.Len()) + } + return fmt.Sprintf("Vector F %d", output.Len()) +} + +// exportGadget generates the `gadget` function in Lean +func exportGadget(gadget ExGadget) string { + kArgs := "" + if len(gadget.OutputsFlat) == 1 { + kArgs = "(k: F -> Prop)" + } else if len(gadget.OutputsFlat) > 1 { + outputType := genKTypeSignature(reflect.ValueOf(gadget.Outputs)) + kArgs = fmt.Sprintf("(k: %s -> Prop)", outputType) } inAssignment := gadget.Args - return fmt.Sprintf("def %s %s (k: %s -> Prop): Prop :=\n%s", gadget.Name, genArgs(inAssignment), kArgsType, genGadgetBody(inAssignment, gadget)) + + return fmt.Sprintf("def %s %s %s: Prop :=\n%s", gadget.Name, genArgs(inAssignment), kArgs, genGadgetBody(inAssignment, gadget)) } -func ExportGadgets(exGadgets []ExGadget) string { +func exportGadgets(exGadgets []ExGadget) string { gadgets := make([]string, len(exGadgets)) for i, gadget := range exGadgets { - gadgets[i] = ExportGadget(gadget) + gadgets[i] = exportGadget(gadget) } return strings.Join(gadgets, "\n\n") } -func ExportCircuit(circuit ExCircuit, name string) string { - gadgets := ExportGadgets(circuit.Gadgets) +// exportCircuit generates the `circuit` function in Lean +func exportCircuit(circuit ExCircuit, name string) string { + gadgets := exportGadgets(circuit.Gadgets) circ := fmt.Sprintf("def circuit %s: Prop :=\n%s", genArgs(circuit.Inputs), genCircuitBody(circuit)) - prelude := ExportPrelude(name, circuit.Field.ScalarField()) - footer := ExportFooter(name) + prelude := exportPrelude(name, circuit.Field.ScalarField()) + footer := exportFooter(name) return fmt.Sprintf("%s\n\n%s\n\n%s\n\n%s", prelude, gadgets, circ, footer) } -func ArrayInit(f schema.Field, v reflect.Value, op Operand) error { - for i := 0; i < f.ArraySize; i++ { - op := Proj{op, i} - switch len(f.SubFields) { - case 1: - ArrayInit(f.SubFields[0], v.Index(i), op) - case 0: - value := reflect.ValueOf(op) - v.Index(i).Set(value) - default: - panic("Only nested arrays supported in SubFields") - } - } - return nil -} - -func ArrayZero(v reflect.Value) { - switch v.Kind() { - case reflect.Slice: - if v.Len() != 0 { - // Check if there are nested arrays. If yes, continue recursion - // until most nested array - if v.Addr().Elem().Index(0).Kind() == reflect.Slice { - for i := 0; i < v.Len(); i++ { - ArrayZero(v.Addr().Elem().Index(i)) - } - } else { - zero_array := make([]frontend.Variable, v.Len(), v.Len()) - v.Set(reflect.ValueOf(&zero_array).Elem()) - } - } - default: - panic("Only nested slices supported in SubFields of slices") - } -} - -func CircuitInit(class any, schema *schema.Schema) error { +// circuitInit takes struct and a schema to populate all the +// circuit/gagdget fields with Operand. +func circuitInit(class any, schema *schema.Schema) { // https://stackoverflow.com/a/49704408 // https://stackoverflow.com/a/14162161 // https://stackoverflow.com/a/63422049 @@ -122,12 +128,12 @@ func CircuitInit(class any, schema *schema.Schema) error { // initialise each element in the array if field_type.Kind() == reflect.Array { - ArrayInit(f, tmp.Elem().FieldByName(field_name), Input{j}) + arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j}) } else if field_type.Kind() == reflect.Slice { // Recreate a zeroed array to remove overlapping pointers if input // arguments are duplicated (i.e. `api.Call(SliceGadget{circuit.Path, circuit.Path})`) - ArrayZero(tmp.Elem().FieldByName(field_name)) - ArrayInit(f, tmp.Elem().FieldByName(field_name), Input{j}) + arrayZero(tmp.Elem().FieldByName(field_name)) + arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j}) } else if field_type.Kind() == reflect.Interface { init := Input{j} value := reflect.ValueOf(init) @@ -137,20 +143,13 @@ func CircuitInit(class any, schema *schema.Schema) error { fmt.Printf("Skipped type %s\n", field_type.Kind()) } } - return nil -} - -func KindOfField(a any, s string) reflect.Kind { - v := reflect.ValueOf(a).Elem() - f := v.FieldByName(s) - return f.Kind() } -func CircuitArgs(field schema.Field) ExArgType { +func circuitArgs(field schema.Field) ExArgType { // Handling only subfields which are nested arrays switch len(field.SubFields) { case 1: - subType := CircuitArgs(field.SubFields[0]) + subType := circuitArgs(field.SubFields[0]) return ExArgType{field.ArraySize, &subType} case 0: return ExArgType{field.ArraySize, nil} @@ -159,86 +158,24 @@ func CircuitArgs(field schema.Field) ExArgType { } } -func GetExArgs(circuit any, fields []schema.Field) []ExArg { +// getExArgs generates a list of ExArg given a `circuit` and a +// list of `Field`. It is used in the Circuit to Lean functions +func getExArgs(circuit any, fields []schema.Field) []ExArg { args := []ExArg{} for _, f := range fields { - kind := KindOfField(circuit, f.Name) - arg := ExArg{f.Name, kind, CircuitArgs(f)} + kind := kindOfField(circuit, f.Name) + arg := ExArg{f.Name, kind, circuitArgs(f)} args = append(args, arg) } return args } -// Cloned version of NewSchema without constraints -func GetSchema(circuit any) (*schema.Schema, error) { +// getSchema is a cloned version of NewSchema without constraints +func getSchema(circuit any) (*schema.Schema, error) { tVariable := reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() return schema.New(circuit, tVariable) } -func getStructName(circuit any) string { - return reflect.TypeOf(circuit).Elem().Name() -} - -func CircuitToLeanWithName(circuit abstractor.Circuit, field ecc.ID, namespace string) (string, error) { - schema, err := GetSchema(circuit) - if err != nil { - return "", err - } - - err = CircuitInit(circuit, schema) - if err != nil { - fmt.Println("CircuitInit error!") - fmt.Println(err.Error()) - } - - api := CodeExtractor{ - Code: []App{}, - Gadgets: []ExGadget{}, - Field: field, - } - - err = circuit.AbsDefine(&api) - if err != nil { - return "", err - } - - name := getStructName(circuit) - - extractorCircuit := ExCircuit{ - Inputs: GetExArgs(circuit, schema.Fields), - Gadgets: api.Gadgets, - Code: api.Code, - Field: api.Field, - Name: name, - } - out := ExportCircuit(extractorCircuit, namespace) - return out, nil -} - -func CircuitToLean(circuit abstractor.Circuit, field ecc.ID) (string, error) { - name := getStructName(circuit) - return CircuitToLeanWithName(circuit, field, name) -} - -func GadgetToLeanWithName(circuit abstractor.GadgetDefinition, field ecc.ID, namespace string) (string, error) { - api := CodeExtractor{ - Code: []App{}, - Gadgets: []ExGadget{}, - Field: field, - } - - api.DefineGadget(circuit) - gadgets := ExportGadgets(api.Gadgets) - prelude := ExportPrelude(namespace, api.Field.ScalarField()) - footer := ExportFooter(namespace) - return fmt.Sprintf("%s\n\n%s\n\n%s", prelude, gadgets, footer), nil -} - -func GadgetToLean(circuit abstractor.GadgetDefinition, field ecc.ID) (string, error) { - name := getStructName(circuit) - return GadgetToLeanWithName(circuit, field, name) -} - func genNestedArrays(a ExArgType) string { if a.Type != nil { return fmt.Sprintf("Vector (%s) %d", genNestedArrays(*a.Type), a.Size) @@ -265,8 +202,8 @@ func extractGateVars(arg Operand) []Operand { return extractGateVars(arg.(Proj).Operand) case ProjArray: res := []Operand{} - for i := range arg.(ProjArray).Proj { - res = append(res, extractGateVars(arg.(ProjArray).Proj[i])...) + for i := range arg.(ProjArray).Projs { + res = append(res, extractGateVars(arg.(ProjArray).Projs[i])...) } return res default: @@ -309,11 +246,14 @@ func assignGateVars(code []App, additional ...Operand) []string { func genGadgetCall(gateVar string, inAssignment []ExArg, gateVars []string, gadget *ExGadget, args []Operand) string { name := gadget.Name operands := operandExprs(args, inAssignment, gateVars) - binder := "_" - if gateVar != "" { - binder = gateVar + binder := "∧" + if len(gadget.OutputsFlat) > 0 { + binder = "fun _ =>" + if gateVar != "" { + binder = fmt.Sprintf("fun %s =>", gateVar) + } } - return fmt.Sprintf(" %s %s fun %s =>\n", name, strings.Join(operands, " "), binder) + return fmt.Sprintf(" %s %s %s\n", name, strings.Join(operands, " "), binder) } func genGateOp(op Op) string { @@ -388,23 +328,7 @@ func genFunctionalGate(gateVar string, op Op, operands []string) string { func genCallbackGate(gateVar string, op Op, operands []string, args []Operand) string { gateName := getGateName(gateVar, false) - switch op { - case OpFromBinary: - is_gate := reflect.TypeOf(args[0]) == reflect.TypeOf(Gate{}) - if len(args) == 1 && is_gate { - return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) - } - return fmt.Sprintf(" ∃%s, %s vec![%s] %s ∧\n", gateName, genGateOp(op), strings.Join(operands, ", "), gateName) - case OpToBinary: - is_const := reflect.TypeOf(args[0]) == reflect.TypeOf(Const{}) - if is_const { - operands[0] = fmt.Sprintf("(%s:F)", operands[0]) - return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) - } - return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) - default: - return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) - } + return fmt.Sprintf(" ∃%s, %s %s %s ∧\n", gateName, genGateOp(op), strings.Join(operands, " "), gateName) } func genGenericGate(op Op, operands []string) string { @@ -423,6 +347,13 @@ func genOpCall(gateVar string, inAssignment []ExArg, gateVars []string, op Op, a } operands := operandExprs(args, inAssignment, gateVars) + if op == OpFromBinary { + // OpFromBinary takes only one argument which is represented as list of Proj. For this reason we can + // safely wrap it in a ProjArray and call operandExpr directly. + projArray := ProjArray{args} + operands = []string{operandExpr(projArray, inAssignment, gateVars)} + } + if functional { // if an operation supports infinite length of arguments, // turn it into a chain of operations @@ -458,18 +389,27 @@ func genLine(app App, gateVar string, inAssignment []ExArg, gateVars []string) s } func genGadgetBody(inAssignment []ExArg, gadget ExGadget) string { - gateVars := assignGateVars(gadget.Code, gadget.Outputs...) + gateVars := assignGateVars(gadget.Code, gadget.OutputsFlat...) lines := make([]string, len(gadget.Code)) for i, app := range gadget.Code { lines[i] = genLine(app, gateVars[i], inAssignment, gateVars) } - outs := operandExprs(gadget.Outputs, inAssignment, gateVars) - result := outs[0] - if len(gadget.Outputs) > 1 { - result = fmt.Sprintf("vec![%s]", strings.Join(outs, ", ")) + + switch len(gadget.OutputsFlat) { + case 0: + lastLine := " True" + return strings.Join(append(lines, lastLine), "") + case 1: + // The case statement ensures there is index 0 (and only 0) + result := operandExpr(gadget.OutputsFlat[0], inAssignment, gateVars) + lastLine := fmt.Sprintf(" k %s", result) + return strings.Join(append(lines, lastLine), "") + default: + // Same trick used for OpFromBinary in genOpCall + result := operandExpr(ProjArray{gadget.OutputsFlat}, inAssignment, gateVars) + lastLine := fmt.Sprintf(" k %s", result) + return strings.Join(append(lines, lastLine), "") } - lastLine := fmt.Sprintf(" k %s", result) - return strings.Join(append(lines, lastLine), "") } func genCircuitBody(circuit ExCircuit) string { @@ -482,6 +422,161 @@ func genCircuitBody(circuit ExCircuit) string { return strings.Join(append(lines, lastLine), "") } +func getArgIndex(operand ProjArray) int { + if reflect.TypeOf(operand.Projs[0]) == reflect.TypeOf(Proj{}) { + switch op := operand.Projs[0].(Proj).Operand.(type) { + case Input: + return op.Index + case Gate: + return op.Index + case Proj: + return getArgIndex(ProjArray{[]Operand{op}}) + default: + return -1 + } + } else if (reflect.TypeOf(operand.Projs[0]) == reflect.TypeOf(ProjArray{})) { + return getArgIndex(operand.Projs[0].(ProjArray)) + } else { + return -1 + } +} + +func checkVector(operand ProjArray, argIdx int) (bool, Operand) { + // Check correct length + if operand.Projs[0].(Proj).Size != len(operand.Projs) { + return false, operand + } + + // Check index starts at 0 + lastIndex := operand.Projs[0].(Proj).Index + if lastIndex != 0 { + return false, operand + } + // Check always same Operand + firstOperand := operand.Projs[0].(Proj).Operand + + // Check indices are in ascending order + // on the same argIdx + for _, op := range operand.Projs[1:] { + if lastIndex != op.(Proj).Index-1 { + return false, operand + } + lastIndex += 1 + if firstOperand != op.(Proj).Operand { + return false, operand + } + } + return true, operand.Projs[0].(Proj).Operand +} + +// getStack returns the dimension of each of the nested ProjArray.Projs in `operand`. +// Outermost dimension is at index 0 +func getStack(operand ProjArray) []int { + if reflect.TypeOf(operand.Projs[0]) == reflect.TypeOf(ProjArray{}) { + return getStack(operand.Projs[0].(ProjArray)) + } else if reflect.TypeOf(operand.Projs[0]) == reflect.TypeOf(Proj{}) { + proj := operand.Projs[0].(Proj) + if reflect.TypeOf(proj.Operand) == reflect.TypeOf(Proj{}) { + return append(getStack(ProjArray{[]Operand{proj.Operand}}), proj.Size) + } else { + return []int{proj.Size} + } + } else { + return []int{} + } +} + +// expectedOperand checks that `op` has the Operand of `argIndex` +// and the last element of `indices` matches `op.Index` +func expectedOperand(op Proj, argIndex Operand, indices []int) bool { + if op.Index != indices[len(indices)-1] { + return false + } + if reflect.TypeOf(op.Operand) == reflect.TypeOf(Proj{}) { + return expectedOperand(op.Operand.(Proj), argIndex, indices[0:len(indices)-1]) + } + return op.Operand == argIndex +} + +// checkDimensions checks that the list of Proj is from the same Operand, with increasing Index, +// with the first Index being 0 and with the number of elemetns in ProjArray.Projs matching length[0] +func checkDimensions(operand ProjArray, length []int, argIndex Operand, pastIndices ...int) bool { + if len(operand.Projs) != length[0] { + return false + } + for i, p := range operand.Projs { + if len(length[1:]) >= 1 { + past := append(pastIndices, i) + if !checkDimensions(p.(ProjArray), length[1:], argIndex, past...) { + return false + } + } else { + if !expectedOperand(p.(Proj), argIndex, append(pastIndices, i)) { + return false + } + } + } + return true +} + +func getFirstOperand(operand ProjArray) Operand { + if reflect.TypeOf(operand.Projs[0]) == reflect.TypeOf(ProjArray{}) { + return getFirstOperand(operand.Projs[0].(ProjArray)) + } else if reflect.TypeOf(operand.Projs[0]) == reflect.TypeOf(Proj{}) { + return operand.Projs[0].(Proj) + } else { + fmt.Printf("getFirstOperand %+v\n", operand) + panic("Error in getFirstOperand.") + } +} + +func getIndex(operand Operand) Operand { + if reflect.TypeOf(operand) != reflect.TypeOf(Proj{}) { + return operand + } + return getIndex(operand.(Proj).Operand) +} + +// isVectorComplete determines if `operand` can be optimised +// without recreating the Vector element by element. The +// Operand returned is the simplified operand for parsing +// by `operandExpr` +func isVectorComplete(operand ProjArray) (bool, Operand) { + if len(operand.Projs) == 0 { + return false, operand + } + + // To check that ProjArray{} is complete, we first collect the Size + // parameter from index 0 of the Projs list. Then we retrieve the first + // Operand in ProjArray and extract the Input or Gate to use it in + // `checkDimensions` to verify that it's the same across all the + // elements. `checkDimensions` iterates through all the elements + // in `operand` to verify that `Operand` in `Proj` are all from the + // same Input/Gate, with Index starting from 0 and in ascending order + // and the length of Operand matches Proj.Size. + if reflect.TypeOf(operand.Projs[0]) == reflect.TypeOf(ProjArray{}) { + sliceDimensions := getStack(operand) + if len(sliceDimensions) == 0 { + return false, operand + } + firstOperand := getFirstOperand(operand) + argIdx := getIndex(firstOperand) + if !checkDimensions(operand, sliceDimensions, argIdx) { + return false, operand + } + return true, argIdx + } + + // checkVector is used for Proj and it does the same checks + // as for ProjArray + argIdx := getArgIndex(operand) + if argIdx == -1 { + return false, operand + } + + return checkVector(operand, argIdx) +} + func operandExpr(operand Operand, inAssignment []ExArg, gateVars []string) string { switch operand.(type) { case Input: @@ -491,11 +586,17 @@ func operandExpr(operand Operand, inAssignment []ExArg, gateVars []string) strin case Proj: return fmt.Sprintf("%s[%d]", operandExpr(operand.(Proj).Operand, inAssignment, gateVars), operand.(Proj).Index) case ProjArray: - opArray := operandExprs(operand.(ProjArray).Proj, inAssignment, gateVars) + isComplete, newOperand := isVectorComplete(operand.(ProjArray)) + if isComplete { + return operandExpr(newOperand, inAssignment, gateVars) + } + opArray := operandExprs(operand.(ProjArray).Projs, inAssignment, gateVars) opArray = []string{strings.Join(opArray, ", ")} return fmt.Sprintf("vec!%s", opArray) case Const: - return operand.(Const).Value.Text(10) + return fmt.Sprintf("(%s:F)", operand.(Const).Value.Text(10)) + case Integer: + return operand.(Integer).Value.Text(10) default: fmt.Printf("Type %T\n", operand) panic("not yet supported") @@ -503,9 +604,9 @@ func operandExpr(operand Operand, inAssignment []ExArg, gateVars []string) strin } func operandExprs(operands []Operand, inAssignment []ExArg, gateVars []string) []string { - exprs := make([]string, len(operands)) - for i, operand := range operands { - exprs[i] = operandExpr(operand, inAssignment, gateVars) + exprs := []string{} + for _, operand := range operands { + exprs = append(exprs, operandExpr(operand, inAssignment, gateVars)) } return exprs } diff --git a/extractor/misc.go b/extractor/misc.go new file mode 100644 index 0000000..754c46f --- /dev/null +++ b/extractor/misc.go @@ -0,0 +1,264 @@ +package extractor + +import ( + "errors" + "flag" + "fmt" + "reflect" + "runtime/debug" + "strings" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/schema" + "github.com/mitchellh/copystructure" + "github.com/reilabs/gnark-lean-extractor/abstractor" +) + +// recoverError is used in the top level interface to prevent panic +// caused by any of the methods in the extractor from propagating +// When go is running in test mode, it prints the stack trace to aid +// debugging. +func recoverError() (err error) { + if recover() != nil { + if flag.Lookup("test.v") != nil { + stack := string(debug.Stack()) + fmt.Println(stack) + } + err = errors.New("Panic extracting circuit to Lean") + } + return nil +} + +// arrayToSlice returns a slice of elements identical to +// the input array `v` +func arrayToSlice(v reflect.Value) []frontend.Variable { + if v.Len() == 0 { + return []frontend.Variable{} + } + + switch v.Index(0).Kind() { + case reflect.Array: + args := []frontend.Variable{} + for i := 0; i < v.Len(); i++ { + arg := arrayToSlice(v.Index(i)) + // The reason to check for len != 0 is to avoid generating + // lists of empty nested lists + if len(arg) != 0 { + args = append(args, arg) + } + } + return args + case reflect.Interface: + res := []frontend.Variable{} + for i := 0; i < v.Len(); i++ { + res = append(res, v.Index(i).Elem().Interface().(frontend.Variable)) + } + return res + default: + return []frontend.Variable{} + } +} + +// flattenSlice takes a slice and returns a single dimension +// slice of frontend.Variable. This is needed to transform +// nested slices into single dimensional slices to be +// processed by sanitizeVars. +func flattenSlice(value reflect.Value) []frontend.Variable { + if value.Len() == 0 { + return []frontend.Variable{} + } + if value.Index(0).Kind() == reflect.Slice { + args := []frontend.Variable{} + for i := 0; i < value.Len(); i++ { + arg := flattenSlice(value.Index(i)) + // The reason to check for len != 0 is to avoid generating + // lists of empty nested lists + if len(arg) != 0 { + args = append(args, arg) + } + } + return args + } + return value.Interface().([]frontend.Variable) +} + +// arrayInit generates the Proj{} object for each element of v +func arrayInit(f schema.Field, v reflect.Value, op Operand) error { + for i := 0; i < f.ArraySize; i++ { + op := Proj{op, i, f.ArraySize} + switch len(f.SubFields) { + case 1: + arrayInit(f.SubFields[0], v.Index(i), op) + case 0: + if v.Len() != f.ArraySize { + // Slices of this type aren't supported yet [[ ] [ ]] + // gnark newSchema doesn't handle different dimensions + fmt.Printf("Wrong slices dimensions %+v\n", v) + panic("Only slices dimensions not matching") + } + value := reflect.ValueOf(op) + v.Index(i).Set(value) + default: + panic("Only nested arrays supported in SubFields") + } + } + return nil +} + +// arrayZero sets all the elements of the input slice v to nil. +// It is used when initialising a new circuit or gadget to ensure +// the object is clean +func arrayZero(v reflect.Value) { + switch v.Kind() { + case reflect.Slice: + if v.Len() != 0 { + // Check if there are nested arrays. If yes, continue recursion + // until most nested array + if v.Addr().Elem().Index(0).Kind() == reflect.Slice { + for i := 0; i < v.Len(); i++ { + arrayZero(v.Addr().Elem().Index(i)) + } + } else { + zero_array := make([]frontend.Variable, v.Len(), v.Len()) + v.Set(reflect.ValueOf(&zero_array).Elem()) + } + } + default: + panic("Only nested slices supported in SubFields of slices") + } +} + +// kindOfField returns the Kind of field in struct a +func kindOfField(a any, field string) reflect.Kind { + v := reflect.ValueOf(a).Elem() + f := v.FieldByName(field) + return f.Kind() +} + +// getStructName returns the name of struct a +func getStructName(a any) string { + return reflect.TypeOf(a).Elem().Name() +} + +// updateProj recursively creates a Proj object using the `Index` and `Size` from the +// optional argument `extra`. It uses the argument `gate` as Operand for the innermost Proj. +// The `extra` optional argument contains the `Index` in even indices and the `Size` in odd indices, +// elements are discarded from the end. +func updateProj(gate Operand, extra ...int) Proj { + if len(extra) == 2 { + return Proj{gate, extra[0], extra[1]} + } else if len(extra) > 0 && len(extra)%2 == 0 { + return Proj{updateProj(gate, extra[:len(extra)-2]...), extra[len(extra)-2], extra[len(extra)-1]} + } + fmt.Printf("updateProj gate: %#v | extra: %+v", gate, extra) + panic("updateProj called with wrong number of elements in extra") +} + +// replaceArg generates the object returned when calling the gadget in a circuit. +// The object returned has the same structure as ExGadget.OutputsFlat but it needs +// to have updated `Proj` fields. gate argument corresponds to the `Gate` object of the +// gadget call. extra argument keeps track of the `Size` and `Index` elements of the nested +// Proj. These need to be replaced because the output of a gadget is a combination +// of Proj. +func replaceArg(gOutputs interface{}, gate Operand, extra ...int) interface{} { + // extra[0] -> i + // extra[1] -> len + switch v := (gOutputs).(type) { + case Input, Gate: + if len(extra) == 2 { + return Proj{gate, extra[0], extra[1]} + } + return gate + case Proj: + if len(extra) >= 2 { + return updateProj(gate, extra...) + } + return gate + case []frontend.Variable: + res := make([]frontend.Variable, len(v)) + for i, o := range v { + res[i] = replaceArg(o, gate, append(extra, []int{i, len(v)}...)...) + } + return res + case [][]frontend.Variable: + res := make([][]frontend.Variable, len(v)) + for i, o := range v { + res[i] = replaceArg(o, gate, append(extra, []int{i, len(v)}...)...).([]frontend.Variable) + } + return res + case [][][]frontend.Variable: + res := make([][][]frontend.Variable, len(v)) + for i, o := range v { + res[i] = replaceArg(o, gate, append(extra, []int{i, len(v)}...)...).([][]frontend.Variable) + } + return res + case nil: + return []frontend.Variable{} + default: + fmt.Printf("replaceArg invalid argument of type %T %#v\n", gOutputs, gOutputs) + panic("replaceArg invalid argument") + } +} + +// cloneGadget performs deep cloning of `gadget` +func cloneGadget(gadget abstractor.GadgetDefinition) abstractor.GadgetDefinition { + dup, err := copystructure.Copy(gadget) + if err != nil { + panic(err) + } + // The reason for the following lines is to generate a reflect.Ptr to the interface + v := reflect.ValueOf(dup) + tmp_gadget := reflect.New(v.Type()) + tmp_gadget.Elem().Set(v) + return tmp_gadget.Interface().(abstractor.GadgetDefinition) +} + +// generateUniqueName is a function that generates the gadget function name in Lean +// To distinguish between gadgets instantiated with different array +// sizes, add a suffix to the name. The suffix of each instantiation +// is made up of the concatenation of the length of all the array +// fields in the gadget +func generateUniqueName(element any, args []ExArg) string { + suffix := "" + for _, a := range args { + if a.Kind == reflect.Array || a.Kind == reflect.Slice { + suffix += "_" + suffix += strings.Join(getSizeGadgetArgs(a.Type), "_") + } + } + + val := reflect.ValueOf(element).Elem() + for i := 0; i < val.NumField(); i++ { + switch val.Field(i).Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + suffix += fmt.Sprintf("_%d", val.Field(i).Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + suffix += fmt.Sprintf("_%d", val.Field(i).Uint()) + case reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + fmt.Printf("-- Gadget name doesn't differentiate yet between different values of type %+v.\n", val.Field(i).Kind()) + fmt.Println("-- Proceed with caution") + } + } + return fmt.Sprintf("%s%s", reflect.TypeOf(element).Elem().Name(), suffix) +} + +// getGadgetByName checks if `name` matches the ExGadget.Name of one of +// the elements in `gadgets` +func getGadgetByName(gadgets []ExGadget, name string) abstractor.Gadget { + for _, gadget := range gadgets { + if gadget.Name == name { + return &gadget + } + } + return nil +} + +// getSizeGadgetArgs generates the concatenation of dimensions of +// a slice/array (i.e. [3][2]frontend.Variable --> ["3","2"]) +// It is used to generate a unique gadget name +func getSizeGadgetArgs(elem ExArgType) []string { + if elem.Type == nil { + return []string{fmt.Sprintf("%d", elem.Size)} + } + return append(getSizeGadgetArgs(*elem.Type), fmt.Sprintf("%d", elem.Size)) +} diff --git a/extractor/test/another_circuit_test.go b/extractor/test/another_circuit_test.go new file mode 100644 index 0000000..8fd8140 --- /dev/null +++ b/extractor/test/another_circuit_test.go @@ -0,0 +1,62 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" +) + +// Example: Gadget with nested array of int +type IntArrayGadget struct { + In []frontend.Variable + Matrix [2]int + NestedMatrix [2][2]int +} + +func (gadget IntArrayGadget) DefineGadget(api abstractor.API) interface{} { + r := api.FromBinary(gadget.In...) + api.Mul(gadget.Matrix[0], gadget.Matrix[1]) + return []frontend.Variable{r, r, r} +} + +type AnotherCircuit struct { + In []frontend.Variable + Matrix [2][2]int +} + +func (circuit *AnotherCircuit) AbsDefine(api abstractor.API) error { + r := extractor.Call1(api, IntArrayGadget{ + circuit.In, + circuit.Matrix[0], + circuit.Matrix, + }) + + api.FromBinary(r[1:3]...) + api.FromBinary(r[0:2]...) + api.FromBinary(r...) + return nil +} + +func (circuit AnotherCircuit) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestAnotherCircuit(t *testing.T) { + m := [2][2]int{ + {0, 36}, + {1, 44}, + } + assignment := AnotherCircuit{ + In: make([]frontend.Variable, 4), + Matrix: m, + } + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/circuit_with_parameter_test.go b/extractor/test/circuit_with_parameter_test.go new file mode 100644 index 0000000..56002dd --- /dev/null +++ b/extractor/test/circuit_with_parameter_test.go @@ -0,0 +1,92 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" + "github.com/stretchr/testify/assert" +) + +// Example: circuit with constant parameter +type ReturnItself struct { + In_1 []frontend.Variable + Out []frontend.Variable +} + +func (gadget ReturnItself) DefineGadget(api abstractor.API) interface{} { + for i := 0; i < len(gadget.In_1); i++ { + gadget.Out[i] = api.Mul(gadget.In_1[i], gadget.In_1[i]) + } + + return gadget.Out +} + +type SliceGadget struct { + In_1 []frontend.Variable + In_2 []frontend.Variable +} + +func (gadget SliceGadget) DefineGadget(api abstractor.API) interface{} { + for i := 0; i < len(gadget.In_1); i++ { + api.Mul(gadget.In_1[i], gadget.In_2[i]) + } + + r := api.FromBinary(gadget.In_1...) + return r +} + +type CircuitWithParameter struct { + In frontend.Variable `gnark:",public"` + Path []frontend.Variable `gnark:",public"` + Tree []frontend.Variable `gnark:",public"` + Param int +} + +func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error { + D := make([]frontend.Variable, 3) + for i := 0; i < len(circuit.Path); i++ { + D = extractor.Call1(api, ReturnItself{ + In_1: circuit.Path, + Out: D, + }) + api.AssertIsEqual(D[1], D[2]) + } + + api.FromBinary(circuit.Path...) + api.FromBinary(D...) + api.FromBinary(D[1], D[2], D[0]) + api.FromBinary(D[1], 0, D[0]) + api.FromBinary(D[1:3]...) + bin := api.ToBinary(circuit.In) + bin = api.ToBinary(circuit.Param) + + dec := api.FromBinary(bin...) + api.AssertIsEqual(circuit.Param, dec) + extractor.Call(api, SliceGadget{circuit.Path, circuit.Path}) + + api.Mul(circuit.Path[0], circuit.Path[0]) + extractor.Call(api, SliceGadget{circuit.Tree, circuit.Tree}) + api.AssertIsEqual(circuit.Param, circuit.In) + + return nil +} + +func (circuit CircuitWithParameter) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestCircuitWithParameter(t *testing.T) { + paramValue := 20 + assignment := CircuitWithParameter{Path: make([]frontend.Variable, 3), Tree: make([]frontend.Variable, 2)} + assignment.Param = paramValue + assert.Equal(t, assignment.Param, paramValue, "assignment.Param is a const and should be 20.") + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/deletion_mbu_circuit_test.go b/extractor/test/deletion_mbu_circuit_test.go new file mode 100644 index 0000000..bd2221e --- /dev/null +++ b/extractor/test/deletion_mbu_circuit_test.go @@ -0,0 +1,86 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" +) + +// Example: Mismatched arguments error +type DeletionProof struct { + DeletionIndices []frontend.Variable + PreRoot frontend.Variable + IdComms []frontend.Variable + MerkleProofs [][]frontend.Variable + + BatchSize int + Depth int +} + +func (gadget DeletionProof) DefineGadget(api abstractor.API) interface{} { + return gadget.PreRoot +} + +type DeletionMbuCircuit struct { + // single public input + InputHash frontend.Variable `gnark:",public"` + + // private inputs, but used as public inputs + DeletionIndices []frontend.Variable `gnark:"input"` + PreRoot frontend.Variable `gnark:"input"` + PostRoot frontend.Variable `gnark:"input"` + + // private inputs + IdComms []frontend.Variable `gnark:"input"` + MerkleProofs [][]frontend.Variable `gnark:"input"` + + BatchSize int + Depth int +} + +func (circuit *DeletionMbuCircuit) AbsDefine(api abstractor.API) error { + root := extractor.Call(api, DeletionProof{ + DeletionIndices: circuit.DeletionIndices, + PreRoot: circuit.PreRoot, + IdComms: circuit.IdComms, + MerkleProofs: circuit.MerkleProofs, + BatchSize: circuit.BatchSize, + Depth: circuit.Depth, + }) + + // Final root needs to match. + api.AssertIsEqual(root, circuit.PostRoot) + + return nil +} + +func (circuit DeletionMbuCircuit) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestDeletionMbuCircuit(t *testing.T) { + batchSize := 2 + treeDepth := 3 + proofs := make([][]frontend.Variable, batchSize) + for i := 0; i < int(batchSize); i++ { + proofs[i] = make([]frontend.Variable, treeDepth) + } + + assignment := DeletionMbuCircuit{ + DeletionIndices: make([]frontend.Variable, batchSize), + IdComms: make([]frontend.Variable, batchSize), + MerkleProofs: proofs, + + BatchSize: int(batchSize), + Depth: int(treeDepth), + } + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/merkle_recover_test.go b/extractor/test/merkle_recover_test.go new file mode 100644 index 0000000..a798459 --- /dev/null +++ b/extractor/test/merkle_recover_test.go @@ -0,0 +1,54 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" +) + +// Example: circuit with arrays and gadget +type DummyHash struct { + In_1 frontend.Variable + In_2 frontend.Variable +} + +func (gadget DummyHash) DefineGadget(api abstractor.API) interface{} { + r := api.Mul(gadget.In_1, gadget.In_2) + return r +} + +type MerkleRecover struct { + Root frontend.Variable `gnark:",public"` + Element frontend.Variable `gnark:",public"` + Path [20]frontend.Variable `gnark:",secret"` + Proof [20]frontend.Variable `gnark:",secret"` +} + +func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error { + current := circuit.Element + for i := 0; i < len(circuit.Path); i++ { + leftHash := extractor.Call(api, DummyHash{current, circuit.Proof[i]}) + rightHash := extractor.Call(api, DummyHash{circuit.Proof[i], current}) + current = api.Select(circuit.Path[i], rightHash, leftHash) + } + api.AssertIsEqual(current, circuit.Root) + + return nil +} + +func (circuit MerkleRecover) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestMerkleRecover(t *testing.T) { + assignment := MerkleRecover{} + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/my_circuit_test.go b/extractor/test/my_circuit_test.go new file mode 100644 index 0000000..5a86ffd --- /dev/null +++ b/extractor/test/my_circuit_test.go @@ -0,0 +1,37 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" +) + +// Example: readme circuit +type MyCircuit struct { + In_1 frontend.Variable + In_2 frontend.Variable + Out frontend.Variable +} + +func (circuit *MyCircuit) AbsDefine(api abstractor.API) error { + sum := api.Add(circuit.In_1, circuit.In_2) + api.AssertIsEqual(sum, circuit.Out) + return nil +} + +func (circuit MyCircuit) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestMyCircuit(t *testing.T) { + assignment := MyCircuit{} + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/slices_optimisation_test.go b/extractor/test/slices_optimisation_test.go new file mode 100644 index 0000000..51b59fa --- /dev/null +++ b/extractor/test/slices_optimisation_test.go @@ -0,0 +1,110 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" +) + +// Example: checking slices optimisation +type TwoSlices struct { + TwoDim [][]frontend.Variable +} + +func (gadget TwoSlices) DefineGadget(api abstractor.API) interface{} { + return gadget.TwoDim +} + +type ThreeSlices struct { + ThreeDim [][][]frontend.Variable +} + +func (gadget ThreeSlices) DefineGadget(api abstractor.API) interface{} { + return gadget.ThreeDim +} + +type SlicesGadget struct { + TwoDim [][]frontend.Variable + ThreeDim [][][]frontend.Variable +} + +func (gadget SlicesGadget) DefineGadget(api abstractor.API) interface{} { + return append(gadget.ThreeDim[0][0], gadget.TwoDim[0]...) +} + +type SlicesOptimisation struct { + Test frontend.Variable + Id []frontend.Variable + TwoDim [][]frontend.Variable + ThreeDim [][][]frontend.Variable +} + +func (circuit *SlicesOptimisation) AbsDefine(api abstractor.API) error { + extractor.Call1(api, SlicesGadget{ + TwoDim: circuit.TwoDim, + ThreeDim: circuit.ThreeDim, + }) + extractor.Call1(api, SlicesGadget{ + TwoDim: [][]frontend.Variable{circuit.TwoDim[1], circuit.TwoDim[0]}, + ThreeDim: [][][]frontend.Variable{circuit.ThreeDim[1], circuit.ThreeDim[0]}, + }) + extractor.Call1(api, SlicesGadget{ + TwoDim: [][]frontend.Variable{{circuit.TwoDim[1][1]}, {circuit.TwoDim[1][0]}}, + ThreeDim: [][][]frontend.Variable{circuit.ThreeDim[1], circuit.ThreeDim[0], circuit.ThreeDim[1]}, + }) + extractor.Call1(api, SlicesGadget{ + TwoDim: [][]frontend.Variable{circuit.TwoDim[1], {circuit.TwoDim[1][0], circuit.TwoDim[0][0], circuit.TwoDim[1][1]}}, + ThreeDim: circuit.ThreeDim, + }) + extractor.Call2(api, TwoSlices{ + TwoDim: circuit.TwoDim, + }) + a := extractor.Call3(api, ThreeSlices{ + ThreeDim: circuit.ThreeDim, + }) + b := extractor.Call3(api, ThreeSlices{ + ThreeDim: a, + }) + extractor.Call3(api, ThreeSlices{ + ThreeDim: b, + }) + + return nil +} + +func (circuit SlicesOptimisation) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestSlicesOptimisation(t *testing.T) { + depthOne := 2 + depthTwo := 3 + depthThree := 4 + twoSlice := make([][]frontend.Variable, depthOne) + for i := 0; i < int(depthOne); i++ { + twoSlice[i] = make([]frontend.Variable, depthTwo) + } + + threeSlice := make([][][]frontend.Variable, depthOne) + for x := 0; x < int(depthOne); x++ { + threeSlice[x] = make([][]frontend.Variable, depthTwo) + for y := 0; y < int(depthTwo); y++ { + threeSlice[x][y] = make([]frontend.Variable, depthThree) + } + } + + assignment := SlicesOptimisation{ + Id: make([]frontend.Variable, depthTwo), + TwoDim: twoSlice, + ThreeDim: threeSlice, + } + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/to_binary_circuit_test.go b/extractor/test/to_binary_circuit_test.go new file mode 100644 index 0000000..530f959 --- /dev/null +++ b/extractor/test/to_binary_circuit_test.go @@ -0,0 +1,91 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" +) + +// Example: Gadget that returns a vector +type OptimisedVectorGadget struct { + In frontend.Variable +} + +func (gadget OptimisedVectorGadget) DefineGadget(api abstractor.API) interface{} { + return api.ToBinary(gadget.In, 3) +} + +// Example: ToBinary behaviour and nested Slice +type VectorGadget struct { + In_1 []frontend.Variable + In_2 []frontend.Variable + Nested [][]frontend.Variable +} + +func (gadget VectorGadget) DefineGadget(api abstractor.API) interface{} { + var sum frontend.Variable + for i := 0; i < len(gadget.In_1); i++ { + sum = api.Mul(gadget.In_1[i], gadget.In_2[i]) + } + return []frontend.Variable{sum, sum, sum} +} + +type ToBinaryCircuit struct { + In frontend.Variable `gnark:",public"` + Out frontend.Variable `gnark:",public"` + Double [][]frontend.Variable `gnark:",public"` +} + +func (circuit *ToBinaryCircuit) AbsDefine(api abstractor.API) error { + bin := api.ToBinary(circuit.In, 3) + bout := api.ToBinary(circuit.Out, 3) + + api.Add(circuit.Double[2][2], circuit.Double[1][1], circuit.Double[0][0]) + api.Mul(bin[1], bout[1]) + d := extractor.Call1(api, VectorGadget{circuit.Double[2][:], circuit.Double[0][:], circuit.Double}) + api.Mul(d[2], d[1]) + + return nil +} + +func (circuit ToBinaryCircuit) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestGadgetExtraction(t *testing.T) { + dim_1 := 3 + dim_2 := 3 + doubleSlice := make([][]frontend.Variable, dim_1) + for i := 0; i < int(dim_1); i++ { + doubleSlice[i] = make([]frontend.Variable, dim_2) + } + assignment := VectorGadget{ + In_1: make([]frontend.Variable, dim_2), + In_2: make([]frontend.Variable, dim_2), + Nested: doubleSlice, + } + out, err := extractor.GadgetToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} + +func TestToBinaryCircuit(t *testing.T) { + dim_1 := 3 + dim_2 := 3 + doubleSlice := make([][]frontend.Variable, dim_1) + for i := 0; i < int(dim_1); i++ { + doubleSlice[i] = make([]frontend.Variable, dim_2) + } + assignment := ToBinaryCircuit{Double: doubleSlice} + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/two_gadgets_test.go b/extractor/test/two_gadgets_test.go new file mode 100644 index 0000000..405d9cb --- /dev/null +++ b/extractor/test/two_gadgets_test.go @@ -0,0 +1,121 @@ +package extractor_test + +import ( + "log" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/reilabs/gnark-lean-extractor/abstractor" + "github.com/reilabs/gnark-lean-extractor/extractor" +) + +// Example: circuit with multiple gadgets +type MyWidget struct { + Test_1 frontend.Variable + Test_2 frontend.Variable + Num uint32 +} + +func (gadget MyWidget) DefineGadget(api abstractor.API) interface{} { + sum := api.Add(gadget.Test_1, gadget.Test_2) + mul := api.Mul(gadget.Test_1, gadget.Test_2) + r := api.Div(sum, mul) + api.AssertIsBoolean(gadget.Num) + return r +} + +type MySecondWidget struct { + Test_1 frontend.Variable + Test_2 frontend.Variable + Num int +} + +func (gadget MySecondWidget) DefineGadget(api abstractor.API) interface{} { + mul := api.Mul(gadget.Test_1, gadget.Test_2) + snd := extractor.Call(api, MyWidget{gadget.Test_1, gadget.Test_2, uint32(gadget.Num)}) + api.Mul(mul, snd) + return nil +} + +type TwoGadgets struct { + In_1 frontend.Variable + In_2 frontend.Variable + Num int +} + +func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error { + sum := api.Add(circuit.In_1, circuit.In_2) + prod := api.Mul(circuit.In_1, circuit.In_2) + extractor.CallVoid(api, MySecondWidget{sum, prod, circuit.Num}) + return nil +} + +func (circuit TwoGadgets) Define(api frontend.API) error { + return abstractor.Concretize(api, &circuit) +} + +func TestTwoGadgets(t *testing.T) { + assignment := TwoGadgets{Num: 11} + out, err := extractor.CircuitToLean(&assignment, ecc.BN254) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} + +func TestExtractGadgets(t *testing.T) { + assignment_1 := DummyHash{} + assignment_2 := MySecondWidget{Num: 11} + assignment_3 := MySecondWidget{Num: 9} + out, err := extractor.ExtractGadgets("MultipleGadgets", ecc.BN254, &assignment_1, &assignment_2, &assignment_3) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} + +func TestExtractGadgetsVectors(t *testing.T) { + dim_1 := 3 + dim_2 := 3 + doubleSlice := make([][]frontend.Variable, dim_1) + for i := 0; i < int(dim_1); i++ { + doubleSlice[i] = make([]frontend.Variable, dim_2) + } + assignment_1 := VectorGadget{ + In_1: make([]frontend.Variable, dim_2), + In_2: make([]frontend.Variable, dim_2), + Nested: doubleSlice, + } + assignment_2 := ReturnItself{ + In_1: make([]frontend.Variable, dim_1), + Out: make([]frontend.Variable, dim_1), + } + assignment_3 := OptimisedVectorGadget{} + out, err := extractor.ExtractGadgets("MultipleGadgetsVectors", ecc.BN254, &assignment_1, &assignment_2, &assignment_3) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} + +func TestExtractCircuits(t *testing.T) { + assignment_1 := TwoGadgets{Num: 11} + assignment_2 := MerkleRecover{} + + dim_1 := 3 + dim_2 := 3 + doubleSlice := make([][]frontend.Variable, dim_1) + for i := 0; i < int(dim_1); i++ { + doubleSlice[i] = make([]frontend.Variable, dim_2) + } + assignment_3 := ToBinaryCircuit{Double: doubleSlice} + assignment_4 := TwoGadgets{Num: 6} + assignment_5 := TwoGadgets{Num: 6} + + out, err := extractor.ExtractCircuits("MultipleCircuits", ecc.BN254, &assignment_3, &assignment_2, &assignment_1, &assignment_4, &assignment_5) + if err != nil { + log.Fatal(err) + } + checkOutput(t, out) +} diff --git a/extractor/test/utils_test.go b/extractor/test/utils_test.go new file mode 100644 index 0000000..03c2baf --- /dev/null +++ b/extractor/test/utils_test.go @@ -0,0 +1,62 @@ +package extractor_test + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "io" + "log" + "os" + "testing" +) + +// saveOutput can be called once when creating/changing a test to generate +// the reference result +func saveOutput(filename string, testOutput string) { + f, err := os.Create(filename) + if err != nil { + log.Fatal(err) + } + defer f.Close() + + _, err = f.WriteString(testOutput) + if err != nil { + log.Fatal(err) + } +} + +// checkOutput performs a check of the circuit generated by the extractor. +// If the hashes don't match, the circuit generated by the extractor is printed. +func checkOutput(t *testing.T, testOutput string) { + // I assume tests are executed from the extractor/test directory + filename := fmt.Sprintf("../../test/%s.lean", t.Name()) + + // https://stackoverflow.com/a/66405130 + if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) { + saveOutput(filename, testOutput) + } + + f, err := os.Open(filename) + if err != nil { + log.Fatalf("Error checking test output\n\n%s\n\n%s\n\n", err, testOutput) + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + log.Fatal(err) + } + + correctHash := h.Sum(nil) + + h.Reset() + if _, err := h.Write([]byte(testOutput)); err != nil { + log.Fatal(err) + } + testResultHash := h.Sum(nil) + if !bytes.Equal(correctHash, testResultHash) { + t.Logf("This circuit doesn't match the result in the test folder\n\n%s", testOutput) + t.Fail() + } +} diff --git a/go.mod b/go.mod index 6f794c9..96b5370 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.20 require ( github.com/consensys/gnark v0.8.0 github.com/consensys/gnark-crypto v0.9.1 + github.com/mitchellh/copystructure v1.2.0 github.com/stretchr/testify v1.8.1 + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 ) require ( @@ -15,11 +17,12 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/rs/zerolog v1.29.0 // indirect - golang.org/x/sys v0.5.0 // indirect + golang.org/x/sys v0.12.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) diff --git a/go.sum b/go.sum index 3a3dfa8..3b95434 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,10 @@ github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZb github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU= @@ -42,10 +46,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/test/TestAnotherCircuit.lean b/test/TestAnotherCircuit.lean new file mode 100644 index 0000000..fb065b5 --- /dev/null +++ b/test/TestAnotherCircuit.lean @@ -0,0 +1,24 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace AnotherCircuit + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def IntArrayGadget_4 (In: Vector F 4) (k: Vector F 3 -> Prop): Prop := + ∃gate_0, Gates.from_binary In gate_0 ∧ + ∃_ignored_, _ignored_ = Gates.mul (0:F) (36:F) ∧ + k vec![gate_0, gate_0, gate_0] + +def circuit (In: Vector F 4): Prop := + IntArrayGadget_4 In fun gate_0 => + ∃_ignored_, Gates.from_binary vec![gate_0[1], gate_0[2]] _ignored_ ∧ + ∃_ignored_, Gates.from_binary vec![gate_0[0], gate_0[1]] _ignored_ ∧ + ∃_ignored_, Gates.from_binary gate_0 _ignored_ ∧ + True + +end AnotherCircuit \ No newline at end of file diff --git a/test/TestCircuitWithParameter.lean b/test/TestCircuitWithParameter.lean new file mode 100644 index 0000000..2fdcd96 --- /dev/null +++ b/test/TestCircuitWithParameter.lean @@ -0,0 +1,53 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace CircuitWithParameter + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def ReturnItself_3_3 (In_1: Vector F 3) (Out: Vector F 3) (k: Vector F 3 -> Prop): Prop := + ∃gate_0, gate_0 = Gates.mul In_1[0] In_1[0] ∧ + ∃gate_1, gate_1 = Gates.mul In_1[1] In_1[1] ∧ + ∃gate_2, gate_2 = Gates.mul In_1[2] In_1[2] ∧ + k vec![gate_0, gate_1, gate_2] + +def SliceGadget_3_3 (In_1: Vector F 3) (In_2: Vector F 3) (k: F -> Prop): Prop := + ∃_ignored_, _ignored_ = Gates.mul In_1[0] In_2[0] ∧ + ∃_ignored_, _ignored_ = Gates.mul In_1[1] In_2[1] ∧ + ∃_ignored_, _ignored_ = Gates.mul In_1[2] In_2[2] ∧ + ∃gate_3, Gates.from_binary In_1 gate_3 ∧ + k gate_3 + +def SliceGadget_2_2 (In_1: Vector F 2) (In_2: Vector F 2) (k: F -> Prop): Prop := + ∃_ignored_, _ignored_ = Gates.mul In_1[0] In_2[0] ∧ + ∃_ignored_, _ignored_ = Gates.mul In_1[1] In_2[1] ∧ + ∃gate_2, Gates.from_binary In_1 gate_2 ∧ + k gate_2 + +def circuit (In: F) (Path: Vector F 3) (Tree: Vector F 2): Prop := + ReturnItself_3_3 Path vec![(0:F), (0:F), (0:F)] fun gate_0 => + Gates.eq gate_0[1] gate_0[2] ∧ + ReturnItself_3_3 Path gate_0 fun gate_2 => + Gates.eq gate_2[1] gate_2[2] ∧ + ReturnItself_3_3 Path gate_2 fun gate_4 => + Gates.eq gate_4[1] gate_4[2] ∧ + ∃_ignored_, Gates.from_binary Path _ignored_ ∧ + ∃_ignored_, Gates.from_binary gate_4 _ignored_ ∧ + ∃_ignored_, Gates.from_binary vec![gate_4[1], gate_4[2], gate_4[0]] _ignored_ ∧ + ∃_ignored_, Gates.from_binary vec![gate_4[1], (0:F), gate_4[0]] _ignored_ ∧ + ∃_ignored_, Gates.from_binary vec![gate_4[1], gate_4[2]] _ignored_ ∧ + ∃_ignored_, Gates.to_binary In 254 _ignored_ ∧ + ∃gate_12, Gates.to_binary (20:F) 254 gate_12 ∧ + ∃gate_13, Gates.from_binary gate_12 gate_13 ∧ + Gates.eq (20:F) gate_13 ∧ + SliceGadget_3_3 Path Path fun _ => + ∃_ignored_, _ignored_ = Gates.mul Path[0] Path[0] ∧ + SliceGadget_2_2 Tree Tree fun _ => + Gates.eq (20:F) In ∧ + True + +end CircuitWithParameter \ No newline at end of file diff --git a/test/TestDeletionMbuCircuit.lean b/test/TestDeletionMbuCircuit.lean new file mode 100644 index 0000000..680a783 --- /dev/null +++ b/test/TestDeletionMbuCircuit.lean @@ -0,0 +1,20 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace DeletionMbuCircuit + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def DeletionProof_2_2_3_2_2_3 (DeletionIndices: Vector F 2) (PreRoot: F) (IdComms: Vector F 2) (MerkleProofs: Vector (Vector F 3) 2) (k: F -> Prop): Prop := + k PreRoot + +def circuit (InputHash: F) (DeletionIndices: Vector F 2) (PreRoot: F) (PostRoot: F) (IdComms: Vector F 2) (MerkleProofs: Vector (Vector F 3) 2): Prop := + DeletionProof_2_2_3_2_2_3 DeletionIndices PreRoot IdComms MerkleProofs fun gate_0 => + Gates.eq gate_0 PostRoot ∧ + True + +end DeletionMbuCircuit \ No newline at end of file diff --git a/test/TestExtractCircuits.lean b/test/TestExtractCircuits.lean new file mode 100644 index 0000000..dea3173 --- /dev/null +++ b/test/TestExtractCircuits.lean @@ -0,0 +1,134 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace MultipleCircuits + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def VectorGadget_3_3_3_3 (In_1: Vector F 3) (In_2: Vector F 3) (Nested: Vector (Vector F 3) 3) (k: Vector F 3 -> Prop): Prop := + ∃_ignored_, _ignored_ = Gates.mul In_1[0] In_2[0] ∧ + ∃_ignored_, _ignored_ = Gates.mul In_1[1] In_2[1] ∧ + ∃gate_2, gate_2 = Gates.mul In_1[2] In_2[2] ∧ + k vec![gate_2, gate_2, gate_2] + +def DummyHash (In_1: F) (In_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.mul In_1 In_2 ∧ + k gate_0 + +def MyWidget_11 (Test_1: F) (Test_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.add Test_1 Test_2 ∧ + ∃gate_1, gate_1 = Gates.mul Test_1 Test_2 ∧ + ∃gate_2, Gates.div gate_0 gate_1 gate_2 ∧ + Gates.is_bool (11:F) ∧ + k gate_2 + +def MySecondWidget_11 (Test_1: F) (Test_2: F) : Prop := + ∃gate_0, gate_0 = Gates.mul Test_1 Test_2 ∧ + MyWidget_11 Test_1 Test_2 fun gate_1 => + ∃_ignored_, _ignored_ = Gates.mul gate_0 gate_1 ∧ + True + +def MyWidget_6 (Test_1: F) (Test_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.add Test_1 Test_2 ∧ + ∃gate_1, gate_1 = Gates.mul Test_1 Test_2 ∧ + ∃gate_2, Gates.div gate_0 gate_1 gate_2 ∧ + Gates.is_bool (6:F) ∧ + k gate_2 + +def MySecondWidget_6 (Test_1: F) (Test_2: F) : Prop := + ∃gate_0, gate_0 = Gates.mul Test_1 Test_2 ∧ + MyWidget_6 Test_1 Test_2 fun gate_1 => + ∃_ignored_, _ignored_ = Gates.mul gate_0 gate_1 ∧ + True + +def ToBinaryCircuit_3_3 (In: F) (Out: F) (Double: Vector (Vector F 3) 3): Prop := + ∃gate_0, Gates.to_binary In 3 gate_0 ∧ + ∃gate_1, Gates.to_binary Out 3 gate_1 ∧ + ∃_ignored_, _ignored_ = Gates.add Double[2][2] Double[1][1] ∧ + ∃_ignored_, _ignored_ = Gates.add _ignored_ Double[0][0] ∧ + ∃_ignored_, _ignored_ = Gates.mul gate_0[1] gate_1[1] ∧ + VectorGadget_3_3_3_3 Double[2] Double[0] Double fun gate_4 => + ∃_ignored_, _ignored_ = Gates.mul gate_4[2] gate_4[1] ∧ + True + +def MerkleRecover_20_20 (Root: F) (Element: F) (Path: Vector F 20) (Proof: Vector F 20): Prop := + DummyHash Element Proof[0] fun gate_0 => + DummyHash Proof[0] Element fun gate_1 => + ∃gate_2, Gates.select Path[0] gate_1 gate_0 gate_2 ∧ + DummyHash gate_2 Proof[1] fun gate_3 => + DummyHash Proof[1] gate_2 fun gate_4 => + ∃gate_5, Gates.select Path[1] gate_4 gate_3 gate_5 ∧ + DummyHash gate_5 Proof[2] fun gate_6 => + DummyHash Proof[2] gate_5 fun gate_7 => + ∃gate_8, Gates.select Path[2] gate_7 gate_6 gate_8 ∧ + DummyHash gate_8 Proof[3] fun gate_9 => + DummyHash Proof[3] gate_8 fun gate_10 => + ∃gate_11, Gates.select Path[3] gate_10 gate_9 gate_11 ∧ + DummyHash gate_11 Proof[4] fun gate_12 => + DummyHash Proof[4] gate_11 fun gate_13 => + ∃gate_14, Gates.select Path[4] gate_13 gate_12 gate_14 ∧ + DummyHash gate_14 Proof[5] fun gate_15 => + DummyHash Proof[5] gate_14 fun gate_16 => + ∃gate_17, Gates.select Path[5] gate_16 gate_15 gate_17 ∧ + DummyHash gate_17 Proof[6] fun gate_18 => + DummyHash Proof[6] gate_17 fun gate_19 => + ∃gate_20, Gates.select Path[6] gate_19 gate_18 gate_20 ∧ + DummyHash gate_20 Proof[7] fun gate_21 => + DummyHash Proof[7] gate_20 fun gate_22 => + ∃gate_23, Gates.select Path[7] gate_22 gate_21 gate_23 ∧ + DummyHash gate_23 Proof[8] fun gate_24 => + DummyHash Proof[8] gate_23 fun gate_25 => + ∃gate_26, Gates.select Path[8] gate_25 gate_24 gate_26 ∧ + DummyHash gate_26 Proof[9] fun gate_27 => + DummyHash Proof[9] gate_26 fun gate_28 => + ∃gate_29, Gates.select Path[9] gate_28 gate_27 gate_29 ∧ + DummyHash gate_29 Proof[10] fun gate_30 => + DummyHash Proof[10] gate_29 fun gate_31 => + ∃gate_32, Gates.select Path[10] gate_31 gate_30 gate_32 ∧ + DummyHash gate_32 Proof[11] fun gate_33 => + DummyHash Proof[11] gate_32 fun gate_34 => + ∃gate_35, Gates.select Path[11] gate_34 gate_33 gate_35 ∧ + DummyHash gate_35 Proof[12] fun gate_36 => + DummyHash Proof[12] gate_35 fun gate_37 => + ∃gate_38, Gates.select Path[12] gate_37 gate_36 gate_38 ∧ + DummyHash gate_38 Proof[13] fun gate_39 => + DummyHash Proof[13] gate_38 fun gate_40 => + ∃gate_41, Gates.select Path[13] gate_40 gate_39 gate_41 ∧ + DummyHash gate_41 Proof[14] fun gate_42 => + DummyHash Proof[14] gate_41 fun gate_43 => + ∃gate_44, Gates.select Path[14] gate_43 gate_42 gate_44 ∧ + DummyHash gate_44 Proof[15] fun gate_45 => + DummyHash Proof[15] gate_44 fun gate_46 => + ∃gate_47, Gates.select Path[15] gate_46 gate_45 gate_47 ∧ + DummyHash gate_47 Proof[16] fun gate_48 => + DummyHash Proof[16] gate_47 fun gate_49 => + ∃gate_50, Gates.select Path[16] gate_49 gate_48 gate_50 ∧ + DummyHash gate_50 Proof[17] fun gate_51 => + DummyHash Proof[17] gate_50 fun gate_52 => + ∃gate_53, Gates.select Path[17] gate_52 gate_51 gate_53 ∧ + DummyHash gate_53 Proof[18] fun gate_54 => + DummyHash Proof[18] gate_53 fun gate_55 => + ∃gate_56, Gates.select Path[18] gate_55 gate_54 gate_56 ∧ + DummyHash gate_56 Proof[19] fun gate_57 => + DummyHash Proof[19] gate_56 fun gate_58 => + ∃gate_59, Gates.select Path[19] gate_58 gate_57 gate_59 ∧ + Gates.eq gate_59 Root ∧ + True + +def TwoGadgets_11 (In_1: F) (In_2: F): Prop := + ∃gate_0, gate_0 = Gates.add In_1 In_2 ∧ + ∃gate_1, gate_1 = Gates.mul In_1 In_2 ∧ + MySecondWidget_11 gate_0 gate_1 ∧ + True + +def TwoGadgets_6 (In_1: F) (In_2: F): Prop := + ∃gate_0, gate_0 = Gates.add In_1 In_2 ∧ + ∃gate_1, gate_1 = Gates.mul In_1 In_2 ∧ + MySecondWidget_6 gate_0 gate_1 ∧ + True + +end MultipleCircuits \ No newline at end of file diff --git a/test/TestExtractGadgets.lean b/test/TestExtractGadgets.lean new file mode 100644 index 0000000..128d0ac --- /dev/null +++ b/test/TestExtractGadgets.lean @@ -0,0 +1,42 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace MultipleGadgets + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def DummyHash (In_1: F) (In_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.mul In_1 In_2 ∧ + k gate_0 + +def MyWidget_11 (Test_1: F) (Test_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.add Test_1 Test_2 ∧ + ∃gate_1, gate_1 = Gates.mul Test_1 Test_2 ∧ + ∃gate_2, Gates.div gate_0 gate_1 gate_2 ∧ + Gates.is_bool (11:F) ∧ + k gate_2 + +def MySecondWidget_11 (Test_1: F) (Test_2: F) : Prop := + ∃gate_0, gate_0 = Gates.mul Test_1 Test_2 ∧ + MyWidget_11 Test_1 Test_2 fun gate_1 => + ∃_ignored_, _ignored_ = Gates.mul gate_0 gate_1 ∧ + True + +def MyWidget_9 (Test_1: F) (Test_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.add Test_1 Test_2 ∧ + ∃gate_1, gate_1 = Gates.mul Test_1 Test_2 ∧ + ∃gate_2, Gates.div gate_0 gate_1 gate_2 ∧ + Gates.is_bool (9:F) ∧ + k gate_2 + +def MySecondWidget_9 (Test_1: F) (Test_2: F) : Prop := + ∃gate_0, gate_0 = Gates.mul Test_1 Test_2 ∧ + MyWidget_9 Test_1 Test_2 fun gate_1 => + ∃_ignored_, _ignored_ = Gates.mul gate_0 gate_1 ∧ + True + +end MultipleGadgets \ No newline at end of file diff --git a/test/TestExtractGadgetsVectors.lean b/test/TestExtractGadgetsVectors.lean new file mode 100644 index 0000000..19d8286 --- /dev/null +++ b/test/TestExtractGadgetsVectors.lean @@ -0,0 +1,28 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace MultipleGadgetsVectors + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def VectorGadget_3_3_3_3 (In_1: Vector F 3) (In_2: Vector F 3) (Nested: Vector (Vector F 3) 3) (k: Vector F 3 -> Prop): Prop := + ∃_ignored_, _ignored_ = Gates.mul In_1[0] In_2[0] ∧ + ∃_ignored_, _ignored_ = Gates.mul In_1[1] In_2[1] ∧ + ∃gate_2, gate_2 = Gates.mul In_1[2] In_2[2] ∧ + k vec![gate_2, gate_2, gate_2] + +def ReturnItself_3_3 (In_1: Vector F 3) (Out: Vector F 3) (k: Vector F 3 -> Prop): Prop := + ∃gate_0, gate_0 = Gates.mul In_1[0] In_1[0] ∧ + ∃gate_1, gate_1 = Gates.mul In_1[1] In_1[1] ∧ + ∃gate_2, gate_2 = Gates.mul In_1[2] In_1[2] ∧ + k vec![gate_0, gate_1, gate_2] + +def OptimisedVectorGadget (In: F) (k: Vector F 3 -> Prop): Prop := + ∃gate_0, Gates.to_binary In 3 gate_0 ∧ + k gate_0 + +end MultipleGadgetsVectors \ No newline at end of file diff --git a/test/TestGadgetExtraction.lean b/test/TestGadgetExtraction.lean new file mode 100644 index 0000000..5faa800 --- /dev/null +++ b/test/TestGadgetExtraction.lean @@ -0,0 +1,18 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace VectorGadget + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def VectorGadget_3_3_3_3 (In_1: Vector F 3) (In_2: Vector F 3) (Nested: Vector (Vector F 3) 3) (k: Vector F 3 -> Prop): Prop := + ∃_ignored_, _ignored_ = Gates.mul In_1[0] In_2[0] ∧ + ∃_ignored_, _ignored_ = Gates.mul In_1[1] In_2[1] ∧ + ∃gate_2, gate_2 = Gates.mul In_1[2] In_2[2] ∧ + k vec![gate_2, gate_2, gate_2] + +end VectorGadget \ No newline at end of file diff --git a/test/TestMerkleRecover.lean b/test/TestMerkleRecover.lean new file mode 100644 index 0000000..e70d028 --- /dev/null +++ b/test/TestMerkleRecover.lean @@ -0,0 +1,80 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace MerkleRecover + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def DummyHash (In_1: F) (In_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.mul In_1 In_2 ∧ + k gate_0 + +def circuit (Root: F) (Element: F) (Path: Vector F 20) (Proof: Vector F 20): Prop := + DummyHash Element Proof[0] fun gate_0 => + DummyHash Proof[0] Element fun gate_1 => + ∃gate_2, Gates.select Path[0] gate_1 gate_0 gate_2 ∧ + DummyHash gate_2 Proof[1] fun gate_3 => + DummyHash Proof[1] gate_2 fun gate_4 => + ∃gate_5, Gates.select Path[1] gate_4 gate_3 gate_5 ∧ + DummyHash gate_5 Proof[2] fun gate_6 => + DummyHash Proof[2] gate_5 fun gate_7 => + ∃gate_8, Gates.select Path[2] gate_7 gate_6 gate_8 ∧ + DummyHash gate_8 Proof[3] fun gate_9 => + DummyHash Proof[3] gate_8 fun gate_10 => + ∃gate_11, Gates.select Path[3] gate_10 gate_9 gate_11 ∧ + DummyHash gate_11 Proof[4] fun gate_12 => + DummyHash Proof[4] gate_11 fun gate_13 => + ∃gate_14, Gates.select Path[4] gate_13 gate_12 gate_14 ∧ + DummyHash gate_14 Proof[5] fun gate_15 => + DummyHash Proof[5] gate_14 fun gate_16 => + ∃gate_17, Gates.select Path[5] gate_16 gate_15 gate_17 ∧ + DummyHash gate_17 Proof[6] fun gate_18 => + DummyHash Proof[6] gate_17 fun gate_19 => + ∃gate_20, Gates.select Path[6] gate_19 gate_18 gate_20 ∧ + DummyHash gate_20 Proof[7] fun gate_21 => + DummyHash Proof[7] gate_20 fun gate_22 => + ∃gate_23, Gates.select Path[7] gate_22 gate_21 gate_23 ∧ + DummyHash gate_23 Proof[8] fun gate_24 => + DummyHash Proof[8] gate_23 fun gate_25 => + ∃gate_26, Gates.select Path[8] gate_25 gate_24 gate_26 ∧ + DummyHash gate_26 Proof[9] fun gate_27 => + DummyHash Proof[9] gate_26 fun gate_28 => + ∃gate_29, Gates.select Path[9] gate_28 gate_27 gate_29 ∧ + DummyHash gate_29 Proof[10] fun gate_30 => + DummyHash Proof[10] gate_29 fun gate_31 => + ∃gate_32, Gates.select Path[10] gate_31 gate_30 gate_32 ∧ + DummyHash gate_32 Proof[11] fun gate_33 => + DummyHash Proof[11] gate_32 fun gate_34 => + ∃gate_35, Gates.select Path[11] gate_34 gate_33 gate_35 ∧ + DummyHash gate_35 Proof[12] fun gate_36 => + DummyHash Proof[12] gate_35 fun gate_37 => + ∃gate_38, Gates.select Path[12] gate_37 gate_36 gate_38 ∧ + DummyHash gate_38 Proof[13] fun gate_39 => + DummyHash Proof[13] gate_38 fun gate_40 => + ∃gate_41, Gates.select Path[13] gate_40 gate_39 gate_41 ∧ + DummyHash gate_41 Proof[14] fun gate_42 => + DummyHash Proof[14] gate_41 fun gate_43 => + ∃gate_44, Gates.select Path[14] gate_43 gate_42 gate_44 ∧ + DummyHash gate_44 Proof[15] fun gate_45 => + DummyHash Proof[15] gate_44 fun gate_46 => + ∃gate_47, Gates.select Path[15] gate_46 gate_45 gate_47 ∧ + DummyHash gate_47 Proof[16] fun gate_48 => + DummyHash Proof[16] gate_47 fun gate_49 => + ∃gate_50, Gates.select Path[16] gate_49 gate_48 gate_50 ∧ + DummyHash gate_50 Proof[17] fun gate_51 => + DummyHash Proof[17] gate_50 fun gate_52 => + ∃gate_53, Gates.select Path[17] gate_52 gate_51 gate_53 ∧ + DummyHash gate_53 Proof[18] fun gate_54 => + DummyHash Proof[18] gate_53 fun gate_55 => + ∃gate_56, Gates.select Path[18] gate_55 gate_54 gate_56 ∧ + DummyHash gate_56 Proof[19] fun gate_57 => + DummyHash Proof[19] gate_56 fun gate_58 => + ∃gate_59, Gates.select Path[19] gate_58 gate_57 gate_59 ∧ + Gates.eq gate_59 Root ∧ + True + +end MerkleRecover \ No newline at end of file diff --git a/test/TestMyCircuit.lean b/test/TestMyCircuit.lean new file mode 100644 index 0000000..431f4f8 --- /dev/null +++ b/test/TestMyCircuit.lean @@ -0,0 +1,19 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace MyCircuit + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + + + +def circuit (In_1: F) (In_2: F) (Out: F): Prop := + ∃gate_0, gate_0 = Gates.add In_1 In_2 ∧ + Gates.eq gate_0 Out ∧ + True + +end MyCircuit \ No newline at end of file diff --git a/test/TestSlicesOptimisation.lean b/test/TestSlicesOptimisation.lean new file mode 100644 index 0000000..15dab1c --- /dev/null +++ b/test/TestSlicesOptimisation.lean @@ -0,0 +1,35 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace SlicesOptimisation + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def SlicesGadget_3_2_4_3_2 (TwoDim: Vector (Vector F 3) 2) (ThreeDim: Vector (Vector (Vector F 4) 3) 2) (k: Vector F 7 -> Prop): Prop := + k vec![ThreeDim[0][0][0], ThreeDim[0][0][1], ThreeDim[0][0][2], ThreeDim[0][0][3], TwoDim[0][0], TwoDim[0][1], TwoDim[0][2]] + +def SlicesGadget_1_2_4_3_3 (TwoDim: Vector (Vector F 1) 2) (ThreeDim: Vector (Vector (Vector F 4) 3) 3) (k: Vector F 5 -> Prop): Prop := + k vec![ThreeDim[0][0][0], ThreeDim[0][0][1], ThreeDim[0][0][2], ThreeDim[0][0][3], TwoDim[0][0]] + +def TwoSlices_3_2 (TwoDim: Vector (Vector F 3) 2) (k: Vector (Vector F 3) 2 -> Prop): Prop := + k TwoDim + +def ThreeSlices_4_3_2 (ThreeDim: Vector (Vector (Vector F 4) 3) 2) (k: Vector (Vector (Vector F 4) 3) 2 -> Prop): Prop := + k ThreeDim + +def circuit (Test: F) (Id: Vector F 3) (TwoDim: Vector (Vector F 3) 2) (ThreeDim: Vector (Vector (Vector F 4) 3) 2): Prop := + SlicesGadget_3_2_4_3_2 TwoDim ThreeDim fun _ => + SlicesGadget_3_2_4_3_2 vec![TwoDim[1], TwoDim[0]] vec![vec![ThreeDim[1][0], ThreeDim[1][1], ThreeDim[1][2]], vec![ThreeDim[0][0], ThreeDim[0][1], ThreeDim[0][2]]] fun _ => + SlicesGadget_1_2_4_3_3 vec![vec![TwoDim[1][1]], vec![TwoDim[1][0]]] vec![vec![ThreeDim[1][0], ThreeDim[1][1], ThreeDim[1][2]], vec![ThreeDim[0][0], ThreeDim[0][1], ThreeDim[0][2]], vec![ThreeDim[1][0], ThreeDim[1][1], ThreeDim[1][2]]] fun _ => + SlicesGadget_3_2_4_3_2 vec![TwoDim[1], vec![TwoDim[1][0], TwoDim[0][0], TwoDim[1][1]]] ThreeDim fun _ => + TwoSlices_3_2 TwoDim fun _ => + ThreeSlices_4_3_2 ThreeDim fun gate_5 => + ThreeSlices_4_3_2 gate_5 fun gate_6 => + ThreeSlices_4_3_2 gate_6 fun _ => + True + +end SlicesOptimisation \ No newline at end of file diff --git a/test/TestToBinaryCircuit.lean b/test/TestToBinaryCircuit.lean new file mode 100644 index 0000000..01a2ca9 --- /dev/null +++ b/test/TestToBinaryCircuit.lean @@ -0,0 +1,28 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace ToBinaryCircuit + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def VectorGadget_3_3_3_3 (In_1: Vector F 3) (In_2: Vector F 3) (Nested: Vector (Vector F 3) 3) (k: Vector F 3 -> Prop): Prop := + ∃_ignored_, _ignored_ = Gates.mul In_1[0] In_2[0] ∧ + ∃_ignored_, _ignored_ = Gates.mul In_1[1] In_2[1] ∧ + ∃gate_2, gate_2 = Gates.mul In_1[2] In_2[2] ∧ + k vec![gate_2, gate_2, gate_2] + +def circuit (In: F) (Out: F) (Double: Vector (Vector F 3) 3): Prop := + ∃gate_0, Gates.to_binary In 3 gate_0 ∧ + ∃gate_1, Gates.to_binary Out 3 gate_1 ∧ + ∃_ignored_, _ignored_ = Gates.add Double[2][2] Double[1][1] ∧ + ∃_ignored_, _ignored_ = Gates.add _ignored_ Double[0][0] ∧ + ∃_ignored_, _ignored_ = Gates.mul gate_0[1] gate_1[1] ∧ + VectorGadget_3_3_3_3 Double[2] Double[0] Double fun gate_4 => + ∃_ignored_, _ignored_ = Gates.mul gate_4[2] gate_4[1] ∧ + True + +end ToBinaryCircuit \ No newline at end of file diff --git a/test/TestTwoGadgets.lean b/test/TestTwoGadgets.lean new file mode 100644 index 0000000..764688d --- /dev/null +++ b/test/TestTwoGadgets.lean @@ -0,0 +1,31 @@ +import ProvenZk.Gates +import ProvenZk.Ext.Vector + +set_option linter.unusedVariables false + +namespace TwoGadgets + +def Order : ℕ := 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +variable [Fact (Nat.Prime Order)] +abbrev F := ZMod Order + +def MyWidget_11 (Test_1: F) (Test_2: F) (k: F -> Prop): Prop := + ∃gate_0, gate_0 = Gates.add Test_1 Test_2 ∧ + ∃gate_1, gate_1 = Gates.mul Test_1 Test_2 ∧ + ∃gate_2, Gates.div gate_0 gate_1 gate_2 ∧ + Gates.is_bool (11:F) ∧ + k gate_2 + +def MySecondWidget_11 (Test_1: F) (Test_2: F) : Prop := + ∃gate_0, gate_0 = Gates.mul Test_1 Test_2 ∧ + MyWidget_11 Test_1 Test_2 fun gate_1 => + ∃_ignored_, _ignored_ = Gates.mul gate_0 gate_1 ∧ + True + +def circuit (In_1: F) (In_2: F): Prop := + ∃gate_0, gate_0 = Gates.add In_1 In_2 ∧ + ∃gate_1, gate_1 = Gates.mul In_1 In_2 ∧ + MySecondWidget_11 gate_0 gate_1 ∧ + True + +end TwoGadgets \ No newline at end of file