diff --git a/expression/constant_test.go b/expression/constant_test.go index f07047a5f9bbe..e0eba43757412 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -32,12 +32,16 @@ var _ = Suite(&testExpressionSuite{}) type testExpressionSuite struct{} func newColumn(id int) *Column { + return newColumnWithType(id, types.NewFieldType(mysql.TypeLonglong)) +} + +func newColumnWithType(id int, t *types.FieldType) *Column { return &Column{ UniqueID: int64(id), ColName: model.NewCIStr(fmt.Sprint(id)), TblName: model.NewCIStr("t"), DBName: model.NewCIStr("test"), - RetType: types.NewFieldType(mysql.TypeLonglong), + RetType: t, } } @@ -48,6 +52,18 @@ func newLonglong(value int64) *Constant { } } +func newDate(year, month, day int) *Constant { + var tmp types.Datum + tmp.SetMysqlTime(types.Time{ + Time: types.FromDate(year, month, day, 0, 0, 0, 0), + Type: mysql.TypeDate, + }) + return &Constant{ + Value: tmp, + RetType: types.NewFieldType(mysql.TypeDate), + } +} + func newFunction(funcName string, args ...Expression) Expression { typeLong := types.NewFieldType(mysql.TypeLonglong) return NewFunctionInternal(mock.NewContext(), funcName, typeLong, args...) @@ -180,6 +196,91 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { } } +func (*testExpressionSuite) TestConstraintPropagation(c *C) { + defer testleak.AfterTest(c)() + col1 := newColumnWithType(1, types.NewFieldType(mysql.TypeDate)) + tests := []struct { + solver constraintSolver + conditions []Expression + result string + }{ + // Don't propagate this any more, because it makes the code more complex but not + // useful for partition pruning. + // { + // solver: newConstraintSolver(ruleColumnGTConst), + // conditions: []Expression{ + // newFunction(ast.GT, newColumn(0), newLonglong(5)), + // newFunction(ast.GT, newColumn(0), newLonglong(7)), + // }, + // result: "gt(test.t.0, 7)", + // }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.GT, newColumn(0), newLonglong(5)), + newFunction(ast.LT, newColumn(0), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.GT, newColumn(0), newLonglong(7)), + newFunction(ast.LT, newColumn(0), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + // col1 > '2018-12-11' and to_days(col1) < 5 => false + conditions: []Expression{ + newFunction(ast.GT, col1, newDate(2018, 12, 11)), + newFunction(ast.LT, newFunction(ast.ToDays, col1), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.LT, newColumn(0), newLonglong(5)), + newFunction(ast.GT, newColumn(0), newLonglong(5)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + conditions: []Expression{ + newFunction(ast.LT, newColumn(0), newLonglong(5)), + newFunction(ast.GT, newColumn(0), newLonglong(7)), + }, + result: "0", + }, + { + solver: newConstraintSolver(ruleColumnOPConst), + // col1 < '2018-12-11' and to_days(col1) > 737999 => false + conditions: []Expression{ + newFunction(ast.LT, col1, newDate(2018, 12, 11)), + newFunction(ast.GT, newFunction(ast.ToDays, col1), newLonglong(737999)), + }, + result: "0", + }, + } + for _, tt := range tests { + ctx := mock.NewContext() + conds := make([]Expression, 0, len(tt.conditions)) + for _, cd := range tt.conditions { + conds = append(conds, FoldConstant(cd)) + } + newConds := tt.solver.Solve(ctx, conds) + var result []string + for _, v := range newConds { + result = append(result, v.String()) + } + sort.Strings(result) + c.Assert(strings.Join(result, ", "), Equals, tt.result, Commentf("different for expr %s", tt.conditions)) + } +} + func (*testExpressionSuite) TestConstantFolding(c *C) { defer testleak.AfterTest(c)() tests := []struct { diff --git a/expression/constraint_propagation.go b/expression/constraint_propagation.go index 986576aa3261e..aafbf1660d306 100644 --- a/expression/constraint_propagation.go +++ b/expression/constraint_propagation.go @@ -16,6 +16,7 @@ package expression import ( "bytes" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" @@ -78,24 +79,35 @@ func newExprSet(conditions []Expression) *exprSet { return &exprs } +type constraintSolver []constraintPropagateRule + +func newConstraintSolver(rules ...constraintPropagateRule) constraintSolver { + return constraintSolver(rules) +} + type pgSolver2 struct{} -// PropagateConstant propagate constant values of deterministic predicates in a condition. func (s pgSolver2) PropagateConstant(ctx sessionctx.Context, conditions []Expression) []Expression { + solver := newConstraintSolver(ruleConstantFalse, ruleColumnEQConst) + return solver.Solve(ctx, conditions) +} + +// Solve propagate constraint according to the rules in the constraintSolver. +func (s constraintSolver) Solve(ctx sessionctx.Context, conditions []Expression) []Expression { exprs := newExprSet(conditions) s.fixPoint(ctx, exprs) return exprs.Slice() } -// fixPoint is the core of the constant propagation algorithm. +// fixPoint is the core of the constraint propagation algorithm. // It will iterate the expression set over and over again, pick two expressions, // apply one to another. // If new conditions can be inferred, they will be append into the expression set. // Until no more conditions can be inferred from the set, the algorithm finish. -func (s pgSolver2) fixPoint(ctx sessionctx.Context, exprs *exprSet) { +func (s constraintSolver) fixPoint(ctx sessionctx.Context, exprs *exprSet) { for { saveLen := len(exprs.data) - iterOnce(ctx, exprs) + s.iterOnce(ctx, exprs) if saveLen == len(exprs.data) { break } @@ -104,7 +116,7 @@ func (s pgSolver2) fixPoint(ctx sessionctx.Context, exprs *exprSet) { } // iterOnce picks two expressions from the set, try to propagate new conditions from them. -func iterOnce(ctx sessionctx.Context, exprs *exprSet) { +func (s constraintSolver) iterOnce(ctx sessionctx.Context, exprs *exprSet) { for i := 0; i < len(exprs.data); i++ { if exprs.tombstone[i] { continue @@ -116,24 +128,19 @@ func iterOnce(ctx sessionctx.Context, exprs *exprSet) { if i == j { continue } - solve(ctx, i, j, exprs) + s.solve(ctx, i, j, exprs) } } } // solve uses exprs[i] exprs[j] to propagate new conditions. -func solve(ctx sessionctx.Context, i, j int, exprs *exprSet) { - for _, rule := range rules { +func (s constraintSolver) solve(ctx sessionctx.Context, i, j int, exprs *exprSet) { + for _, rule := range s { rule(ctx, i, j, exprs) } } -type constantPropagateRule func(ctx sessionctx.Context, i, j int, exprs *exprSet) - -var rules = []constantPropagateRule{ - ruleConstantFalse, - ruleColumnEQConst, -} +type constraintPropagateRule func(ctx sessionctx.Context, i, j int, exprs *exprSet) // ruleConstantFalse propagates from CNF condition that false plus anything returns false. // false, a = 1, b = c ... => false @@ -164,3 +171,134 @@ func ruleColumnEQConst(ctx sessionctx.Context, i, j int, exprs *exprSet) { } } } + +// ruleColumnOPConst propagates the "column OP const" condition. +func ruleColumnOPConst(ctx sessionctx.Context, i, j int, exprs *exprSet) { + cond := exprs.data[i] + f1, ok := cond.(*ScalarFunction) + if !ok { + return + } + if f1.FuncName.L != ast.GE && f1.FuncName.L != ast.GT && + f1.FuncName.L != ast.LE && f1.FuncName.L != ast.LT { + return + } + OP1 := f1.FuncName.L + + var col1 *Column + var con1 *Constant + col1, ok = f1.GetArgs()[0].(*Column) + if !ok { + return + } + con1, ok = f1.GetArgs()[1].(*Constant) + if !ok { + return + } + + expr := exprs.data[j] + f2, ok := expr.(*ScalarFunction) + if !ok { + return + } + + // The simple case: + // col >= c1, col < c2, c1 >= c2 => false + // col >= c1, col <= c2, c1 > c2 => false + // col >= c1, col OP c2, c1 ^OP c2, where OP in [< , <=] => false + // col OP1 c1 where OP1 in [>= , <], col OP2 c2 where OP1 opsite OP2, c1 ^OP2 c2 => false + // + // The extended case: + // col >= c1, f(col) < c2, f is monotonous, f(c1) >= c2 => false + // + // Proof: + // col > c1, f is monotonous => f(col) > f(c1) + // f(col) > f(c1), f(col) < c2, f(c1) >= c2 => false + OP2 := f2.FuncName.L + if !opsiteOP(OP1, OP2) { + return + } + + con2, ok := f2.GetArgs()[1].(*Constant) + if !ok { + return + } + arg0 := f2.GetArgs()[0] + // The simple case. + var fc1 Expression + col2, ok := arg0.(*Column) + if ok { + fc1 = con1 + } else { + // The extended case. + scalarFunc, ok := arg0.(*ScalarFunction) + if !ok { + return + } + _, ok = monotoneIncFuncs[scalarFunc.FuncName.L] + if !ok { + return + } + col2, ok = scalarFunc.GetArgs()[0].(*Column) + if !ok { + return + } + var err error + fc1, err = NewFunction(ctx, scalarFunc.FuncName.L, scalarFunc.RetType, con1) + if err != nil { + log.Warn(err) + return + } + } + if !col1.Equal(ctx, col2) { + return + } + v, isNull, err := compareConstant(ctx, negOP(OP2), fc1, con2) + if err != nil { + log.Warn(err) + return + } + if !isNull && v > 0 { + exprs.SetConstFalse() + } + return +} + +// opsiteOP the opsite direction of a compare operation, used in ruleColumnOPConst. +func opsiteOP(op1, op2 string) bool { + switch { + case op1 == ast.GE || op1 == ast.GT: + return op2 == ast.LT || op2 == ast.LE + case op1 == ast.LE || op1 == ast.LT: + return op2 == ast.GT || op2 == ast.GE + } + return false +} + +func negOP(cmp string) string { + switch cmp { + case ast.LT: + return ast.GE + case ast.LE: + return ast.GT + case ast.GT: + return ast.LE + case ast.GE: + return ast.LT + } + return "" +} + +// monotoneIncFuncs are those functions that for any x y, if x > y => f(x) > f(y) +var monotoneIncFuncs = map[string]struct{}{ + ast.ToDays: {}, +} + +// compareConstant compares two expressions. c1 and c2 should be constant with the same type. +func compareConstant(ctx sessionctx.Context, fn string, c1, c2 Expression) (int64, bool, error) { + cmp, err := NewFunction(ctx, fn, types.NewFieldType(mysql.TypeTiny), c1, c2) + if err != nil { + return 0, false, err + } + return cmp.EvalInt(ctx, chunk.Row{}) +}