Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support structs #47

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ type ExArgType struct {
}

type ExArg struct {
Name string
Kind reflect.Kind
Type ExArgType
Name string
ArrayType *ExArgType
}

type ExCircuit struct {
Expand Down Expand Up @@ -423,12 +422,11 @@ func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstra
panic("DefineGadget only takes pointers to the gadget")
}
schema, _ := getSchema(gadget)
circuitInit(gadget, schema)
args := circuitInit(gadget, schema)
// Can't use `schema.NbPublic + schema.NbSecret`
// for arity because each array element is considered
// a parameter
arity := len(schema.Fields)
args := getExArgs(gadget, schema.Fields)

name := generateUniqueName(gadget, args)

Expand Down
15 changes: 4 additions & 11 deletions extractor/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ import (
// CircuitToLean(circuit abstractor.Circuit, field ecc.ID, namespace ...string) because the long term view
// is to add an optional parameter to support custom `set_option` directives in the header.
func CircuitToLeanWithName(circuit frontend.Circuit, field ecc.ID, namespace string) (out string, err error) {
defer recoverError()

schema, err := getSchema(circuit)
if err != nil {
return "", err
}

circuitInit(circuit, schema)
inputs := circuitInit(circuit, schema)

api := CodeExtractor{
Code: []App{},
Expand All @@ -37,7 +35,7 @@ func CircuitToLeanWithName(circuit frontend.Circuit, field ecc.ID, namespace str
}

extractorCircuit := ExCircuit{
Inputs: getExArgs(circuit, schema.Fields),
Inputs: inputs,
Gadgets: api.Gadgets,
Code: api.Code,
Field: api.FieldID,
Expand All @@ -57,8 +55,6 @@ func CircuitToLean(circuit frontend.Circuit, field ecc.ID) (string, error) {
// GadgetToLeanWithName exports a `gadget` to Lean over a `field` with `namespace`
// Same notes written for CircuitToLeanWithName apply to GadgetToLeanWithName and GadgetToLean
func GadgetToLeanWithName(gadget abstractor.GadgetDefinition, field ecc.ID, namespace string) (out string, err error) {
defer recoverError()

api := CodeExtractor{
Code: []App{},
Gadgets: []ExGadget{},
Expand All @@ -80,7 +76,6 @@ func GadgetToLean(gadget abstractor.GadgetDefinition, field ecc.ID) (string, err

// ExtractCircuits is used to export a series of `circuits` to Lean over a `field` under `namespace`.
func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circuit) (out string, err error) {
defer recoverError()

api := CodeExtractor{
Code: []App{},
Expand All @@ -103,14 +98,14 @@ func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circui
if err != nil {
return "", err
}
args := getExArgs(circuit, schema.Fields)
args := circuitInit(circuit, schema)

name := generateUniqueName(circuit, args)
if slices.Contains(past_circuits, name) {
continue
}
past_circuits = append(past_circuits, name)

circuitInit(circuit, schema)
err = circuit.Define(&api)
if err != nil {
return "", err
Expand All @@ -136,8 +131,6 @@ func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circui

// ExtractGadgets is used to export a series of `gadgets` to Lean over a `field` under `namespace`.
func ExtractGadgets(namespace string, field ecc.ID, gadgets ...abstractor.GadgetDefinition) (out string, err error) {
defer recoverError()

api := CodeExtractor{
Code: []App{},
Gadgets: []ExGadget{},
Expand Down
124 changes: 61 additions & 63 deletions extractor/lean_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,76 +99,74 @@ func exportCircuit(circuit ExCircuit, name string) string {

// circuitInit takes struct and a schema to populate all the
// circuit/gagdget fields with Operand.
func circuitInit(class any, schema *schema.Schema) {
// https://stackoverflow.com/a/49704408
// https://stackoverflow.com/a/14162161
// https://stackoverflow.com/a/63422049

func circuitInit(circuit any, sch *schema.Schema) []ExArg {
// The purpose of this function is to initialise the
// struct fields with Operand interfaces for being
// processed by the Extractor.
v := reflect.ValueOf(class)
if v.Type().Kind() == reflect.Ptr {
ptr := v
v = ptr.Elem()
} else {
ptr := reflect.New(reflect.TypeOf(class))
temp := ptr.Elem()
temp.Set(v)
}

tmp_c := reflect.ValueOf(&class).Elem().Elem()
tmp := reflect.New(tmp_c.Type()).Elem()
tmp.Set(tmp_c)
for j, f := range schema.Fields {
field_name := f.Name
field := v.FieldByName(field_name)
field_type := field.Type()

// Can't assign an array to another array, therefore
// initialise each element in the array

if field_type.Kind() == reflect.Array {
arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
} else if field_type.Kind() == reflect.Slice {
// Recreate a zeroed array to remove overlapping pointers if input
// arguments are duplicated (i.e. `api.Call(SliceGadget{circuit.Path, circuit.Path})`)
arrayZero(tmp.Elem().FieldByName(field_name))
arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
} else if field_type.Kind() == reflect.Interface {
init := Input{j}
value := reflect.ValueOf(init)

tmp.Elem().FieldByName(field_name).Set(value)
} else {
fmt.Printf("Skipped type %s\n", field_type.Kind())
v := reflect.ValueOf(circuit)
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
v = v.Elem()
}

return structInit(v, sch.Fields, 0, "")
}

func structInit(val reflect.Value, fields []schema.Field, offset int, prefix string) []ExArg {
var result []ExArg
for _, f := range fields {
fieldVal := val.FieldByName(f.Name)
switch f.Type {
case schema.Leaf:
input := Input{offset}
value := reflect.ValueOf(input)
fieldVal.Set(value)
result = append(result, ExArg{prefix + f.Name, nil})
offset++
case schema.Array:
if val.Type().Kind() == reflect.Slice {
arrayZero(val)
}
arrTp := arrayInit(fieldVal, arraySubfield(f), Input{offset})
result = append(result, ExArg{prefix + f.Name, arrTp})
offset++
case schema.Struct:
recurResult := structInit(val.FieldByName(f.Name), f.SubFields, offset, prefix+f.Name+"_")
result = append(result, recurResult...)
offset += len(recurResult)
}
}
return result
}

func circuitArgs(field schema.Field) ExArgType {
// Handling only subfields which are nested arrays
switch len(field.SubFields) {
case 1:
subType := circuitArgs(field.SubFields[0])
return ExArgType{field.ArraySize, &subType}
case 0:
return ExArgType{field.ArraySize, nil}
default:
panic("Only nested arrays supported in SubFields")
func arraySubfield(field schema.Field) *schema.Field {
if len(field.SubFields) == 0 {
return nil
}
return &field.SubFields[0]
}

// getExArgs generates a list of ExArg given a `circuit` and a
// list of `Field`. It is used in the Circuit to Lean functions
func getExArgs(circuit any, fields []schema.Field) []ExArg {
args := []ExArg{}
for _, f := range fields {
kind := kindOfField(circuit, f.Name)
arg := ExArg{f.Name, kind, circuitArgs(f)}
args = append(args, arg)
func arrayInit(val reflect.Value, elemField *schema.Field, baseVal Operand) *ExArgType {
var childType *ExArgType
if elemField == nil {
for i := 0; i < val.Len(); i++ {
proj := Proj{baseVal, i, val.Len()}
value := reflect.ValueOf(proj)
val.Index(i).Set(value)
}
} else {
switch elemField.Type {
case schema.Leaf:
panic("Gnark contract broken – leaf inside array should be nil")
case schema.Array:
for i := 0; i < val.Len(); i++ {
childType = arrayInit(val.Index(i), arraySubfield(*elemField), Proj{baseVal, i, val.Len()})
}
case schema.Struct:
panic("Struct inside array not supported")
}

}
return args
return &ExArgType{val.Len(), childType}
}

// getSchema is a cloned version of NewSchema without constraints
Expand All @@ -187,10 +185,10 @@ func genNestedArrays(a ExArgType) string {
func genArgs(inAssignment []ExArg) string {
args := make([]string, len(inAssignment))
for i, in := range inAssignment {
switch in.Kind {
case reflect.Array, reflect.Slice:
args[i] = fmt.Sprintf("(%s: %s)", in.Name, genNestedArrays(in.Type))
default:
if in.ArrayType != nil {
args[i] = fmt.Sprintf("(%s: %s)", in.Name, genNestedArrays(*in.ArrayType))

} else {
args[i] = fmt.Sprintf("(%s: F)", in.Name)
}
}
Expand Down
57 changes: 4 additions & 53 deletions extractor/misc.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
package extractor

import (
"errors"
"flag"
"fmt"
"reflect"
"runtime/debug"
"strings"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/schema"
"github.com/mitchellh/copystructure"
"github.com/reilabs/gnark-lean-extractor/v2/abstractor"
)

// recoverError is used in the top level interface to prevent panic
// caused by any of the methods in the extractor from propagating
// When go is running in test mode, it prints the stack trace to aid
// debugging.
func recoverError() (err error) {
if recover() != nil {
if flag.Lookup("test.v") != nil {
stack := string(debug.Stack())
fmt.Println(stack)
}
err = errors.New("Panic extracting circuit to Lean")
}
return nil
}

// arrayToSlice returns a slice of elements identical to
// the input array `v`
func arrayToSlice(v reflect.Value) []frontend.Variable {
Expand Down Expand Up @@ -82,29 +63,6 @@ func flattenSlice(value reflect.Value) []frontend.Variable {
return value.Interface().([]frontend.Variable)
}

// arrayInit generates the Proj{} object for each element of v
func arrayInit(f schema.Field, v reflect.Value, op Operand) error {
for i := 0; i < f.ArraySize; i++ {
op := Proj{op, i, f.ArraySize}
switch len(f.SubFields) {
case 1:
arrayInit(f.SubFields[0], v.Index(i), op)
case 0:
if v.Len() != f.ArraySize {
// Slices of this type aren't supported yet [[<nil> <nil> <nil>] [<nil> <nil>]]
// gnark newSchema doesn't handle different dimensions
fmt.Printf("Wrong slices dimensions %+v\n", v)
panic("Only slices dimensions not matching")
}
value := reflect.ValueOf(op)
v.Index(i).Set(value)
default:
panic("Only nested arrays supported in SubFields")
}
}
return nil
}

// arrayZero sets all the elements of the input slice v to nil.
// It is used when initialising a new circuit or gadget to ensure
// the object is clean
Expand All @@ -119,22 +77,15 @@ func arrayZero(v reflect.Value) {
arrayZero(v.Addr().Elem().Index(i))
}
} else {
zero_array := make([]frontend.Variable, v.Len(), v.Len())
v.Set(reflect.ValueOf(&zero_array).Elem())
zeroArray := make([]frontend.Variable, v.Len(), v.Len())
v.Set(reflect.ValueOf(&zeroArray).Elem())
}
}
default:
panic("Only nested slices supported in SubFields of slices")
}
}

// kindOfField returns the Kind of field in struct a
func kindOfField(a any, field string) reflect.Kind {
v := reflect.ValueOf(a).Elem()
f := v.FieldByName(field)
return f.Kind()
}

// getStructName returns the name of struct a
func getStructName(a any) string {
return reflect.TypeOf(a).Elem().Name()
Expand Down Expand Up @@ -221,9 +172,9 @@ func cloneGadget(gadget abstractor.GadgetDefinition) abstractor.GadgetDefinition
func generateUniqueName(element any, args []ExArg) string {
suffix := ""
for _, a := range args {
if a.Kind == reflect.Array || a.Kind == reflect.Slice {
if a.ArrayType != nil {
suffix += "_"
suffix += strings.Join(getSizeGadgetArgs(a.Type), "_")
suffix += strings.Join(getSizeGadgetArgs(*a.ArrayType), "_")
}
}

Expand Down
40 changes: 40 additions & 0 deletions extractor/test/nested_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package extractor_test

import (
"log"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/reilabs/gnark-lean-extractor/v2/extractor"
)

type Nest1 struct {
In1 frontend.Variable
In2 [5]frontend.Variable
}

type Nest2 struct {
In1 [4][4]frontend.Variable
In2 [3]frontend.Variable
}

type NestedCircuit struct {
N1 Nest1
N2 Nest2
}

func (circuit *NestedCircuit) Define(api frontend.API) error {
sum := api.Add(circuit.N1.In2[2], circuit.N2.In2[0])
api.AssertIsEqual(sum, circuit.N1.In2[1])
return nil
}

func TestNestedCircuit(t *testing.T) {
circuit := NestedCircuit{}
out, err := extractor.CircuitToLean(&circuit, ecc.BN254)
if err != nil {
log.Fatal(err)
}
checkOutput(t, out)
}
Loading
Loading