Skip to content

Commit

Permalink
feat: Removing definegadget (#18)
Browse files Browse the repository at this point in the history
* Bug fixes for Array fields

* Replaced Vect with Vector

* Fixed assignGateVars for ProjArray

* Removing definegadget call
  • Loading branch information
Eagle941 authored Jul 21, 2023
1 parent b3efdf2 commit a439b70
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 63 deletions.
3 changes: 3 additions & 0 deletions abstractor/abstractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type GadgetDefinition interface {
type API interface {
frontend.API
DefineGadget(gadget GadgetDefinition) Gadget

frontend.API
Call(gadget GadgetDefinition) []frontend.Variable
}

type Circuit interface {
Expand Down
4 changes: 4 additions & 0 deletions abstractor/concretizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,8 @@ func (c *Concretizer) DefineGadget(gadget GadgetDefinition) Gadget {
return &ConcreteGadget{c}
}

func (c *Concretizer) Call(gadget GadgetDefinition) []frontend.Variable {
return c.Call(gadget)
}

var _ API = &(Concretizer{})
43 changes: 28 additions & 15 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ type ExGadget struct {

func (g *ExGadget) isOp() {}

func ArrayToSlice(v reflect.Value) []frontend.Variable {
res := make([]frontend.Variable, v.Len())

for i := 0; i < v.Len(); i++ {
res[i] = v.Index(i).Elem().Interface().(frontend.Variable)
}

return res
}

func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable {
args := []frontend.Variable{}

Expand All @@ -113,8 +123,12 @@ func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable
for i := 0; i < rt.NumField(); i++ {
fld := rt.Field(i)
v := rv.FieldByName(fld.Name)
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
if v.Kind() == reflect.Slice {
args = append(args, v.Interface().([]frontend.Variable))
} else if v.Kind() == reflect.Array {
// I can't convert from array to slice using Reflect because
// the field is unaddressable.
args = append(args, ArrayToSlice(v))
} else {
args = append(args, v.Elem().Interface().(frontend.Variable))
}
Expand All @@ -132,6 +146,15 @@ func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable
return outs
}

func (ce *CodeExtractor) Call(gadget abstractor.GadgetDefinition) []frontend.Variable {
// Copying `gadget` because `DefineGadget` needs to manipulate the input
v := reflect.ValueOf(gadget)
tmp_gadget := reflect.New(v.Type())
tmp_gadget.Elem().Set(v)
g := ce.DefineGadget(tmp_gadget.Interface().(abstractor.GadgetDefinition))
return g.Call(gadget)
}

type ExArgType struct {
Size int
Type *ExArgType
Expand All @@ -157,19 +180,6 @@ type CodeExtractor struct {
Field ecc.ID
}

func operandFromArray(args []frontend.Variable) []Operand {
ops := make([]Operand, len(args))
for i, arg := range args {
switch arg.(type) {
case Input, Gate, Proj, Const:
ops[i] = arg.(Operand)
default:
ops[i] = arg.(Proj).Operand
}
}
return ops
}

func sanitizeVars(args ...frontend.Variable) []Operand {
ops := []Operand{}
for _, arg := range args {
Expand All @@ -182,7 +192,7 @@ func sanitizeVars(args ...frontend.Variable) []Operand {
casted := arg.(big.Int)
ops = append(ops, Const{&casted})
case []frontend.Variable:
opsArray := operandFromArray(arg.([]frontend.Variable))
opsArray := sanitizeVars(arg.([]frontend.Variable)...)
ops = append(ops, ProjArray{opsArray})
default:
fmt.Printf("invalid argument of type %T\n%#v\n", arg, arg)
Expand Down Expand Up @@ -334,6 +344,9 @@ func getGadgetByName(gadgets []ExGadget, name string) abstractor.Gadget {
}

func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstractor.Gadget {
if reflect.ValueOf(gadget).Kind() != reflect.Ptr {
panic("DefineGadget only takes pointers to the gadget")
}
schema, _ := GetSchema(gadget)
CircuitInit(gadget, schema)
// Can't use `schema.NbPublic + schema.NbSecret`
Expand Down
28 changes: 6 additions & 22 deletions extractor/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,16 @@ type CircuitWithParameter struct {
}

func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error {
slice_3 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 3),
In_2: make([]frontend.Variable, 3),
})

slice_2 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 2),
In_2: make([]frontend.Variable, 2),
})

api.FromBinary(circuit.Path...)
bin := api.ToBinary(circuit.In)
bin = api.ToBinary(circuit.Param)

dec := api.FromBinary(bin...)
api.AssertIsEqual(circuit.Param, dec)
slice_3.Call(SliceGadget{circuit.Path, circuit.Path})
api.Call(SliceGadget{circuit.Path, circuit.Path})

api.Mul(circuit.Path[0], circuit.Path[0])
slice_2.Call(SliceGadget{circuit.Tree, circuit.Tree})
api.Call(SliceGadget{circuit.Tree, circuit.Tree})
api.AssertIsEqual(circuit.Param, circuit.In)

return nil
Expand Down Expand Up @@ -95,12 +85,10 @@ type MerkleRecover struct {
}

func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error {
hash := api.DefineGadget(&DummyHash{})

current := circuit.Element
for i := 0; i < len(circuit.Path); i++ {
leftHash := hash.Call(DummyHash{current, circuit.Proof[i]})[0]
rightHash := hash.Call(DummyHash{circuit.Proof[i], current})[0]
leftHash := api.Call(DummyHash{current, circuit.Proof[i]})[0]
rightHash := api.Call(DummyHash{circuit.Proof[i], current})[0]
current = api.Select(circuit.Path[i], rightHash, leftHash)
}
api.AssertIsEqual(current, circuit.Root)
Expand Down Expand Up @@ -140,10 +128,8 @@ type MySecondWidget struct {
}

func (gadget MySecondWidget) DefineGadget(api abstractor.API) []frontend.Variable {
my_widget := api.DefineGadget(&MyWidget{})

mul := api.Mul(gadget.Test_1, gadget.Test_2)
snd := my_widget.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0]
snd := api.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0]
r := api.Mul(mul, snd)
return []frontend.Variable{r}
}
Expand All @@ -154,11 +140,9 @@ type TwoGadgets struct {
}

func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error {
my_snd_widget := api.DefineGadget(&MySecondWidget{})

sum := api.Add(circuit.In_1, circuit.In_2)
prod := api.Mul(circuit.In_1, circuit.In_2)
my_snd_widget.Call(MySecondWidget{sum, prod})
api.Call(MySecondWidget{sum, prod})

return nil
}
Expand Down
64 changes: 38 additions & 26 deletions extractor/lean_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func ExportFooter(circuit ExCircuit) string {
func ExportGadget(gadget ExGadget) string {
kArgsType := "F"
if len(gadget.Outputs) > 1 {
kArgsType = fmt.Sprintf("Vect F %d", len(gadget.Outputs))
kArgsType = fmt.Sprintf("Vector F %d", len(gadget.Outputs))
}
inAssignment := gadget.Args
return fmt.Sprintf("def %s %s (k: %s -> Prop): Prop :=\n%s", gadget.Name, genArgs(inAssignment), kArgsType, genGadgetBody(inAssignment, gadget))
Expand Down Expand Up @@ -84,28 +84,31 @@ func CircuitInit(class any, schema *schema.Schema) error {
temp.Set(v)
}

tmp_c := reflect.ValueOf(&class).Elem().Elem()
tmp := reflect.New(tmp_c.Type()).Elem()
tmp.Set(tmp_c)
for j, f := range schema.Fields {
field_name := f.Name
field := v.FieldByName(field_name)
field_type := field.Type()

// Can't assign an array to another array, therefore
// initialise each element in the array
if field_type.Kind() == reflect.Array || field_type.Kind() == reflect.Slice {
tmp_c := reflect.ValueOf(&class).Elem()
tmp := reflect.New(tmp_c.Elem().Type()).Elem()
tmp.Set(tmp_c.Elem())

if field_type.Kind() == reflect.Array {
ArrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
} else if field_type.Kind() == reflect.Slice {
// Recreate a zeroed array to remove overlapping pointers if input
// arguments are duplicated (i.e. `api.Call(SliceGadget{circuit.Path, circuit.Path})`)
zero_array := make([]frontend.Variable, f.ArraySize, f.ArraySize)
tmp.Elem().FieldByName(field_name).Set(reflect.ValueOf(&zero_array).Elem())

ArrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
tmp_c.Set(tmp)
} else if field_type.Kind() == reflect.Interface {
init := Input{j}
value := reflect.ValueOf(init)

tmp_c := reflect.ValueOf(&class).Elem()
tmp := reflect.New(tmp_c.Elem().Type()).Elem()
tmp.Set(tmp_c.Elem())
tmp.Elem().FieldByName(field_name).Set(value)
tmp_c.Set(tmp)
} else {
fmt.Printf("Skipped type %s\n", field_type.Kind())
}
Expand Down Expand Up @@ -204,41 +207,50 @@ func genArgs(inAssignment []ExArg) string {
return strings.Join(args, " ")
}

func extractBaseArg(arg Operand) Operand {
func extractGateVars(arg Operand) []Operand {
switch arg.(type) {
case Proj:
return extractBaseArg(arg.(Proj).Operand)
return extractGateVars(arg.(Proj).Operand)
case ProjArray:
return extractBaseArg(arg.(ProjArray).Proj[0])
res := []Operand{}
for i := range arg.(ProjArray).Proj {
res = append(res, extractGateVars(arg.(ProjArray).Proj[i])...)
}
return res
default:
return arg
return []Operand{arg}
}
}

func assignGateVars(code []App, additional ...Operand) []string {
gateVars := make([]string, len(code))
for _, app := range code {
for _, arg := range app.Args {
base := extractBaseArg(arg)
switch base.(type) {
case Gate:
ix := base.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
bases := extractGateVars(arg)
for _, base := range bases {
switch base.(type) {
case Gate:
ix := base.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
}
}
}
}
}
for _, out := range additional {
outBase := extractBaseArg(out)
switch outBase.(type) {
case Gate:
ix := outBase.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
outBases := extractGateVars(out)
for _, outBase := range outBases {
switch outBase.(type) {
case Gate:
ix := outBase.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
}
}
}
}

return gateVars
}

Expand Down

0 comments on commit a439b70

Please sign in to comment.