Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Removing definegadget #18

Merged
merged 15 commits into from
Jul 21, 2023
3 changes: 3 additions & 0 deletions abstractor/abstractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type GadgetDefinition interface {
type API interface {
frontend.API
DefineGadget(gadget GadgetDefinition) Gadget

frontend.API
Call(gadget GadgetDefinition) []frontend.Variable
}

type Circuit interface {
Expand Down
4 changes: 4 additions & 0 deletions abstractor/concretizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,8 @@ func (c *Concretizer) DefineGadget(gadget GadgetDefinition) Gadget {
return &ConcreteGadget{c}
}

func (c *Concretizer) Call(gadget GadgetDefinition) []frontend.Variable {
return c.Call(gadget)
}

var _ API = &(Concretizer{})
43 changes: 28 additions & 15 deletions extractor/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ type ExGadget struct {

func (g *ExGadget) isOp() {}

func ArrayToSlice(v reflect.Value) []frontend.Variable {
res := make([]frontend.Variable, v.Len())

for i := 0; i < v.Len(); i++ {
res[i] = v.Index(i).Elem().Interface().(frontend.Variable)
}

return res
}

func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable {
args := []frontend.Variable{}

Expand All @@ -113,8 +123,12 @@ func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable
for i := 0; i < rt.NumField(); i++ {
fld := rt.Field(i)
v := rv.FieldByName(fld.Name)
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
if v.Kind() == reflect.Slice {
args = append(args, v.Interface().([]frontend.Variable))
} else if v.Kind() == reflect.Array {
// I can't convert from array to slice using Reflect because
// the field is unaddressable.
args = append(args, ArrayToSlice(v))
} else {
args = append(args, v.Elem().Interface().(frontend.Variable))
}
Expand All @@ -132,6 +146,15 @@ func (g *ExGadget) Call(gadget abstractor.GadgetDefinition) []frontend.Variable
return outs
}

func (ce *CodeExtractor) Call(gadget abstractor.GadgetDefinition) []frontend.Variable {
// Copying `gadget` because `DefineGadget` needs to manipulate the input
v := reflect.ValueOf(gadget)
tmp_gadget := reflect.New(v.Type())
tmp_gadget.Elem().Set(v)
g := ce.DefineGadget(tmp_gadget.Interface().(abstractor.GadgetDefinition))
return g.Call(gadget)
}

type ExArgType struct {
Size int
Type *ExArgType
Expand All @@ -157,19 +180,6 @@ type CodeExtractor struct {
Field ecc.ID
}

func operandFromArray(args []frontend.Variable) []Operand {
ops := make([]Operand, len(args))
for i, arg := range args {
switch arg.(type) {
case Input, Gate, Proj, Const:
ops[i] = arg.(Operand)
default:
ops[i] = arg.(Proj).Operand
}
}
return ops
}

func sanitizeVars(args ...frontend.Variable) []Operand {
ops := []Operand{}
for _, arg := range args {
Expand All @@ -182,7 +192,7 @@ func sanitizeVars(args ...frontend.Variable) []Operand {
casted := arg.(big.Int)
ops = append(ops, Const{&casted})
case []frontend.Variable:
opsArray := operandFromArray(arg.([]frontend.Variable))
opsArray := sanitizeVars(arg.([]frontend.Variable)...)
ops = append(ops, ProjArray{opsArray})
default:
fmt.Printf("invalid argument of type %T\n%#v\n", arg, arg)
Expand Down Expand Up @@ -334,6 +344,9 @@ func getGadgetByName(gadgets []ExGadget, name string) abstractor.Gadget {
}

func (ce *CodeExtractor) DefineGadget(gadget abstractor.GadgetDefinition) abstractor.Gadget {
if reflect.ValueOf(gadget).Kind() != reflect.Ptr {
panic("DefineGadget only takes pointers to the gadget")
}
schema, _ := GetSchema(gadget)
CircuitInit(gadget, schema)
// Can't use `schema.NbPublic + schema.NbSecret`
Expand Down
28 changes: 6 additions & 22 deletions extractor/extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,16 @@ type CircuitWithParameter struct {
}

func (circuit *CircuitWithParameter) AbsDefine(api abstractor.API) error {
slice_3 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 3),
In_2: make([]frontend.Variable, 3),
})

slice_2 := api.DefineGadget(&SliceGadget{
In_1: make([]frontend.Variable, 2),
In_2: make([]frontend.Variable, 2),
})

api.FromBinary(circuit.Path...)
bin := api.ToBinary(circuit.In)
bin = api.ToBinary(circuit.Param)

dec := api.FromBinary(bin...)
api.AssertIsEqual(circuit.Param, dec)
slice_3.Call(SliceGadget{circuit.Path, circuit.Path})
api.Call(SliceGadget{circuit.Path, circuit.Path})

api.Mul(circuit.Path[0], circuit.Path[0])
slice_2.Call(SliceGadget{circuit.Tree, circuit.Tree})
api.Call(SliceGadget{circuit.Tree, circuit.Tree})
api.AssertIsEqual(circuit.Param, circuit.In)

return nil
Expand Down Expand Up @@ -95,12 +85,10 @@ type MerkleRecover struct {
}

func (circuit *MerkleRecover) AbsDefine(api abstractor.API) error {
hash := api.DefineGadget(&DummyHash{})

current := circuit.Element
for i := 0; i < len(circuit.Path); i++ {
leftHash := hash.Call(DummyHash{current, circuit.Proof[i]})[0]
rightHash := hash.Call(DummyHash{circuit.Proof[i], current})[0]
leftHash := api.Call(DummyHash{current, circuit.Proof[i]})[0]
rightHash := api.Call(DummyHash{circuit.Proof[i], current})[0]
current = api.Select(circuit.Path[i], rightHash, leftHash)
}
api.AssertIsEqual(current, circuit.Root)
Expand Down Expand Up @@ -140,10 +128,8 @@ type MySecondWidget struct {
}

func (gadget MySecondWidget) DefineGadget(api abstractor.API) []frontend.Variable {
my_widget := api.DefineGadget(&MyWidget{})

mul := api.Mul(gadget.Test_1, gadget.Test_2)
snd := my_widget.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0]
snd := api.Call(MyWidget{gadget.Test_1, gadget.Test_2})[0]
r := api.Mul(mul, snd)
return []frontend.Variable{r}
}
Expand All @@ -154,11 +140,9 @@ type TwoGadgets struct {
}

func (circuit *TwoGadgets) AbsDefine(api abstractor.API) error {
my_snd_widget := api.DefineGadget(&MySecondWidget{})

sum := api.Add(circuit.In_1, circuit.In_2)
prod := api.Mul(circuit.In_1, circuit.In_2)
my_snd_widget.Call(MySecondWidget{sum, prod})
api.Call(MySecondWidget{sum, prod})

return nil
}
Expand Down
64 changes: 38 additions & 26 deletions extractor/lean_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func ExportFooter(circuit ExCircuit) string {
func ExportGadget(gadget ExGadget) string {
kArgsType := "F"
if len(gadget.Outputs) > 1 {
kArgsType = fmt.Sprintf("Vect F %d", len(gadget.Outputs))
kArgsType = fmt.Sprintf("Vector F %d", len(gadget.Outputs))
}
inAssignment := gadget.Args
return fmt.Sprintf("def %s %s (k: %s -> Prop): Prop :=\n%s", gadget.Name, genArgs(inAssignment), kArgsType, genGadgetBody(inAssignment, gadget))
Expand Down Expand Up @@ -84,28 +84,31 @@ func CircuitInit(class any, schema *schema.Schema) error {
temp.Set(v)
}

tmp_c := reflect.ValueOf(&class).Elem().Elem()
tmp := reflect.New(tmp_c.Type()).Elem()
tmp.Set(tmp_c)
for j, f := range schema.Fields {
field_name := f.Name
field := v.FieldByName(field_name)
field_type := field.Type()

// Can't assign an array to another array, therefore
// initialise each element in the array
if field_type.Kind() == reflect.Array || field_type.Kind() == reflect.Slice {
tmp_c := reflect.ValueOf(&class).Elem()
tmp := reflect.New(tmp_c.Elem().Type()).Elem()
tmp.Set(tmp_c.Elem())

if field_type.Kind() == reflect.Array {
ArrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
} else if field_type.Kind() == reflect.Slice {
// Recreate a zeroed array to remove overlapping pointers if input
// arguments are duplicated (i.e. `api.Call(SliceGadget{circuit.Path, circuit.Path})`)
zero_array := make([]frontend.Variable, f.ArraySize, f.ArraySize)
tmp.Elem().FieldByName(field_name).Set(reflect.ValueOf(&zero_array).Elem())

ArrayInit(f, tmp.Elem().FieldByName(field_name), Input{j})
tmp_c.Set(tmp)
} else if field_type.Kind() == reflect.Interface {
init := Input{j}
value := reflect.ValueOf(init)

tmp_c := reflect.ValueOf(&class).Elem()
tmp := reflect.New(tmp_c.Elem().Type()).Elem()
tmp.Set(tmp_c.Elem())
tmp.Elem().FieldByName(field_name).Set(value)
tmp_c.Set(tmp)
} else {
fmt.Printf("Skipped type %s\n", field_type.Kind())
}
Expand Down Expand Up @@ -204,41 +207,50 @@ func genArgs(inAssignment []ExArg) string {
return strings.Join(args, " ")
}

func extractBaseArg(arg Operand) Operand {
func extractGateVars(arg Operand) []Operand {
switch arg.(type) {
case Proj:
return extractBaseArg(arg.(Proj).Operand)
return extractGateVars(arg.(Proj).Operand)
case ProjArray:
return extractBaseArg(arg.(ProjArray).Proj[0])
res := []Operand{}
for i := range arg.(ProjArray).Proj {
res = append(res, extractGateVars(arg.(ProjArray).Proj[i])...)
}
return res
default:
return arg
return []Operand{arg}
}
}

func assignGateVars(code []App, additional ...Operand) []string {
gateVars := make([]string, len(code))
for _, app := range code {
for _, arg := range app.Args {
base := extractBaseArg(arg)
switch base.(type) {
case Gate:
ix := base.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
bases := extractGateVars(arg)
for _, base := range bases {
switch base.(type) {
case Gate:
ix := base.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
}
}
}
}
}
for _, out := range additional {
outBase := extractBaseArg(out)
switch outBase.(type) {
case Gate:
ix := outBase.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
outBases := extractGateVars(out)
for _, outBase := range outBases {
switch outBase.(type) {
case Gate:
ix := outBase.(Gate).Index
if gateVars[ix] == "" {
gateVars[ix] = fmt.Sprintf("gate_%d", ix)
}
}
}
}

return gateVars
}

Expand Down