Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ToBinary and nested Slices #34

Merged
merged 7 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"math/big"
"reflect"
"strings"

"github.com/reilabs/gnark-lean-extractor/abstractor"

Expand Down Expand Up @@ -123,13 +124,14 @@ func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable
for i := 0; i < rt.NumField(); i++ {
fld := rt.Field(i)
v := rv.FieldByName(fld.Name)
if v.Kind() == reflect.Slice {
switch v.Kind() {
case reflect.Slice:
args = append(args, v.Interface().([]frontend.Variable))
} else if v.Kind() == reflect.Array {
case reflect.Array:
// I can't convert from array to slice using Reflect because
// the field is unaddressable.
args = append(args, ArrayToSlice(v))
} else {
case reflect.Interface:
args = append(args, v.Elem().Interface().(frontend.Variable))
}
}
Expand Down Expand Up @@ -253,8 +255,13 @@ func (ce *CodeExtractor) ToBinary(i1 frontend.Variable, n ...int) []frontend.Var
panic("Number of bits in ToBinary must be > 0")
}
}

gate := ce.AddApp(OpToBinary, i1, nbBits)
return []frontend.Variable{gate}
outs := make([]frontend.Variable, nbBits)
for i := range outs {
outs[i] = Proj{gate, i}
}
return outs
}

func (ce *CodeExtractor) FromBinary(b ...frontend.Variable) frontend.Variable {
Expand Down Expand Up @@ -348,6 +355,13 @@ func getGadgetByName(gadgets []ExGadget, name string) abstractor.Gadget {
return nil
}

func getSize(elem ExArgType) []string {
if elem.Type == nil {
return []string{fmt.Sprintf("%d", elem.Size)}
}
return append(getSize(*elem.Type), fmt.Sprintf("%d", elem.Size))
}

func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstractor.Gadget {
if reflect.ValueOf(gadget).Kind() != reflect.Ptr {
panic("DefineGadget only takes pointers to the gadget")
Expand All @@ -367,7 +381,8 @@ func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstra
suffix := ""
for _, a := range args {
if a.Kind == reflect.Array || a.Kind == reflect.Slice {
suffix += fmt.Sprintf("_%d", a.Type.Size)
suffix += "_"
suffix += strings.Join(getSize(a.Type), "_")
}
}
name := fmt.Sprintf("%s%s", reflect.TypeOf(gadget).Elem().Name(), suffix)
Expand Down
62 changes: 58 additions & 4 deletions extractor/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,57 @@ import (
"github.com/stretchr/testify/assert"
)

// Example: ToBinary behaviour and nested Slice
type VectorGadget struct {
In_1 []frontend.Variable
In_2 []frontend.Variable
}

func (gadget VectorGadget) DefineGadget(api abstractor.API) []frontend.Variable {
var sum frontend.Variable
for i := 0; i < len(gadget.In_1); i++ {
sum = api.Mul(gadget.In_1[i], gadget.In_2[i])
}
return []frontend.Variable{sum, sum, sum}
}

type ToBinaryCircuit struct {
In frontend.Variable `gnark:",public"`
Out frontend.Variable `gnark:",public"`
Double [][]frontend.Variable `gnark:",public"`
}

func (circuit *ToBinaryCircuit) AbsDefine(api abstractor.API) error {
bin := api.ToBinary(circuit.In, 3)
bout := api.ToBinary(circuit.Out, 3)

api.Add(circuit.Double[2][2], circuit.Double[1][1], circuit.Double[0][0])
api.Mul(bin[1], bout[1])
d := api.Call(VectorGadget{circuit.Double[2][:], circuit.Double[0][:]})
api.Mul(d[2], d[1])

return nil
}

func (circuit ToBinaryCircuit) Define(api frontend.API) error {
return abstractor.Concretize(api, &circuit)
}

func TestToBinaryCircuit(t *testing.T) {
dim_1 := 3
dim_2 := 3
doubleSlice := make([][]frontend.Variable, dim_1)
for i := 0; i < int(dim_1); i++ {
doubleSlice[i] = make([]frontend.Variable, dim_2)
}
assignment := ToBinaryCircuit{Double: doubleSlice}
out, err := CircuitToLean(&assignment, ecc.BN254)
if err != nil {
log.Fatal(err)
}
fmt.Println(out)
}

// Example: readme circuit
type DummyCircuit struct {
In_1 frontend.Variable
Expand Down Expand Up @@ -139,37 +190,40 @@ func TestMerkleRecover(t *testing.T) {
type MyWidget struct {
Test_1 frontend.Variable
Test_2 frontend.Variable
Num int
}

func (gadget MyWidget) DefineGadget(api abstractor.API) []frontend.Variable {
sum := api.Add(gadget.Test_1, gadget.Test_2)
mul := api.Mul(gadget.Test_1, gadget.Test_2)
r := api.Div(sum, mul)
api.AssertIsBoolean(gadget.Num)
return []frontend.Variable{r}
}

type MySecondWidget struct {
Test_1 frontend.Variable
Test_2 frontend.Variable
Num int
}

func (gadget MySecondWidget) DefineGadget(api abstractor.API) []frontend.Variable {
mul := api.Mul(gadget.Test_1, gadget.Test_2)
snd := api.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0]
snd := api.Call(MyWidget{gadget.Test_1, gadget.Test_2, gadget.Num})[0]
r := api.Mul(mul, snd)
return []frontend.Variable{r}
}

type TwoGadgets struct {
In_1 frontend.Variable
In_2 frontend.Variable
Num int
}

func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error {
sum := api.Add(circuit.In_1, circuit.In_2)
prod := api.Mul(circuit.In_1, circuit.In_2)
api.Call(MySecondWidget{sum, prod})

api.Call(MySecondWidget{sum, prod, circuit.Num})
return nil
}

Expand All @@ -178,7 +232,7 @@ func (circuit TwoGadgets) Define(api frontend.API) error {
}

func TestTwoGadgets(t *testing.T) {
assignment := TwoGadgets{}
assignment := TwoGadgets{Num: 11}
out, err := CircuitToLean(&assignment, ecc.BN254)
if err != nil {
log.Fatal(err)
Expand Down
24 changes: 21 additions & 3 deletions extractor/lean_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ func ArrayInit(f schema.Field, v reflect.Value, op Operand) error {
return nil
}

func ArrayZero(v reflect.Value) {
switch v.Kind() {
case reflect.Slice:
if v.Len() != 0 {
// Check if there are nested arrays. If yes, continue recursion
// until most nested array
if v.Addr().Elem().Index(0).Kind() == reflect.Slice {
for i := 0; i < v.Len(); i++ {
ArrayZero(v.Addr().Elem().Index(i))
}
} else {
zero_array := make([]frontend.Variable, v.Len(), v.Len())
v.Set(reflect.ValueOf(&zero_array).Elem())
}
}
default:
panic("Only nested slices supported in SubFields of slices")
}
}

func CircuitInit(class any, schema *schema.Schema) error {
// https://stackoverflow.com/a/49704408
// https://stackoverflow.com/a/14162161
Expand Down Expand Up @@ -106,9 +126,7 @@ func CircuitInit(class any, schema *schema.Schema) error {
} else if field_type.Kind() == reflect.Slice {
// Recreate a zeroed array to remove overlapping pointers if input
// arguments are duplicated (i.e. `api.Call(SliceGadget{circuit.Path, circuit.Path})`)
zero_array := make([]frontend.Variable, f.ArraySize, f.ArraySize)
tmp.Elem().FieldByName(field_name).Set(reflect.ValueOf(&zero_array).Elem())

ArrayZero(tmp.Elem().FieldByName(field_name))
ArrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
} else if field_type.Kind() == reflect.Interface {
init := Input{j}
Expand Down