diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index 784ae74e171b3..b208ae78a44fd 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -606,17 +606,33 @@ func (s *testEvaluatorSuite) TestExprPushDownToFlash(c *C) { client := new(mock.Client) dg := new(dataGen4Expr2PbTest) exprs := make([]Expression, 0) - function, err := NewFunction(mock.NewContext(), ast.JSONLength, types.NewFieldType(mysql.TypeLonglong), dg.genColumn(mysql.TypeJSON, 1)) + + jsonColumn := dg.genColumn(mysql.TypeJSON, 1) + intColumn := dg.genColumn(mysql.TypeLonglong, 2) + function, err := NewFunction(mock.NewContext(), ast.JSONLength, types.NewFieldType(mysql.TypeLonglong), jsonColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.If, types.NewFieldType(mysql.TypeLonglong), intColumn, intColumn, intColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.BitNeg, types.NewFieldType(mysql.TypeLonglong), intColumn) + c.Assert(err, IsNil) + exprs = append(exprs, function) + + function, err = NewFunction(mock.NewContext(), ast.Xor, types.NewFieldType(mysql.TypeLonglong), intColumn, intColumn) c.Assert(err, IsNil) exprs = append(exprs, function) + canPush := CanExprsPushDown(sc, exprs, client, kv.TiFlash) c.Assert(canPush, Equals, true) - function, err = NewFunction(mock.NewContext(), ast.JSONDepth, types.NewFieldType(mysql.TypeLonglong), dg.genColumn(mysql.TypeJSON, 2)) + function, err = NewFunction(mock.NewContext(), ast.JSONDepth, types.NewFieldType(mysql.TypeLonglong), jsonColumn) c.Assert(err, IsNil) exprs = append(exprs, function) pushed, remained := PushDownExprs(sc, exprs, client, kv.TiFlash) - c.Assert(len(pushed), Equals, 1) + c.Assert(len(pushed), Equals, len(exprs)-1) c.Assert(len(remained), Equals, 1) } diff --git a/expression/expression.go b/expression/expression.go index 96f243d6a2394..cf8ecc59ec1d8 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1114,13 +1114,12 @@ func scalarExprSupportedByTiDB(function *ScalarFunction) bool { func scalarExprSupportedByFlash(function *ScalarFunction) bool { switch function.FuncName.L { - case ast.Plus, ast.Minus, ast.Div, ast.Mul, - ast.NullEQ, ast.GE, ast.LE, ast.EQ, ast.NE, - ast.LT, ast.GT, ast.Ifnull, ast.IsNull, ast.Or, - ast.In, ast.Mod, ast.And, ast.LogicOr, ast.LogicAnd, + case ast.Plus, ast.Minus, ast.Div, ast.Mul, ast.GE, ast.LE, + ast.EQ, ast.NE, ast.LT, ast.GT, ast.Ifnull, ast.IsNull, + ast.Or, ast.In, ast.Mod, ast.And, ast.LogicOr, ast.LogicAnd, ast.Like, ast.UnaryNot, ast.Case, ast.Month, ast.Substr, ast.Substring, ast.TimestampDiff, ast.DateFormat, ast.FromUnixTime, - ast.JSONLength: + ast.JSONLength, ast.If, ast.BitNeg, ast.Xor: return true case ast.Cast: switch function.Function.PbCode() {