From b6d9a05d958e885b1e8b52f85d5cbb6354f696a9 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 22 Feb 2022 16:41:43 +0800 Subject: [PATCH] planner: change `redundantSchema` to `fullSchema` to correctly handle natural and "using" joins (#29599) (#30041) close pingcap/tidb#29481 --- executor/join_test.go | 33 +++++++ executor/testdata/executor_suite_in.json | 19 ++++ executor/testdata/executor_suite_out.json | 115 ++++++++++++++++++++++ planner/core/expression_rewriter.go | 6 +- planner/core/logical_plan_builder.go | 102 ++++++++++--------- planner/core/logical_plans.go | 20 +++- 6 files changed, 241 insertions(+), 54 deletions(-) diff --git a/executor/join_test.go b/executor/join_test.go index e55708bfebd8a..03f9c16586aac 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -656,6 +656,39 @@ func (s *testSuiteJoin1) TestUsing(c *C) { tk.MustQuery("select t1.t0, t2.t0 from t1 join t2 using(t0) having t1.t0 > 0").Check(testkit.Rows("1 1")) } +func (s *testSuiteWithData) TestUsingAndNaturalJoinSchema(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2, t3, t4") + tk.MustExec("create table t1 (c int, b int);") + tk.MustExec("create table t2 (a int, b int);") + tk.MustExec("create table t3 (b int, c int);") + tk.MustExec("create table t4 (y int, c int);") + + tk.MustExec("insert into t1 values (10,1);") + tk.MustExec("insert into t1 values (3 ,1);") + tk.MustExec("insert into t1 values (3 ,2);") + tk.MustExec("insert into t2 values (2, 1);") + tk.MustExec("insert into t3 values (1, 3);") + tk.MustExec("insert into t3 values (1,10);") + tk.MustExec("insert into t4 values (11,3);") + tk.MustExec("insert into t4 values (2, 3);") + + var input []string + var output []struct { + SQL string + Res []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows()) + }) + tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Res...)) + } +} + func (s *testSuiteWithData) TestNaturalJoin(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/executor/testdata/executor_suite_in.json b/executor/testdata/executor_suite_in.json index cd8fa234c0117..c8d609cdf8e31 100644 --- a/executor/testdata/executor_suite_in.json +++ b/executor/testdata/executor_suite_in.json @@ -40,6 +40,25 @@ "SELECT * FROM t1 NATURAL LEFT JOIN t2 WHERE not(t1.a <=> t2.a)" ] }, + { + "name": "TestUsingAndNaturalJoinSchema", + "cases": [ + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) natural join (t3 natural join t4);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (b);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c,b);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (b);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c,b);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (b);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c);", + "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);", + "select * from (t1 natural join t2) natural join (t3 natural join t4);", + "select * from (t1 natural join t2) join (t3 natural join t4) using (b);", + "select * from (t1 natural join t2) left outer join (t3 natural join t4) using (b);", + "select * from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);" + ] + }, { "name": "TestIndexScanWithYearCol", "cases": [ diff --git a/executor/testdata/executor_suite_out.json b/executor/testdata/executor_suite_out.json index eab098751f2ba..d2ad1621f6049 100644 --- a/executor/testdata/executor_suite_out.json +++ b/executor/testdata/executor_suite_out.json @@ -519,6 +519,121 @@ } ] }, + { + "Name": "TestUsingAndNaturalJoinSchema", + "Cases": [ + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) natural join (t3 natural join t4);", + "Res": [ + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (b);", + "Res": [ + "10 1 2 1 1 3 11 3", + "10 1 2 1 1 3 2 3", + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c);", + "Res": [ + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c,b);", + "Res": [ + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (b);", + "Res": [ + "10 1 2 1 1 3 11 3", + "10 1 2 1 1 3 2 3", + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c);", + "Res": [ + "10 1 2 1 ", + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c,b);", + "Res": [ + "10 1 2 1 ", + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (b);", + "Res": [ + "10 1 2 1 1 3 11 3", + "10 1 2 1 1 3 2 3", + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c);", + "Res": [ + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);", + "Res": [ + "3 1 2 1 1 3 11 3", + "3 1 2 1 1 3 2 3" + ] + }, + { + "SQL": "select * from (t1 natural join t2) natural join (t3 natural join t4);", + "Res": [ + "1 3 2 11", + "1 3 2 2" + ] + }, + { + "SQL": "select * from (t1 natural join t2) join (t3 natural join t4) using (b);", + "Res": [ + "1 10 2 3 11", + "1 10 2 3 2", + "1 3 2 3 11", + "1 3 2 3 2" + ] + }, + { + "SQL": "select * from (t1 natural join t2) left outer join (t3 natural join t4) using (b);", + "Res": [ + "1 10 2 3 11", + "1 10 2 3 2", + "1 3 2 3 11", + "1 3 2 3 2" + ] + }, + { + "SQL": "select * from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);", + "Res": [ + "3 1 11 2", + "3 1 2 2" + ] + } + ] + }, { "Name": "TestIndexScanWithYearCol", "Cases": [ diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 0581301d1a791..60f156d265eaf 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1865,13 +1865,13 @@ func findFieldNameFromNaturalUsingJoin(p LogicalPlan, v *ast.ColumnName) (col *e case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow: return findFieldNameFromNaturalUsingJoin(p.Children()[0], v) case *LogicalJoin: - if x.redundantSchema != nil { - idx, err := expression.FindFieldName(x.redundantNames, v) + if x.fullSchema != nil { + idx, err := expression.FindFieldName(x.fullNames, v) if err != nil { return nil, nil, err } if idx >= 0 { - return x.redundantSchema.Columns[idx], x.redundantNames[idx], nil + return x.fullSchema.Columns[idx], x.fullNames[idx], nil } } } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index dd1ab40650ac6..693089dde6d17 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -276,8 +276,8 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu schema4Agg.Append(newCol) names = append(names, p.OutputNames()[i]) } - if join, isJoin := p.(*LogicalJoin); isJoin && join.redundantSchema != nil { - for i, col := range join.redundantSchema.Columns { + if join, isJoin := p.(*LogicalJoin); isJoin && join.fullSchema != nil { + for i, col := range join.fullSchema.Columns { if p.Schema().Contains(col) { continue } @@ -289,7 +289,7 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu newCol, _ := col.Clone().(*expression.Column) newCol.RetType = newFunc.RetTp schema4Agg.Append(newCol) - names = append(names, join.redundantNames[i]) + names = append(names, join.fullNames[i]) } } hasGroupBy := len(gbyItems) > 0 @@ -720,25 +720,50 @@ func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (Logica joinPlan.JoinType = InnerJoin } - // Merge sub join's redundantSchema into this join plan. When handle query like - // select t2.a from (t1 join t2 using (a)) join t3 using (a); - // we can simply search in the top level join plan to find redundant column. + // Merge sub-plan's fullSchema into this join plan. + // Please read the comment of LogicalJoin.fullSchema for the details. var ( - lRedundantSchema, rRedundantSchema *expression.Schema - lRedundantNames, rRedundantNames types.NameSlice + lFullSchema, rFullSchema *expression.Schema + lFullNames, rFullNames types.NameSlice ) - if left, ok := leftPlan.(*LogicalJoin); ok && left.redundantSchema != nil { - lRedundantSchema = left.redundantSchema - lRedundantNames = left.redundantNames + if left, ok := leftPlan.(*LogicalJoin); ok && left.fullSchema != nil { + lFullSchema = left.fullSchema + lFullNames = left.fullNames + } else { + lFullSchema = leftPlan.Schema() + lFullNames = leftPlan.OutputNames() + } + if right, ok := rightPlan.(*LogicalJoin); ok && right.fullSchema != nil { + rFullSchema = right.fullSchema + rFullNames = right.fullNames + } else { + rFullSchema = rightPlan.Schema() + rFullNames = rightPlan.OutputNames() } - if right, ok := rightPlan.(*LogicalJoin); ok && right.redundantSchema != nil { - rRedundantSchema = right.redundantSchema - rRedundantNames = right.redundantNames + if joinNode.Tp == ast.RightJoin { + // Make sure lFullSchema means outer full schema and rFullSchema means inner full schema. + lFullSchema, rFullSchema = rFullSchema, lFullSchema + lFullNames, rFullNames = rFullNames, lFullNames + } + joinPlan.fullSchema = expression.MergeSchema(lFullSchema, rFullSchema) + + // Clear NotNull flag for the inner side schema if it's an outer join. + if joinNode.Tp == ast.LeftJoin || joinNode.Tp == ast.RightJoin { + resetNotNullFlag(joinPlan.fullSchema, lFullSchema.Len(), joinPlan.fullSchema.Len()) + } + + // Merge sub-plan's fullNames into this join plan, similar to the fullSchema logic above. + joinPlan.fullNames = make([]*types.FieldName, 0, len(lFullNames)+len(rFullNames)) + for _, lName := range lFullNames { + name := *lName + name.Redundant = true + joinPlan.fullNames = append(joinPlan.fullNames, &name) + } + for _, rName := range rFullNames { + name := *rName + name.Redundant = true + joinPlan.fullNames = append(joinPlan.fullNames, &name) } - joinPlan.redundantSchema = expression.MergeSchema(lRedundantSchema, rRedundantSchema) - joinPlan.redundantNames = make([]*types.FieldName, len(lRedundantNames)+len(rRedundantNames)) - copy(joinPlan.redundantNames, lRedundantNames) - copy(joinPlan.redundantNames[len(lRedundantNames):], rRedundantNames) // Set preferred join algorithm if some join hints is specified by user. joinPlan.setPreferredJoinType(b.TableHints()) @@ -941,21 +966,7 @@ func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan p.SetSchema(expression.NewSchema(schemaCols...)) p.names = names - // We record the full `rightPlan.Schema` as `redundantSchema` in order to - // record the redundant column in `rightPlan` and the output columns order - // of the `rightPlan`. - // For SQL like `select t1.*, t2.* from t1 left join t2 using(a)`, we can - // retrieve the column order of `t2.*` from the `redundantSchema`. - p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rightPlan.Schema().Clone().Columns...)) - p.redundantNames = p.redundantNames.Shallow() - for _, name := range rightPlan.OutputNames() { - cpyName := *name - cpyName.Redundant = true - p.redundantNames = append(p.redundantNames, &cpyName) - } - if joinTp == ast.RightJoin || joinTp == ast.LeftJoin { - resetNotNullFlag(p.redundantSchema, 0, p.redundantSchema.Len()) - } + p.OtherConditions = append(conds, p.OtherConditions...) return nil @@ -1208,9 +1219,9 @@ func findColFromNaturalUsingJoin(p LogicalPlan, col *expression.Column) (name *t case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow: return findColFromNaturalUsingJoin(p.Children()[0], col) case *LogicalJoin: - if x.redundantSchema != nil { - idx := x.redundantSchema.ColumnIndex(col) - return x.redundantNames[idx] + if x.fullSchema != nil { + idx := x.fullSchema.ColumnIndex(col) + return x.fullNames[idx] } } return nil @@ -1996,9 +2007,9 @@ func (a *havingWindowAndOrderbyExprResolver) resolveFromPlan(v *ast.ColumnNameEx case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow: return a.resolveFromPlan(v, p.Children()[0]) case *LogicalJoin: - if len(x.redundantNames) != 0 { - idx, err = expression.FindFieldName(x.redundantNames, v.Name) - schemaCols, outputNames = x.redundantSchema.Columns, x.redundantNames + if len(x.fullNames) != 0 { + idx, err = expression.FindFieldName(x.fullNames, v.Name) + schemaCols, outputNames = x.fullSchema.Columns, x.fullNames } } if err != nil || idx < 0 { @@ -3146,14 +3157,11 @@ func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectFi return nil, ErrInvalidWildCard } list := unfoldWildStar(field, p.OutputNames(), p.Schema().Columns) - // For sql like `select t1.*, t2.* from t1 join t2 using(a)`, we should - // not coalesce the `t2.a` in the output result. Thus we need to unfold - // the wildstar from the underlying join.redundantSchema. - if isJoin && join.redundantSchema != nil && field.WildCard.Table.L != "" { - redundantList := unfoldWildStar(field, join.redundantNames, join.redundantSchema.Columns) - if len(redundantList) > len(list) { - list = redundantList - } + // For sql like `select t1.*, t2.* from t1 join t2 using(a)` or `select t1.*, t2.* from t1 natual join t2`, + // the schema of the Join doesn't contain enough columns because the join keys are coalesced in this schema. + // We should collect the columns from the fullSchema. + if isJoin && join.fullSchema != nil && field.WildCard.Table.L != "" { + list = unfoldWildStar(field, join.fullNames, join.fullSchema.Columns) } if len(list) == 0 { return nil, ErrBadTable.GenWithStackByArgs(field.WildCard.Table) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 173de775558bb..78a41ee2bf826 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -146,10 +146,22 @@ type LogicalJoin struct { // Currently, only `aggregation push down` phase will set this. DefaultValues []types.Datum - // redundantSchema contains columns which are eliminated in join. - // For select * from a join b using (c); a.c will in output schema, and b.c will only in redundantSchema. - redundantSchema *expression.Schema - redundantNames types.NameSlice + // fullSchema contains all the columns that the Join can output. It's ordered as [outer schema..., inner schema...]. + // This is useful for natural joins and "using" joins. In these cases, the join key columns from the + // inner side (or the right side when it's an inner join) will not be in the schema of Join. + // But upper operators should be able to find those "redundant" columns, and the user also can specifically select + // those columns, so we put the "redundant" columns here to make them be able to be found. + // + // For example: + // create table t1(a int, b int); create table t2(a int, b int); + // select * from t1 join t2 using (b); + // schema of the Join will be [t1.b, t1.a, t2.a]; fullSchema will be [t1.a, t1.b, t2.a, t2.b]. + // + // We record all columns and keep them ordered is for correctly handling SQLs like + // select t1.*, t2.* from t1 join t2 using (b); + // (*PlanBuilder).unfoldWildStar() handles the schema for such case. + fullSchema *expression.Schema + fullNames types.NameSlice // equalCondOutCnt indicates the estimated count of joined rows after evaluating `EqualConditions`. equalCondOutCnt float64