Skip to content

Commit

Permalink
Refactor groupBy and sortBy (#562)
Browse files Browse the repository at this point in the history
* Rearrange opcodes

* Refactor groupBy

* Fix tests

* Super opcodes for GroupBy

* Refactor sortBy builtin

* Fix sortBy bench
  • Loading branch information
antonmedv committed Feb 15, 2024
1 parent 0bc9d99 commit 36f9adb
Show file tree
Hide file tree
Showing 14 changed files with 237 additions and 238 deletions.
2 changes: 1 addition & 1 deletion bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ func Benchmark_sortBy(b *testing.B) {
env["arr"].([]Foo)[i] = Foo{Value: v.(int)}
}

program, err := expr.Compile(`sortBy(arr, "Value")`, expr.Env(env))
program, err := expr.Compile(`sortBy(arr, .Value)`, expr.Env(env))
require.NoError(b, err)

var out any
Expand Down
137 changes: 49 additions & 88 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ var Builtins = []*Function{
Predicate: true,
Types: types(new(func([]any, func(any) any) map[any][]any)),
},
{
Name: "sortBy",
Predicate: true,
Types: types(new(func([]any, func(any) bool, string) []any)),
},
{
Name: "reduce",
Predicate: true,
Expand Down Expand Up @@ -905,109 +910,65 @@ var Builtins = []*Function{
},
{
Name: "sort",
Func: func(args ...any) (any, error) {
Safe: func(args ...any) (any, uint, error) {
if len(args) != 1 && len(args) != 2 {
return nil, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args))
return nil, 0, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args))
}

v := reflect.ValueOf(args[0])
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("cannot sort %s", v.Kind())
}
var array []any

orderBy := OrderBy{}
if len(args) == 2 {
dir, err := ascOrDesc(args[1])
if err != nil {
return nil, err
switch in := args[0].(type) {
case []any:
array = make([]any, len(in))
copy(array, in)
case []int:
array = make([]any, len(in))
for i, v := range in {
array[i] = v
}
case []float64:
array = make([]any, len(in))
for i, v := range in {
array[i] = v
}
case []string:
array = make([]any, len(in))
for i, v := range in {
array[i] = v
}
orderBy.Desc = dir
}

sortable, err := copyArray(v, orderBy)
if err != nil {
return nil, err
}
sort.Sort(sortable)
return sortable.Array, nil
},
Validate: func(args []reflect.Type) (reflect.Type, error) {
if len(args) != 1 && len(args) != 2 {
return anyType, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args))
}
switch kind(args[0]) {
case reflect.Interface, reflect.Slice, reflect.Array:
default:
return anyType, fmt.Errorf("cannot sort %s", args[0])
}
var desc bool
if len(args) == 2 {
switch kind(args[1]) {
case reflect.String, reflect.Interface:
switch args[1].(string) {
case "asc":
desc = false
case "desc":
desc = true
default:
return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1])
return nil, 0, fmt.Errorf("invalid order %s, expected asc or desc", args[1])
}
}
return arrayType, nil
},
},
{
Name: "sortBy",
Func: func(args ...any) (any, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args))
}

v := reflect.ValueOf(args[0])
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("cannot sort %s", v.Kind())
}

orderBy := OrderBy{}

field, ok := args[1].(string)
if !ok {
return nil, fmt.Errorf("invalid argument for sort (expected string, got %s)", reflect.TypeOf(args[1]))
}
orderBy.Field = field

if len(args) == 3 {
dir, err := ascOrDesc(args[2])
if err != nil {
return nil, err
}
orderBy.Desc = dir
}

sortable, err := copyArray(v, orderBy)
if err != nil {
return nil, err
sortable := &runtime.Sort{
Desc: desc,
Array: array,
}
sort.Sort(sortable)
return sortable.Array, nil
},
Validate: func(args []reflect.Type) (reflect.Type, error) {
if len(args) != 2 && len(args) != 3 {
return anyType, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args))
}
switch kind(args[0]) {
case reflect.Interface, reflect.Slice, reflect.Array:
default:
return anyType, fmt.Errorf("cannot sort %s", args[0])
}
switch kind(args[1]) {
case reflect.String, reflect.Interface:
default:
return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1])
}
if len(args) == 3 {
switch kind(args[2]) {
case reflect.String, reflect.Interface:
default:
return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1])
}
}
return arrayType, nil

return sortable.Array, uint(len(array)), nil
},
Types: types(
new(func([]any, string) []any),
new(func([]int, string) []any),
new(func([]float64, string) []any),
new(func([]string, string) []any),

new(func([]any) []any),
new(func([]float64) []any),
new(func([]string) []any),
new(func([]int) []any),
),
},
bitFunc("bitand", func(x, y int) (any, error) {
return x & y, nil
Expand Down
18 changes: 16 additions & 2 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ func TestBuiltin_sort(t *testing.T) {
{`sort(ArrayOfInt)`, []any{1, 2, 3}},
{`sort(ArrayOfFloat)`, []any{1.0, 2.0, 3.0}},
{`sort(ArrayOfInt, 'desc')`, []any{3, 2, 1}},
{`sortBy(ArrayOfFoo, 'Value')`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}},
{`sortBy([{id: "a"}, {id: "b"}], "id", "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}},
{`sortBy(ArrayOfFoo, .Value)`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}},
{`sortBy([{id: "a"}, {id: "b"}], .id, "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}},
}

for _, test := range tests {
Expand All @@ -546,6 +546,20 @@ func TestBuiltin_sort(t *testing.T) {
}
}

func TestBuiltin_sort_i64(t *testing.T) {
env := map[string]any{
"array": []int{1, 2, 3},
"i64": int64(1),
}

program, err := expr.Compile(`sort(map(array, i64))`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, []any{int64(1), int64(1), int64(1)}, out)
}

func TestBuiltin_bitOpsFunc(t *testing.T) {
tests := []struct {
input string
Expand Down
96 changes: 0 additions & 96 deletions builtin/sort.go

This file was deleted.

26 changes: 24 additions & 2 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
if isAny(collection) {
return arrayType, info{}
}
return reflect.SliceOf(collection.Elem()), info{}
return arrayType, info{}
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")

Expand All @@ -651,7 +651,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
closure.NumOut() == 1 &&
closure.NumIn() == 1 && isAny(closure.In(0)) {

return reflect.SliceOf(closure.Out(0)), info{}
return arrayType, info{}
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")

Expand Down Expand Up @@ -739,6 +739,28 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")

case "sortBy":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

v.begin(collection)
closure, _ := v.visit(node.Arguments[1])
v.end()

if len(node.Arguments) == 3 {
_, _ = v.visit(node.Arguments[2])
}

if isFunc(closure) &&
closure.NumOut() == 1 &&
closure.NumIn() == 1 && isAny(closure.In(0)) {

return reflect.TypeOf([]any{}), info{}
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")

case "reduce":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
Expand Down
10 changes: 0 additions & 10 deletions checker/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,6 @@ invalid operation: < (mismatched types mock.Bar and int) (1:29)
| all(ArrayOfFoo, {#.Method() < 0})
| ............................^
map(Any, {0})[0] + "str"
invalid operation: + (mismatched types int and string) (1:18)
| map(Any, {0})[0] + "str"
| .................^
Variadic()
not enough arguments to call Variadic (1:1)
| Variadic()
Expand Down Expand Up @@ -445,11 +440,6 @@ builtin map takes only array (got int) (1:5)
| map(1, {2})
| ....^
map(filter(ArrayOfFoo, {true}), {.Not})
type mock.Foo has no field Not (1:35)
| map(filter(ArrayOfFoo, {true}), {.Not})
| ..................................^
ArrayOfFoo[Foo]
array elements can only be selected using an integer (got mock.Foo) (1:12)
| ArrayOfFoo[Foo]
Expand Down
Loading

0 comments on commit 36f9adb

Please sign in to comment.