From de32ef0ae38a2dedfdd40f8948b93c835d45a1b1 Mon Sep 17 00:00:00 2001 From: Chengpeng Yan <41809508+Reminiscent@users.noreply.github.com> Date: Wed, 15 Apr 2020 10:46:02 +0800 Subject: [PATCH] Add the check for expression evaluation in some executors (#16339) --- executor/aggregate.go | 4 +- expression/expression.go | 153 +++++++++++++++++++++++++++++----- expression/expression_test.go | 37 ++++++++ expression/util_test.go | 2 +- 4 files changed, 173 insertions(+), 23 deletions(-) diff --git a/executor/aggregate.go b/executor/aggregate.go index b1d5fe0ee8971..3649c1773689c 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -448,7 +448,7 @@ func getGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte, return nil, err } - if err := expression.VecEval(ctx, item, input, buf); err != nil { + if err := expression.EvalExpr(ctx, item, input, buf); err != nil { expression.PutColumn(buf) return nil, err } @@ -1106,7 +1106,7 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express return err } defer e.releaseBuffer(col) - err = expression.VecEval(e.ctx, item, chk, col) + err = expression.EvalExpr(e.ctx, item, chk, col) if err != nil { return err } diff --git a/expression/expression.go b/expression/expression.go index cddaf493c9a30..15c1b22e632c2 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -324,7 +324,7 @@ func VecEvalBool(ctx sessionctx.Context, exprList CNFExprs, input *chunk.Chunk, return nil, nil, err } - if err := VecEval(ctx, expr, input, buf); err != nil { + if err := EvalExpr(ctx, expr, input, buf); err != nil { return nil, nil, err } @@ -464,25 +464,138 @@ func toBool(sc *stmtctx.StatementContext, eType types.EvalType, buf *chunk.Colum return nil } -// VecEval evaluates this expr according to its type. -func VecEval(ctx sessionctx.Context, expr Expression, input *chunk.Chunk, result *chunk.Column) (err error) { - switch expr.GetType().EvalType() { - case types.ETInt: - err = expr.VecEvalInt(ctx, input, result) - case types.ETReal: - err = expr.VecEvalReal(ctx, input, result) - case types.ETDuration: - err = expr.VecEvalDuration(ctx, input, result) - case types.ETDatetime, types.ETTimestamp: - err = expr.VecEvalTime(ctx, input, result) - case types.ETString: - err = expr.VecEvalString(ctx, input, result) - case types.ETJson: - err = expr.VecEvalJSON(ctx, input, result) - case types.ETDecimal: - err = expr.VecEvalDecimal(ctx, input, result) - default: - err = errors.New(fmt.Sprintf("invalid eval type %v", expr.GetType().EvalType())) +// EvalExpr evaluates this expr according to its type. +// And it selects the method for evaluating expression based on +// the environment variables and whether the expression can be vectorized. +func EvalExpr(ctx sessionctx.Context, expr Expression, input *chunk.Chunk, result *chunk.Column) (err error) { + evalType := expr.GetType().EvalType() + if expr.Vectorized() && ctx.GetSessionVars().EnableVectorizedExpression { + switch evalType { + case types.ETInt: + err = expr.VecEvalInt(ctx, input, result) + case types.ETReal: + err = expr.VecEvalReal(ctx, input, result) + case types.ETDuration: + err = expr.VecEvalDuration(ctx, input, result) + case types.ETDatetime, types.ETTimestamp: + err = expr.VecEvalTime(ctx, input, result) + case types.ETString: + err = expr.VecEvalString(ctx, input, result) + case types.ETJson: + err = expr.VecEvalJSON(ctx, input, result) + case types.ETDecimal: + err = expr.VecEvalDecimal(ctx, input, result) + default: + err = errors.New(fmt.Sprintf("invalid eval type %v", expr.GetType().EvalType())) + } + } else { + ind, n := 0, input.NumRows() + iter := chunk.NewIterator4Chunk(input) + switch evalType { + case types.ETInt: + result.ResizeInt64(n, false) + i64s := result.Int64s() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalInt(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + i64s[ind] = value + } + ind++ + } + case types.ETReal: + result.ResizeFloat64(n, false) + f64s := result.Float64s() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalReal(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + f64s[ind] = value + } + ind++ + } + case types.ETDuration: + result.ResizeGoDuration(n, false) + d64s := result.GoDurations() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalDuration(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + d64s[ind] = value.Duration + } + ind++ + } + case types.ETDatetime, types.ETTimestamp: + result.ResizeTime(n, false) + t64s := result.Times() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalTime(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + t64s[ind] = value + } + ind++ + } + case types.ETString: + result.ReserveString(n) + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalString(ctx, it) + if err != nil { + return err + } + if isNull { + result.AppendNull() + } else { + result.AppendString(value) + } + } + case types.ETJson: + result.ReserveJSON(n) + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalJSON(ctx, it) + if err != nil { + return err + } + if isNull { + result.AppendNull() + } else { + result.AppendJSON(value) + } + } + case types.ETDecimal: + result.ResizeDecimal(n, false) + d64s := result.Decimals() + for it := iter.Begin(); it != iter.End(); it = iter.Next() { + value, isNull, err := expr.EvalDecimal(ctx, it) + if err != nil { + return err + } + if isNull { + result.SetNull(ind, isNull) + } else { + d64s[ind] = *value + } + ind++ + } + default: + err = errors.New(fmt.Sprintf("invalid eval type %v", expr.GetType().EvalType())) + } } return } diff --git a/expression/expression_test.go b/expression/expression_test.go index 3fe9387977176..d9e6645058e89 100644 --- a/expression/expression_test.go +++ b/expression/expression_test.go @@ -22,6 +22,8 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mock" ) func (s *testEvaluatorSuite) TestNewValuesFunc(c *C) { @@ -186,3 +188,38 @@ func tableInfoToSchemaForTest(tableInfo *model.TableInfo) *Schema { } return schema } + +func (s *testEvaluatorSuite) TestEvalExpr(c *C) { + ctx := mock.NewContext() + eTypes := []types.EvalType{types.ETInt, types.ETReal, types.ETDecimal, types.ETString, types.ETTimestamp, types.ETDatetime, types.ETDuration} + tNames := []string{"int", "real", "decimal", "string", "timestamp", "datetime", "duration"} + for i := 0; i < len(tNames); i++ { + ft := eType2FieldType(eTypes[i]) + colExpr := &Column{Index: 0, RetType: ft} + input := chunk.New([]*types.FieldType{ft}, 1024, 1024) + fillColumnWithGener(eTypes[i], input, 0, nil) + colBuf := chunk.NewColumn(ft, 1024) + colBuf2 := chunk.NewColumn(ft, 1024) + var err error + c.Assert(colExpr.Vectorized(), IsTrue) + ctx.GetSessionVars().EnableVectorizedExpression = false + err = EvalExpr(ctx, colExpr, input, colBuf) + if err != nil { + c.Fatal(err) + } + ctx.GetSessionVars().EnableVectorizedExpression = true + err = EvalExpr(ctx, colExpr, input, colBuf2) + if err != nil { + c.Fatal(err) + } + for j := 0; j < 1024; j++ { + isNull := colBuf.IsNull(j) + isNull2 := colBuf2.IsNull(j) + c.Assert(isNull, Equals, isNull2) + if isNull { + continue + } + c.Assert(string(colBuf.GetRaw(j)), Equals, string(colBuf2.GetRaw(j))) + } + } +} diff --git a/expression/util_test.go b/expression/util_test.go index c831577b14cfe..a83a934c19db2 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -341,7 +341,7 @@ func (s *testUtilSuite) TestHashGroupKey(c *check.C) { bufs[j] = bufs[j][:0] } var err error - err = VecEval(ctx, colExpr, input, colBuf) + err = EvalExpr(ctx, colExpr, input, colBuf) if err != nil { c.Fatal(err) }