Skip to content

Commit

Permalink
Support structs (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
kustosz committed Jul 24, 2024
1 parent 37dfe6a commit 11fa256
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 132 deletions.
8 changes: 3 additions & 5 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,8 @@ type ExArgType struct {
}

type ExArg struct {
Name string
Kind reflect.Kind
Type ExArgType
Name string
ArrayType *ExArgType
}

type ExCircuit struct {
Expand Down Expand Up @@ -423,12 +422,11 @@ func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstra
panic("DefineGadget only takes pointers to the gadget")
}
schema, _ := getSchema(gadget)
circuitInit(gadget, schema)
args := circuitInit(gadget, schema)
// Can't use `schema.NbPublic + schema.NbSecret`
// for arity because each array element is considered
// a parameter
arity := len(schema.Fields)
args := getExArgs(gadget, schema.Fields)

name := generateUniqueName(gadget, args)

Expand Down
15 changes: 4 additions & 11 deletions extractor/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ import (
// CircuitToLean(circuit abstractor.Circuit, field ecc.ID, namespace ...string) because the long term view
// is to add an optional parameter to support custom `set_option` directives in the header.
func CircuitToLeanWithName(circuit frontend.Circuit, field ecc.ID, namespace string) (out string, err error) {
defer recoverError()

schema, err := getSchema(circuit)
if err != nil {
return "", err
}

circuitInit(circuit, schema)
inputs := circuitInit(circuit, schema)

api := CodeExtractor{
Code: []App{},
Expand All @@ -37,7 +35,7 @@ func CircuitToLeanWithName(circuit frontend.Circuit, field ecc.ID, namespace str
}

extractorCircuit := ExCircuit{
Inputs: getExArgs(circuit, schema.Fields),
Inputs: inputs,
Gadgets: api.Gadgets,
Code: api.Code,
Field: api.FieldID,
Expand All @@ -57,8 +55,6 @@ func CircuitToLean(circuit frontend.Circuit, field ecc.ID) (string, error) {
// GadgetToLeanWithName exports a `gadget` to Lean over a `field` with `namespace`
// Same notes written for CircuitToLeanWithName apply to GadgetToLeanWithName and GadgetToLean
func GadgetToLeanWithName(gadget abstractor.GadgetDefinition, field ecc.ID, namespace string) (out string, err error) {
defer recoverError()

api := CodeExtractor{
Code: []App{},
Gadgets: []ExGadget{},
Expand All @@ -80,7 +76,6 @@ func GadgetToLean(gadget abstractor.GadgetDefinition, field ecc.ID) (string, err

// ExtractCircuits is used to export a series of `circuits` to Lean over a `field` under `namespace`.
func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circuit) (out string, err error) {
defer recoverError()

api := CodeExtractor{
Code: []App{},
Expand All @@ -103,14 +98,14 @@ func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circui
if err != nil {
return "", err
}
args := getExArgs(circuit, schema.Fields)
args := circuitInit(circuit, schema)

name := generateUniqueName(circuit, args)
if slices.Contains(past_circuits, name) {
continue
}
past_circuits = append(past_circuits, name)

circuitInit(circuit, schema)
err = circuit.Define(&api)
if err != nil {
return "", err
Expand All @@ -136,8 +131,6 @@ func ExtractCircuits(namespace string, field ecc.ID, circuits ...frontend.Circui

// ExtractGadgets is used to export a series of `gadgets` to Lean over a `field` under `namespace`.
func ExtractGadgets(namespace string, field ecc.ID, gadgets ...abstractor.GadgetDefinition) (out string, err error) {
defer recoverError()

api := CodeExtractor{
Code: []App{},
Gadgets: []ExGadget{},
Expand Down
124 changes: 61 additions & 63 deletions extractor/lean_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,76 +99,74 @@ func exportCircuit(circuit ExCircuit, name string) string {

// circuitInit takes struct and a schema to populate all the
// circuit/gagdget fields with Operand.
func circuitInit(class any, schema *schema.Schema) {
// https://stackoverflow.com/a/49704408
// https://stackoverflow.com/a/14162161
// https://stackoverflow.com/a/63422049

func circuitInit(circuit any, sch *schema.Schema) []ExArg {
// The purpose of this function is to initialise the
// struct fields with Operand interfaces for being
// processed by the Extractor.
v := reflect.ValueOf(class)
if v.Type().Kind() == reflect.Ptr {
ptr := v
v = ptr.Elem()
} else {
ptr := reflect.New(reflect.TypeOf(class))
temp := ptr.Elem()
temp.Set(v)
}

tmp_c := reflect.ValueOf(&class).Elem().Elem()
tmp := reflect.New(tmp_c.Type()).Elem()
tmp.Set(tmp_c)
for j, f := range schema.Fields {
field_name := f.Name
field := v.FieldByName(field_name)
field_type := field.Type()

// Can't assign an array to another array, therefore
// initialise each element in the array

if field_type.Kind() == reflect.Array {
arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
} else if field_type.Kind() == reflect.Slice {
// Recreate a zeroed array to remove overlapping pointers if input
// arguments are duplicated (i.e. `api.Call(SliceGadget{circuit.Path, circuit.Path})`)
arrayZero(tmp.Elem().FieldByName(field_name))
arrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
} else if field_type.Kind() == reflect.Interface {
init := Input{j}
value := reflect.ValueOf(init)

tmp.Elem().FieldByName(field_name).Set(value)
} else {
fmt.Printf("Skipped type %s\n", field_type.Kind())
v := reflect.ValueOf(circuit)
for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
v = v.Elem()
}

return structInit(v, sch.Fields, 0, "")
}

func structInit(val reflect.Value, fields []schema.Field, offset int, prefix string) []ExArg {
var result []ExArg
for _, f := range fields {
fieldVal := val.FieldByName(f.Name)
switch f.Type {
case schema.Leaf:
input := Input{offset}
value := reflect.ValueOf(input)
fieldVal.Set(value)
result = append(result, ExArg{prefix + f.Name, nil})
offset++
case schema.Array:
if fieldVal.Type().Kind() == reflect.Slice {
arrayZero(fieldVal)
}
arrTp := arrayInit(fieldVal, arraySubfield(f), Input{offset})
result = append(result, ExArg{prefix + f.Name, arrTp})
offset++
case schema.Struct:
recurResult := structInit(val.FieldByName(f.Name), f.SubFields, offset, prefix+f.Name+"_")
result = append(result, recurResult...)
offset += len(recurResult)
}
}
return result
}

func circuitArgs(field schema.Field) ExArgType {
// Handling only subfields which are nested arrays
switch len(field.SubFields) {
case 1:
subType := circuitArgs(field.SubFields[0])
return ExArgType{field.ArraySize, &subType}
case 0:
return ExArgType{field.ArraySize, nil}
default:
panic("Only nested arrays supported in SubFields")
func arraySubfield(field schema.Field) *schema.Field {
if len(field.SubFields) == 0 {
return nil
}
return &field.SubFields[0]
}

// getExArgs generates a list of ExArg given a `circuit` and a
// list of `Field`. It is used in the Circuit to Lean functions
func getExArgs(circuit any, fields []schema.Field) []ExArg {
args := []ExArg{}
for _, f := range fields {
kind := kindOfField(circuit, f.Name)
arg := ExArg{f.Name, kind, circuitArgs(f)}
args = append(args, arg)
func arrayInit(val reflect.Value, elemField *schema.Field, baseVal Operand) *ExArgType {
var childType *ExArgType
if elemField == nil {
for i := 0; i < val.Len(); i++ {
proj := Proj{baseVal, i, val.Len()}
value := reflect.ValueOf(proj)
val.Index(i).Set(value)
}
} else {
switch elemField.Type {
case schema.Leaf:
panic("Gnark contract broken – leaf inside array should be nil")
case schema.Array:
for i := 0; i < val.Len(); i++ {
childType = arrayInit(val.Index(i), arraySubfield(*elemField), Proj{baseVal, i, val.Len()})
}
case schema.Struct:
panic("Struct inside array not supported")
}

}
return args
return &ExArgType{val.Len(), childType}
}

// getSchema is a cloned version of NewSchema without constraints
Expand All @@ -187,10 +185,10 @@ func genNestedArrays(a ExArgType) string {
func genArgs(inAssignment []ExArg) string {
args := make([]string, len(inAssignment))
for i, in := range inAssignment {
switch in.Kind {
case reflect.Array, reflect.Slice:
args[i] = fmt.Sprintf("(%s: %s)", in.Name, genNestedArrays(in.Type))
default:
if in.ArrayType != nil {
args[i] = fmt.Sprintf("(%s: %s)", in.Name, genNestedArrays(*in.ArrayType))

} else {
args[i] = fmt.Sprintf("(%s: F)", in.Name)
}
}
Expand Down
57 changes: 4 additions & 53 deletions extractor/misc.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
package extractor

import (
"errors"
"flag"
"fmt"
"reflect"
"runtime/debug"
"strings"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/schema"
"github.com/mitchellh/copystructure"
"github.com/reilabs/gnark-lean-extractor/v2/abstractor"
)

// recoverError is used in the top level interface to prevent panic
// caused by any of the methods in the extractor from propagating
// When go is running in test mode, it prints the stack trace to aid
// debugging.
func recoverError() (err error) {
if recover() != nil {
if flag.Lookup("test.v") != nil {
stack := string(debug.Stack())
fmt.Println(stack)
}
err = errors.New("Panic extracting circuit to Lean")
}
return nil
}

// arrayToSlice returns a slice of elements identical to
// the input array `v`
func arrayToSlice(v reflect.Value) []frontend.Variable {
Expand Down Expand Up @@ -82,29 +63,6 @@ func flattenSlice(value reflect.Value) []frontend.Variable {
return value.Interface().([]frontend.Variable)
}

// arrayInit generates the Proj{} object for each element of v
func arrayInit(f schema.Field, v reflect.Value, op Operand) error {
for i := 0; i < f.ArraySize; i++ {
op := Proj{op, i, f.ArraySize}
switch len(f.SubFields) {
case 1:
arrayInit(f.SubFields[0], v.Index(i), op)
case 0:
if v.Len() != f.ArraySize {
// Slices of this type aren't supported yet [[<nil> <nil> <nil>] [<nil> <nil>]]
// gnark newSchema doesn't handle different dimensions
fmt.Printf("Wrong slices dimensions %+v\n", v)
panic("Only slices dimensions not matching")
}
value := reflect.ValueOf(op)
v.Index(i).Set(value)
default:
panic("Only nested arrays supported in SubFields")
}
}
return nil
}

// arrayZero sets all the elements of the input slice v to nil.
// It is used when initialising a new circuit or gadget to ensure
// the object is clean
Expand All @@ -119,22 +77,15 @@ func arrayZero(v reflect.Value) {
arrayZero(v.Addr().Elem().Index(i))
}
} else {
zero_array := make([]frontend.Variable, v.Len(), v.Len())
v.Set(reflect.ValueOf(&zero_array).Elem())
zeroArray := make([]frontend.Variable, v.Len())
v.Set(reflect.ValueOf(zeroArray))
}
}
default:
panic("Only nested slices supported in SubFields of slices")
}
}

// kindOfField returns the Kind of field in struct a
func kindOfField(a any, field string) reflect.Kind {
v := reflect.ValueOf(a).Elem()
f := v.FieldByName(field)
return f.Kind()
}

// getStructName returns the name of struct a
func getStructName(a any) string {
return reflect.TypeOf(a).Elem().Name()
Expand Down Expand Up @@ -221,9 +172,9 @@ func cloneGadget(gadget abstractor.GadgetDefinition) abstractor.GadgetDefinition
func generateUniqueName(element any, args []ExArg) string {
suffix := ""
for _, a := range args {
if a.Kind == reflect.Array || a.Kind == reflect.Slice {
if a.ArrayType != nil {
suffix += "_"
suffix += strings.Join(getSizeGadgetArgs(a.Type), "_")
suffix += strings.Join(getSizeGadgetArgs(*a.ArrayType), "_")
}
}

Expand Down
40 changes: 40 additions & 0 deletions extractor/test/nested_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package extractor_test

import (
"log"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/reilabs/gnark-lean-extractor/v2/extractor"
)

type Nest1 struct {
In1 frontend.Variable
In2 [5]frontend.Variable
}

type Nest2 struct {
In1 [4][4]frontend.Variable
In2 [3]frontend.Variable
}

type NestedCircuit struct {
N1 Nest1
N2 Nest2
}

func (circuit *NestedCircuit) Define(api frontend.API) error {
sum := api.Add(circuit.N1.In2[2], circuit.N2.In2[0])
api.AssertIsEqual(sum, circuit.N1.In2[1])
return nil
}

func TestNestedCircuit(t *testing.T) {
circuit := NestedCircuit{}
out, err := extractor.CircuitToLean(&circuit, ecc.BN254)
if err != nil {
log.Fatal(err)
}
checkOutput(t, out)
}
Loading

0 comments on commit 11fa256

Please sign in to comment.