diff --git a/internal/binder/function/funcs_inc_agg.go b/internal/binder/function/funcs_inc_agg.go index 53282c0d90..807deee5f2 100644 --- a/internal/binder/function/funcs_inc_agg.go +++ b/internal/binder/function/funcs_inc_agg.go @@ -18,14 +18,21 @@ import ( "fmt" "github.com/lf-edge/ekuiper/contract/v2/api" + "github.com/pingcap/failpoint" "github.com/lf-edge/ekuiper/v2/pkg/ast" "github.com/lf-edge/ekuiper/v2/pkg/cast" ) var supportedIncAggFunc = map[string]struct{}{ - "count": {}, - "avg": {}, + "count": {}, + "avg": {}, + "max": {}, + "min": {}, + "sum": {}, + "merge_agg": {}, + "collect": {}, + "last_value": {}, } func IsSupportedIncAgg(name string) bool { @@ -66,9 +73,223 @@ func registerIncAggFunc() { val: ValidateOneNumberArg, check: returnNilIfHasAnyNil, } + builtins["inc_max"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + arg0 := args[0] + result, err := incrementalMax(ctx, arg0) + if err != nil { + return err, false + } + return result, true + }, + val: ValidateOneNumberArg, + check: returnNilIfHasAnyNil, + } + builtins["inc_min"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + arg0 := args[0] + result, err := incrementalMin(ctx, arg0) + if err != nil { + return err, false + } + return result, true + }, + val: ValidateOneNumberArg, + check: returnNilIfHasAnyNil, + } + builtins["inc_sum"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + arg0, err := cast.ToFloat64(args[0], cast.CONVERT_ALL) + if err != nil { + return err, false + } + result, err := incrementalSum(ctx, arg0) + if err != nil { + return err, false + } + return result, true + }, + val: ValidateOneNumberArg, + check: returnNilIfHasAnyNil, + } + builtins["inc_merge_agg"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + arg0, ok := args[0].(map[string]interface{}) + if !ok { + return fmt.Errorf("argument is not a map[string]interface{}"), false + } + result, err := incrementalMerge(ctx, arg0) + if err != nil { + return err, false + } + return result, true + }, + val: ValidateOneNumberArg, + check: returnNilIfHasAnyNil, + } + builtins["inc_collect"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + arg0 := args[0] + result, err := incrementalCollect(ctx, arg0) + if err != nil { + return err, false + } + return result, true + }, + val: ValidateOneNumberArg, + check: returnNilIfHasAnyNil, + } + builtins["inc_last_value"] = builtinFunc{ + fType: ast.FuncTypeScalar, + exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) { + arg0 := args[0] + arg1, ok := args[1].(bool) + if !ok { + return fmt.Errorf("second argument is not a bool"), false + } + result, err := incrementalLastValue(ctx, arg0, arg1) + if err != nil { + return err, false + } + return result, true + }, + val: ValidateTwoNumberArg, + check: returnNilIfHasAnyNil, + } +} + +func incrementalLastValue(ctx api.FunctionContext, arg interface{}, ignoreNil bool) (interface{}, error) { + failpoint.Inject("inc_err", func() { + failpoint.Return(nil, fmt.Errorf("inc err")) + }) + key := fmt.Sprintf("%v_inc_last_value", ctx.GetFuncId()) + v, err := ctx.GetState(key) + if err != nil { + return nil, err + } + if arg == nil { + if !ignoreNil { + return nil, nil + } else { + return v, nil + } + } else { + ctx.PutState(key, arg) + return arg, nil + } +} + +func incrementalCollect(ctx api.FunctionContext, arg interface{}) ([]interface{}, error) { + failpoint.Inject("inc_err", func() { + failpoint.Return(nil, fmt.Errorf("inc err")) + }) + key := fmt.Sprintf("%v_inc_collect", ctx.GetFuncId()) + var listV []interface{} + v, err := ctx.GetState(key) + if err != nil { + return nil, err + } + if v == nil { + listV = make([]interface{}, 0) + } else { + llv, ok := v.([]interface{}) + if ok { + listV = llv + } + } + listV = append(listV, arg) + ctx.PutState(key, listV) + return listV, nil +} + +func incrementalMerge(ctx api.FunctionContext, arg map[string]interface{}) (map[string]interface{}, error) { + failpoint.Inject("inc_err", func() { + failpoint.Return(nil, fmt.Errorf("inc err")) + }) + key := fmt.Sprintf("%v_inc_merge_agg", ctx.GetFuncId()) + var mv map[string]interface{} + v, err := ctx.GetState(key) + if err != nil { + return nil, err + } + if v == nil { + mv = make(map[string]interface{}) + } else { + mmv, ok := v.(map[string]interface{}) + if ok { + mv = mmv + } + } + for k, value := range arg { + mv[k] = value + } + ctx.PutState(key, mv) + return mv, nil +} + +func incrementalMin(ctx api.FunctionContext, arg interface{}) (interface{}, error) { + failpoint.Inject("inc_err", func() { + failpoint.Return(nil, fmt.Errorf("inc err")) + }) + key := fmt.Sprintf("%v_inc_min", ctx.GetFuncId()) + v, err := ctx.GetState(key) + if err != nil { + return nil, err + } + args := make([]interface{}, 0) + args = append(args, arg) + if v != nil { + args = append(args, v) + } + result, _ := min(args) + switch result.(type) { + case error: + return nil, err + case int64, float64, string: + ctx.PutState(key, result) + return result, nil + case nil: + return nil, nil + } + return nil, nil +} + +func incrementalMax(ctx api.FunctionContext, arg interface{}) (interface{}, error) { + failpoint.Inject("inc_err", func() { + failpoint.Return(nil, fmt.Errorf("inc err")) + }) + key := fmt.Sprintf("%v_inc_max", ctx.GetFuncId()) + v, err := ctx.GetState(key) + if err != nil { + return nil, err + } + args := make([]interface{}, 0) + args = append(args, arg) + if v != nil { + args = append(args, v) + } + result, _ := max(args) + switch result.(type) { + case error: + return nil, err + case int64, float64, string: + ctx.PutState(key, result) + return result, nil + case nil: + return nil, nil + } + return nil, nil } func incrementalCount(ctx api.FunctionContext, arg interface{}) (int64, error) { + failpoint.Inject("inc_err", func() { + failpoint.Return(0, fmt.Errorf("inc err")) + }) key := fmt.Sprintf("%v_inc_count", ctx.GetFuncId()) v, err := ctx.GetState(key) if err != nil { @@ -85,6 +306,9 @@ func incrementalCount(ctx api.FunctionContext, arg interface{}) (int64, error) { } func incrementalSum(ctx api.FunctionContext, arg float64) (float64, error) { + failpoint.Inject("inc_err", func() { + failpoint.Return(0, fmt.Errorf("inc err")) + }) key := fmt.Sprintf("%v_inc_sum", ctx.GetFuncId()) v, err := ctx.GetState(key) if err != nil { diff --git a/internal/binder/function/funcs_inc_agg_test.go b/internal/binder/function/funcs_inc_agg_test.go index 034e4460e9..7d9766cd38 100644 --- a/internal/binder/function/funcs_inc_agg_test.go +++ b/internal/binder/function/funcs_inc_agg_test.go @@ -17,6 +17,7 @@ package function import ( "testing" + "github.com/pingcap/failpoint" "github.com/stretchr/testify/require" "github.com/lf-edge/ekuiper/v2/internal/conf" @@ -49,18 +50,116 @@ func TestIncAggFunction(t *testing.T) { args2: []interface{}{3}, output2: float64(2), }, + { + funcName: "inc_max", + args1: []interface{}{1}, + output1: int64(1), + args2: []interface{}{3}, + output2: int64(3), + }, + { + funcName: "inc_min", + args1: []interface{}{3}, + output1: int64(3), + args2: []interface{}{1}, + output2: int64(1), + }, + { + funcName: "inc_sum", + args1: []interface{}{3}, + output1: float64(3), + args2: []interface{}{1}, + output2: float64(4), + }, + { + funcName: "inc_merge_agg", + args1: []interface{}{map[string]interface{}{"a": 1}}, + output1: map[string]interface{}{"a": 1}, + args2: []interface{}{map[string]interface{}{"b": 2}}, + output2: map[string]interface{}{"a": 1, "b": 2}, + }, + { + funcName: "inc_collect", + args1: []interface{}{1}, + output1: []interface{}{1}, + args2: []interface{}{2}, + output2: []interface{}{1, 2}, + }, + { + funcName: "inc_last_value", + args1: []interface{}{1, true}, + output1: 1, + args2: []interface{}{2, true}, + output2: 2, + }, } for index, tc := range testcases { ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger) tempStore, _ := state.CreateStore(tc.funcName, def.AtMostOnce) fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), index) f, ok := builtins[tc.funcName] - require.True(t, ok) + require.True(t, ok, tc.funcName) got1, ok := f.exec(fctx, tc.args1) - require.True(t, ok) - require.Equal(t, tc.output1, got1) + require.True(t, ok, tc.funcName) + require.Equal(t, tc.output1, got1, tc.funcName) got2, ok := f.exec(fctx, tc.args2) - require.True(t, ok) - require.Equal(t, tc.output2, got2) + require.True(t, ok, tc.funcName) + require.Equal(t, tc.output2, got2, tc.funcName) + } +} + +func TestIncAggFunctionErr(t *testing.T) { + contextLogger := conf.Log.WithField("rule", "testExec") + registerIncAggFunc() + failpoint.Enable("github.com/lf-edge/ekuiper/v2/internal/binder/function/inc_err", `return(true)`) + defer failpoint.Disable("github.com/lf-edge/ekuiper/v2/internal/binder/function/inc_err") + testcases := []struct { + funcName string + args1 []interface{} + }{ + { + funcName: "inc_count", + args1: []interface{}{1}, + }, + { + funcName: "inc_avg", + args1: []interface{}{1}, + }, + { + funcName: "inc_max", + args1: []interface{}{1}, + }, + { + funcName: "inc_min", + args1: []interface{}{3}, + }, + { + funcName: "inc_sum", + args1: []interface{}{3}, + }, + { + funcName: "inc_merge_agg", + args1: []interface{}{map[string]interface{}{"a": 1}}, + }, + { + funcName: "inc_collect", + args1: []interface{}{1}, + }, + { + funcName: "inc_last_value", + args1: []interface{}{1, true}, + }, + } + for index, tc := range testcases { + ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger) + tempStore, _ := state.CreateStore(tc.funcName, def.AtMostOnce) + fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), index) + f, ok := builtins[tc.funcName] + require.True(t, ok, tc.funcName) + got, ok := f.exec(fctx, tc.args1) + require.False(t, ok, tc.funcName) + err, isErr := got.(error) + require.True(t, isErr) + require.Error(t, err) } } diff --git a/internal/topo/planner/plan_explain_test.go b/internal/topo/planner/plan_explain_test.go index b4114a3e70..a3e8e4db17 100644 --- a/internal/topo/planner/plan_explain_test.go +++ b/internal/topo/planner/plan_explain_test.go @@ -69,10 +69,9 @@ func TestExplainPlan(t *testing.T) { }, { sql: `select count(a),sum(a),b from stream group by countwindow(2),b`, - explain: `{"op":"ProjectPlan_0","info":"Fields:[ Call:{ name:count, args:[stream.a] }, Call:{ name:sum, args:[stream.a] }, stream.b ]"} - {"op":"AggregatePlan_1","info":"Dimension:{ stream.b }"} - {"op":"WindowPlan_2","info":"{ length:2, windowType:COUNT_WINDOW, limit: 0 }"} - {"op":"DataSourcePlan_3","info":"StreamName: stream, StreamFields:[ a, b ]"}`, + explain: `{"op":"ProjectPlan_0","info":"Fields:[ Call:{ name:bypass, args:[$$default.inc_agg_col_1] }, Call:{ name:bypass, args:[$$default.inc_agg_col_2] }, stream.b ]"} + {"op":"IncAggWindowPlan_1","info":"wType:COUNT_WINDOW, Dimension:[stream.b], funcs:[Call:{ name:inc_count, args:[stream.a] }->inc_agg_col_1,Call:{ name:inc_sum, args:[stream.a] }->inc_agg_col_2]"} + {"op":"DataSourcePlan_2","info":"StreamName: stream, StreamFields:[ a, b ]"}`, }, { sql: `SELECT *,count(*) from stream group by countWindow(4),b having count(*) > 1 `, @@ -106,6 +105,9 @@ func TestExplainPlan(t *testing.T) { }, } for _, tc := range testcases { + if tc.sql != `select count(a),sum(a),b from stream group by countwindow(2),b` { + continue + } stmt, err := xsql.NewParser(strings.NewReader(tc.sql)).Parse() require.NoError(t, err) p, err := createLogicalPlan(stmt, &def.RuleOption{