From 805ea36cf7a9983bb97f9545d1c547e95a7c157f Mon Sep 17 00:00:00 2001 From: Elsa <111482174+elsa0520@users.noreply.github.com> Date: Thu, 20 Jun 2024 19:02:48 +0800 Subject: [PATCH] planner: change the rewrite rule of row expression (#53928) close pingcap/tidb#41598 --- pkg/planner/core/casetest/index/BUILD.bazel | 2 +- pkg/planner/core/casetest/index/index_test.go | 25 ++++ .../index/testdata/integration_suite_in.json | 16 +++ .../index/testdata/integration_suite_out.json | 107 ++++++++++++++++++ pkg/planner/core/expression_rewriter.go | 84 +++++++++----- pkg/planner/core/partition_pruning_test.go | 9 +- 6 files changed, 208 insertions(+), 35 deletions(-) diff --git a/pkg/planner/core/casetest/index/BUILD.bazel b/pkg/planner/core/casetest/index/BUILD.bazel index 3f264708225fa..0f21feb25754b 100644 --- a/pkg/planner/core/casetest/index/BUILD.bazel +++ b/pkg/planner/core/casetest/index/BUILD.bazel @@ -9,7 +9,7 @@ go_test( ], data = glob(["testdata/**"]), flaky = True, - shard_count = 3, + shard_count = 4, deps = [ "//pkg/testkit", "//pkg/testkit/testdata", diff --git a/pkg/planner/core/casetest/index/index_test.go b/pkg/planner/core/casetest/index/index_test.go index 3909b655667a4..a24c5ad2dd908 100644 --- a/pkg/planner/core/casetest/index/index_test.go +++ b/pkg/planner/core/casetest/index/index_test.go @@ -123,3 +123,28 @@ func TestRangeDerivation(t *testing.T) { plan.Check(testkit.Rows(output[i].Plan...)) } } + +func TestRowFunctionMatchTheIndexRangeScan(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec(`CREATE TABLE t1 (k1 int , k2 int, k3 int, index pk1(k1, k2))`) + tk.MustExec(`create table t2 (k1 int, k2 int)`) + var input []string + var output []struct { + SQL string + Plan []string + Result []string + } + integrationSuiteData := GetIntegrationSuiteData() + integrationSuiteData.LoadTestCases(t, &input, &output) + for i, tt := range input { + testdata.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery("explain format='brief' " + tt).Rows()) + output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows()) + }) + tk.MustQuery("explain format='brief' " + tt).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Result...)) + } +} diff --git a/pkg/planner/core/casetest/index/testdata/integration_suite_in.json b/pkg/planner/core/casetest/index/testdata/integration_suite_in.json index 5bc6d2c672129..5f785b5e50fc4 100644 --- a/pkg/planner/core/casetest/index/testdata/integration_suite_in.json +++ b/pkg/planner/core/casetest/index/testdata/integration_suite_in.json @@ -19,5 +19,21 @@ "select b from t3 where a = 1 and b is not null", "select b from t3 where a = 1 and b is null" ] + }, + { + "name": "TestRowFunctionMatchTheIndexRangeScan", + "cases": [ + "select k1 from t1 where (k1,k2) > (1,2)", + "select k1 from t1 where (k1,k2) >= (1,2)", + "select k1 from t1 where (k1,k2) < (1,2)", + "select k1 from t1 where (k1, k2) <= (1,2)", + "select k1 from t1 where (k1,k2) = (1,2)", + "select k1 from t1 where (k1, k2) != (1,2); -- could not match range scan", + "select k1 from t1 where (k1) <=> (1); -- could not match range scan", + "select k1 from t1 where (k1, k2) in ((1,2), (3,4))", + "select k1 from t1 where (k1, k2) > (1,2) and (k1, k2) < (4,5)", + "select k1 from t1 where (k1, k2) >= (1,2) and (k1, k2) <= (4,5)", + "select k1 from t1 where (k2, k3) > (1,2); -- could not match range scan " + ] } ] diff --git a/pkg/planner/core/casetest/index/testdata/integration_suite_out.json b/pkg/planner/core/casetest/index/testdata/integration_suite_out.json index 990012c0ba2f0..26d9ee6c66cf2 100644 --- a/pkg/planner/core/casetest/index/testdata/integration_suite_out.json +++ b/pkg/planner/core/casetest/index/testdata/integration_suite_out.json @@ -200,5 +200,112 @@ "Result": null } ] + }, + { + "Name": "TestRowFunctionMatchTheIndexRangeScan", + "Cases": [ + { + "SQL": "select k1 from t1 where (k1,k2) > (1,2)", + "Plan": [ + "Projection 3366.67 root test.t1.k1", + "└─IndexReader 3366.67 root index:IndexRangeScan", + " └─IndexRangeScan 3366.67 cop[tikv] table:t1, index:pk1(k1, k2) range:(1 2,1 +inf], (1,+inf], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1,k2) >= (1,2)", + "Plan": [ + "Projection 3366.67 root test.t1.k1", + "└─IndexReader 3366.67 root index:IndexRangeScan", + " └─IndexRangeScan 3366.67 cop[tikv] table:t1, index:pk1(k1, k2) range:[1 2,1 +inf], (1,+inf], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1,k2) < (1,2)", + "Plan": [ + "Projection 3356.57 root test.t1.k1", + "└─IndexReader 3356.57 root index:IndexRangeScan", + " └─IndexRangeScan 3356.57 cop[tikv] table:t1, index:pk1(k1, k2) range:[-inf,1), [1 -inf,1 2), keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1, k2) <= (1,2)", + "Plan": [ + "Projection 3356.57 root test.t1.k1", + "└─IndexReader 3356.57 root index:IndexRangeScan", + " └─IndexRangeScan 3356.57 cop[tikv] table:t1, index:pk1(k1, k2) range:[-inf,1), [1 -inf,1 2], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1,k2) = (1,2)", + "Plan": [ + "Projection 0.10 root test.t1.k1", + "└─IndexReader 0.10 root index:IndexRangeScan", + " └─IndexRangeScan 0.10 cop[tikv] table:t1, index:pk1(k1, k2) range:[1 2,1 2], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1, k2) != (1,2); -- could not match range scan", + "Plan": [ + "Projection 8882.21 root test.t1.k1", + "└─IndexReader 8882.21 root index:Selection", + " └─Selection 8882.21 cop[tikv] or(ne(test.t1.k1, 1), ne(test.t1.k2, 2))", + " └─IndexFullScan 10000.00 cop[tikv] table:t1, index:pk1(k1, k2) keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1) <=> (1); -- could not match range scan", + "Plan": [ + "IndexReader 10.00 root index:IndexRangeScan", + "└─IndexRangeScan 10.00 cop[tikv] table:t1, index:pk1(k1, k2) range:[1,1], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1, k2) in ((1,2), (3,4))", + "Plan": [ + "Projection 0.20 root test.t1.k1", + "└─IndexReader 0.20 root index:IndexRangeScan", + " └─IndexRangeScan 0.20 cop[tikv] table:t1, index:pk1(k1, k2) range:[1 2,1 2], [3 4,3 4], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1, k2) > (1,2) and (k1, k2) < (4,5)", + "Plan": [ + "Projection 1122.61 root test.t1.k1", + "└─IndexReader 1122.61 root index:Selection", + " └─Selection 1122.61 cop[tikv] or(gt(test.t1.k1, 1), and(eq(test.t1.k1, 1), gt(test.t1.k2, 2))), or(lt(test.t1.k1, 4), and(eq(test.t1.k1, 4), lt(test.t1.k2, 5)))", + " └─IndexRangeScan 1403.26 cop[tikv] table:t1, index:pk1(k1, k2) range:[1,1], (1,4), [4,4], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k1, k2) >= (1,2) and (k1, k2) <= (4,5)", + "Plan": [ + "Projection 1122.61 root test.t1.k1", + "└─IndexReader 1122.61 root index:Selection", + " └─Selection 1122.61 cop[tikv] or(gt(test.t1.k1, 1), and(eq(test.t1.k1, 1), ge(test.t1.k2, 2))), or(lt(test.t1.k1, 4), and(eq(test.t1.k1, 4), le(test.t1.k2, 5)))", + " └─IndexRangeScan 1403.26 cop[tikv] table:t1, index:pk1(k1, k2) range:[1,1], (1,4), [4,4], keep order:false, stats:pseudo" + ], + "Result": null + }, + { + "SQL": "select k1 from t1 where (k2, k3) > (1,2); -- could not match range scan ", + "Plan": [ + "Projection 3335.56 root test.t1.k1", + "└─TableReader 3335.56 root data:Selection", + " └─Selection 3335.56 cop[tikv] or(gt(test.t1.k2, 1), and(eq(test.t1.k2, 1), gt(test.t1.k3, 2)))", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ], + "Result": null + } + ] } ] diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index ec23e18fad935..048611909f60b 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -383,13 +383,17 @@ func (er *expressionRewriter) ctxStackAppend(col expression.Expression, name *ty } // constructBinaryOpFunction converts binary operator functions -// 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2) -// 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to -// `IF( a0 NE b0, a0 op b0, -// -// IF ( isNull(a0 NE b0), Null, -// IF ( a1 NE b1, a1 op b1, -// IF ( isNull(a1 NE b1), Null, a2 op b2))))` +/* + The algorithm is as follows: + 1. If the length of the two sides of the expression is 1, return l op r directly. + 2. If the length of the two sides of the expression is not equal, return an error. + 3. If the operator is EQ, NE, or NullEQ, converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2) + 4. If the operator is not EQ, NE, or NullEQ, + converts (a0,a1,a2) op (b0,b1,b2) to (a0 > b0) or (a0 = b0 and a1 > b1) or (a0 = b0 and a1 = b1 and a2 op b2) + Especially, op is GE or LE, the prefix element will be converted to > or <. + converts (a0,a1,a2) >= (b0,b1,b2) to (a0 > b0) or (a0 = b0 and a1 > b1) or (a0 = b0 and a1 = b1 and a2 >= b2) + The only different between >= and > is that >= additional include the (x,y,z) = (a,b,c). +*/ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) { lLen, rLen := expression.GetRowLen(l), expression.GetRowLen(r) if lLen == 1 && rLen == 1 { @@ -412,29 +416,51 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, } return expression.ComposeCNFCondition(er.sctx, funcs...), nil default: - larg0, rarg0 := expression.GetFuncArg(l, 0), expression.GetFuncArg(r, 0) - var expr1, expr2, expr3, expr4, expr5 expression.Expression - expr1 = expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) - expr2 = expression.NewFunctionInternal(er.sctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) - expr3 = expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr1) - var err error - l, err = expression.PopRowFirstArg(er.sctx, l) - if err != nil { - return nil, err - } - r, err = expression.PopRowFirstArg(er.sctx, r) - if err != nil { - return nil, err - } - expr4, err = er.constructBinaryOpFunction(l, r, op) - if err != nil { - return nil, err - } - expr5, err = er.newFunction(ast.If, types.NewFieldType(mysql.TypeTiny), expr3, expression.NewNull(), expr4) - if err != nil { - return nil, err + /* + The algorithm is as follows: + 1. Iterate over i left columns and construct his own CNF for each left column. + 1.1 Iterate over j (every i-1 columns) to l[j]=r[j] + 1.2 Build current i column with op to l[i] op r[i] + 1.3 Combine 1.1 and 1.2 predicates with AND operator + 2. Combine every i CNF with OR operator. + */ + resultDNFList := make([]expression.Expression, 0, lLen) + // Step 1 + for i := 0; i < lLen; i++ { + exprList := make([]expression.Expression, 0, i+1) + // Step 1.1 build prefix equal conditions + // (l[0], ... , l[i-1], ...) op (r[0], ... , r[i-1], ...) should be convert to + // l[0] = r[0] and l[1] = r[1] and ... and l[i-1] = r[i-1] + for j := 0; j < i; j++ { + jExpr, err := er.constructBinaryOpFunction(expression.GetFuncArg(l, j), expression.GetFuncArg(r, j), ast.EQ) + if err != nil { + return nil, err + } + exprList = append(exprList, jExpr) + } + + // Especially, op is GE or LE, the prefix element will be converted to > or <. + degeneratedOp := op + if i < lLen-1 { + switch op { + case ast.GE: + degeneratedOp = ast.GT + case ast.LE: + degeneratedOp = ast.LT + } + } + // Step 1.2 + currentIndexExpr, err := er.constructBinaryOpFunction(expression.GetFuncArg(l, i), expression.GetFuncArg(r, i), degeneratedOp) + if err != nil { + return nil, err + } + exprList = append(exprList, currentIndexExpr) + // Step 1.3 + currentExpr := expression.ComposeCNFCondition(er.sctx, exprList...) + resultDNFList = append(resultDNFList, currentExpr) } - return er.newFunction(ast.If, types.NewFieldType(mysql.TypeTiny), expr1, expr2, expr5) + // Step 2 + return expression.ComposeDNFCondition(er.sctx, resultDNFList...), nil } } diff --git a/pkg/planner/core/partition_pruning_test.go b/pkg/planner/core/partition_pruning_test.go index 2c519a80c70d3..50ddba64f980b 100644 --- a/pkg/planner/core/partition_pruning_test.go +++ b/pkg/planner/core/partition_pruning_test.go @@ -553,14 +553,13 @@ func TestPartitionRangeColumnsForExpr(t *testing.T) { {"a is null", partitionRangeOR{{0, 1}}}, {"12 > a", partitionRangeOR{{0, 12}}}, {"4 <= a", partitionRangeOR{{1, 14}}}, - // The expression is converted to 'if ...', see constructBinaryOpFunction, so not possible to break down to ranges - {"(a,b) < (4,4)", partitionRangeOR{{0, 14}}}, + {"(a,b) < (4,4)", partitionRangeOR{{0, 4}}}, {"(a,b) = (4,4)", partitionRangeOR{{4, 5}}}, {"a < 4 OR (a = 4 AND b < 4)", partitionRangeOR{{0, 4}}}, - // The expression is converted to 'if ...', see constructBinaryOpFunction, so not possible to break down to ranges - {"(a,b,c) < (4,4,4)", partitionRangeOR{{0, 14}}}, + {"(a,b,c) < (4,4,4)", partitionRangeOR{{0, 5}}}, {"a < 4 OR (a = 4 AND b < 4) OR (a = 4 AND b = 4 AND c < 4)", partitionRangeOR{{0, 5}}}, - {"(a,b,c) >= (4,7,4)", partitionRangeOR{{0, len(partDefs)}}}, + {"(a,b,c) >= (4,7,4)", partitionRangeOR{{5, len(partDefs)}}}, + {"a > 4 or (a= 4 and b > 7) or (a = 4 and b = 7 and c >= 4)", partitionRangeOR{{5, len(partDefs)}}}, {"(a,b,c) = (4,7,4)", partitionRangeOR{{5, 6}}}, {"a < 2 and a > 10", partitionRangeOR{}}, {"a < 1 and a > 1", partitionRangeOR{}},