Skip to content

Commit

Permalink
planner: change redundantSchema to fullSchema to correctly handle…
Browse files Browse the repository at this point in the history
… natural and "using" joins (#29599) (#30041)

close #29481
  • Loading branch information
ti-srebot authored Feb 22, 2022
1 parent 2843c47 commit b6d9a05
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 54 deletions.
33 changes: 33 additions & 0 deletions executor/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,39 @@ func (s *testSuiteJoin1) TestUsing(c *C) {
tk.MustQuery("select t1.t0, t2.t0 from t1 join t2 using(t0) having t1.t0 > 0").Check(testkit.Rows("1 1"))
}

func (s *testSuiteWithData) TestUsingAndNaturalJoinSchema(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1, t2, t3, t4")
tk.MustExec("create table t1 (c int, b int);")
tk.MustExec("create table t2 (a int, b int);")
tk.MustExec("create table t3 (b int, c int);")
tk.MustExec("create table t4 (y int, c int);")

tk.MustExec("insert into t1 values (10,1);")
tk.MustExec("insert into t1 values (3 ,1);")
tk.MustExec("insert into t1 values (3 ,2);")
tk.MustExec("insert into t2 values (2, 1);")
tk.MustExec("insert into t3 values (1, 3);")
tk.MustExec("insert into t3 values (1,10);")
tk.MustExec("insert into t4 values (11,3);")
tk.MustExec("insert into t4 values (2, 3);")

var input []string
var output []struct {
SQL string
Res []string
}
s.testData.GetTestCases(c, &input, &output)
for i, tt := range input {
s.testData.OnRecord(func() {
output[i].SQL = tt
output[i].Res = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Sort().Rows())
})
tk.MustQuery(tt).Sort().Check(testkit.Rows(output[i].Res...))
}
}

func (s *testSuiteWithData) TestNaturalJoin(c *C) {
tk := testkit.NewTestKit(c, s.store)

Expand Down
19 changes: 19 additions & 0 deletions executor/testdata/executor_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@
"SELECT * FROM t1 NATURAL LEFT JOIN t2 WHERE not(t1.a <=> t2.a)"
]
},
{
"name": "TestUsingAndNaturalJoinSchema",
"cases": [
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) natural join (t3 natural join t4);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (b);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c,b);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (b);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c,b);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (b);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c);",
"select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);",
"select * from (t1 natural join t2) natural join (t3 natural join t4);",
"select * from (t1 natural join t2) join (t3 natural join t4) using (b);",
"select * from (t1 natural join t2) left outer join (t3 natural join t4) using (b);",
"select * from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);"
]
},
{
"name": "TestIndexScanWithYearCol",
"cases": [
Expand Down
115 changes: 115 additions & 0 deletions executor/testdata/executor_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,121 @@
}
]
},
{
"Name": "TestUsingAndNaturalJoinSchema",
"Cases": [
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) natural join (t3 natural join t4);",
"Res": [
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (b);",
"Res": [
"10 1 2 1 1 3 11 3",
"10 1 2 1 1 3 2 3",
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c);",
"Res": [
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) join (t3 natural join t4) using (c,b);",
"Res": [
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (b);",
"Res": [
"10 1 2 1 1 3 11 3",
"10 1 2 1 1 3 2 3",
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c);",
"Res": [
"10 1 2 1 <nil> <nil> <nil> <nil>",
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) left outer join (t3 natural join t4) using (c,b);",
"Res": [
"10 1 2 1 <nil> <nil> <nil> <nil>",
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (b);",
"Res": [
"10 1 2 1 1 3 11 3",
"10 1 2 1 1 3 2 3",
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c);",
"Res": [
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select t1.*, t2.*, t3.*, t4.* from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);",
"Res": [
"3 1 2 1 1 3 11 3",
"3 1 2 1 1 3 2 3"
]
},
{
"SQL": "select * from (t1 natural join t2) natural join (t3 natural join t4);",
"Res": [
"1 3 2 11",
"1 3 2 2"
]
},
{
"SQL": "select * from (t1 natural join t2) join (t3 natural join t4) using (b);",
"Res": [
"1 10 2 3 11",
"1 10 2 3 2",
"1 3 2 3 11",
"1 3 2 3 2"
]
},
{
"SQL": "select * from (t1 natural join t2) left outer join (t3 natural join t4) using (b);",
"Res": [
"1 10 2 3 11",
"1 10 2 3 2",
"1 3 2 3 11",
"1 3 2 3 2"
]
},
{
"SQL": "select * from (t1 natural join t2) right outer join (t3 natural join t4) using (c,b);",
"Res": [
"3 1 11 2",
"3 1 2 2"
]
}
]
},
{
"Name": "TestIndexScanWithYearCol",
"Cases": [
Expand Down
6 changes: 3 additions & 3 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1865,13 +1865,13 @@ func findFieldNameFromNaturalUsingJoin(p LogicalPlan, v *ast.ColumnName) (col *e
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return findFieldNameFromNaturalUsingJoin(p.Children()[0], v)
case *LogicalJoin:
if x.redundantSchema != nil {
idx, err := expression.FindFieldName(x.redundantNames, v)
if x.fullSchema != nil {
idx, err := expression.FindFieldName(x.fullNames, v)
if err != nil {
return nil, nil, err
}
if idx >= 0 {
return x.redundantSchema.Columns[idx], x.redundantNames[idx], nil
return x.fullSchema.Columns[idx], x.fullNames[idx], nil
}
}
}
Expand Down
102 changes: 55 additions & 47 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
schema4Agg.Append(newCol)
names = append(names, p.OutputNames()[i])
}
if join, isJoin := p.(*LogicalJoin); isJoin && join.redundantSchema != nil {
for i, col := range join.redundantSchema.Columns {
if join, isJoin := p.(*LogicalJoin); isJoin && join.fullSchema != nil {
for i, col := range join.fullSchema.Columns {
if p.Schema().Contains(col) {
continue
}
Expand All @@ -289,7 +289,7 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
newCol, _ := col.Clone().(*expression.Column)
newCol.RetType = newFunc.RetTp
schema4Agg.Append(newCol)
names = append(names, join.redundantNames[i])
names = append(names, join.fullNames[i])
}
}
hasGroupBy := len(gbyItems) > 0
Expand Down Expand Up @@ -720,25 +720,50 @@ func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (Logica
joinPlan.JoinType = InnerJoin
}

// Merge sub join's redundantSchema into this join plan. When handle query like
// select t2.a from (t1 join t2 using (a)) join t3 using (a);
// we can simply search in the top level join plan to find redundant column.
// Merge sub-plan's fullSchema into this join plan.
// Please read the comment of LogicalJoin.fullSchema for the details.
var (
lRedundantSchema, rRedundantSchema *expression.Schema
lRedundantNames, rRedundantNames types.NameSlice
lFullSchema, rFullSchema *expression.Schema
lFullNames, rFullNames types.NameSlice
)
if left, ok := leftPlan.(*LogicalJoin); ok && left.redundantSchema != nil {
lRedundantSchema = left.redundantSchema
lRedundantNames = left.redundantNames
if left, ok := leftPlan.(*LogicalJoin); ok && left.fullSchema != nil {
lFullSchema = left.fullSchema
lFullNames = left.fullNames
} else {
lFullSchema = leftPlan.Schema()
lFullNames = leftPlan.OutputNames()
}
if right, ok := rightPlan.(*LogicalJoin); ok && right.fullSchema != nil {
rFullSchema = right.fullSchema
rFullNames = right.fullNames
} else {
rFullSchema = rightPlan.Schema()
rFullNames = rightPlan.OutputNames()
}
if right, ok := rightPlan.(*LogicalJoin); ok && right.redundantSchema != nil {
rRedundantSchema = right.redundantSchema
rRedundantNames = right.redundantNames
if joinNode.Tp == ast.RightJoin {
// Make sure lFullSchema means outer full schema and rFullSchema means inner full schema.
lFullSchema, rFullSchema = rFullSchema, lFullSchema
lFullNames, rFullNames = rFullNames, lFullNames
}
joinPlan.fullSchema = expression.MergeSchema(lFullSchema, rFullSchema)

// Clear NotNull flag for the inner side schema if it's an outer join.
if joinNode.Tp == ast.LeftJoin || joinNode.Tp == ast.RightJoin {
resetNotNullFlag(joinPlan.fullSchema, lFullSchema.Len(), joinPlan.fullSchema.Len())
}

// Merge sub-plan's fullNames into this join plan, similar to the fullSchema logic above.
joinPlan.fullNames = make([]*types.FieldName, 0, len(lFullNames)+len(rFullNames))
for _, lName := range lFullNames {
name := *lName
name.Redundant = true
joinPlan.fullNames = append(joinPlan.fullNames, &name)
}
for _, rName := range rFullNames {
name := *rName
name.Redundant = true
joinPlan.fullNames = append(joinPlan.fullNames, &name)
}
joinPlan.redundantSchema = expression.MergeSchema(lRedundantSchema, rRedundantSchema)
joinPlan.redundantNames = make([]*types.FieldName, len(lRedundantNames)+len(rRedundantNames))
copy(joinPlan.redundantNames, lRedundantNames)
copy(joinPlan.redundantNames[len(lRedundantNames):], rRedundantNames)

// Set preferred join algorithm if some join hints is specified by user.
joinPlan.setPreferredJoinType(b.TableHints())
Expand Down Expand Up @@ -941,21 +966,7 @@ func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan

p.SetSchema(expression.NewSchema(schemaCols...))
p.names = names
// We record the full `rightPlan.Schema` as `redundantSchema` in order to
// record the redundant column in `rightPlan` and the output columns order
// of the `rightPlan`.
// For SQL like `select t1.*, t2.* from t1 left join t2 using(a)`, we can
// retrieve the column order of `t2.*` from the `redundantSchema`.
p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rightPlan.Schema().Clone().Columns...))
p.redundantNames = p.redundantNames.Shallow()
for _, name := range rightPlan.OutputNames() {
cpyName := *name
cpyName.Redundant = true
p.redundantNames = append(p.redundantNames, &cpyName)
}
if joinTp == ast.RightJoin || joinTp == ast.LeftJoin {
resetNotNullFlag(p.redundantSchema, 0, p.redundantSchema.Len())
}

p.OtherConditions = append(conds, p.OtherConditions...)

return nil
Expand Down Expand Up @@ -1208,9 +1219,9 @@ func findColFromNaturalUsingJoin(p LogicalPlan, col *expression.Column) (name *t
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return findColFromNaturalUsingJoin(p.Children()[0], col)
case *LogicalJoin:
if x.redundantSchema != nil {
idx := x.redundantSchema.ColumnIndex(col)
return x.redundantNames[idx]
if x.fullSchema != nil {
idx := x.fullSchema.ColumnIndex(col)
return x.fullNames[idx]
}
}
return nil
Expand Down Expand Up @@ -1996,9 +2007,9 @@ func (a *havingWindowAndOrderbyExprResolver) resolveFromPlan(v *ast.ColumnNameEx
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return a.resolveFromPlan(v, p.Children()[0])
case *LogicalJoin:
if len(x.redundantNames) != 0 {
idx, err = expression.FindFieldName(x.redundantNames, v.Name)
schemaCols, outputNames = x.redundantSchema.Columns, x.redundantNames
if len(x.fullNames) != 0 {
idx, err = expression.FindFieldName(x.fullNames, v.Name)
schemaCols, outputNames = x.fullSchema.Columns, x.fullNames
}
}
if err != nil || idx < 0 {
Expand Down Expand Up @@ -3146,14 +3157,11 @@ func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectFi
return nil, ErrInvalidWildCard
}
list := unfoldWildStar(field, p.OutputNames(), p.Schema().Columns)
// For sql like `select t1.*, t2.* from t1 join t2 using(a)`, we should
// not coalesce the `t2.a` in the output result. Thus we need to unfold
// the wildstar from the underlying join.redundantSchema.
if isJoin && join.redundantSchema != nil && field.WildCard.Table.L != "" {
redundantList := unfoldWildStar(field, join.redundantNames, join.redundantSchema.Columns)
if len(redundantList) > len(list) {
list = redundantList
}
// For sql like `select t1.*, t2.* from t1 join t2 using(a)` or `select t1.*, t2.* from t1 natual join t2`,
// the schema of the Join doesn't contain enough columns because the join keys are coalesced in this schema.
// We should collect the columns from the fullSchema.
if isJoin && join.fullSchema != nil && field.WildCard.Table.L != "" {
list = unfoldWildStar(field, join.fullNames, join.fullSchema.Columns)
}
if len(list) == 0 {
return nil, ErrBadTable.GenWithStackByArgs(field.WildCard.Table)
Expand Down
20 changes: 16 additions & 4 deletions planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,22 @@ type LogicalJoin struct {
// Currently, only `aggregation push down` phase will set this.
DefaultValues []types.Datum

// redundantSchema contains columns which are eliminated in join.
// For select * from a join b using (c); a.c will in output schema, and b.c will only in redundantSchema.
redundantSchema *expression.Schema
redundantNames types.NameSlice
// fullSchema contains all the columns that the Join can output. It's ordered as [outer schema..., inner schema...].
// This is useful for natural joins and "using" joins. In these cases, the join key columns from the
// inner side (or the right side when it's an inner join) will not be in the schema of Join.
// But upper operators should be able to find those "redundant" columns, and the user also can specifically select
// those columns, so we put the "redundant" columns here to make them be able to be found.
//
// For example:
// create table t1(a int, b int); create table t2(a int, b int);
// select * from t1 join t2 using (b);
// schema of the Join will be [t1.b, t1.a, t2.a]; fullSchema will be [t1.a, t1.b, t2.a, t2.b].
//
// We record all columns and keep them ordered is for correctly handling SQLs like
// select t1.*, t2.* from t1 join t2 using (b);
// (*PlanBuilder).unfoldWildStar() handles the schema for such case.
fullSchema *expression.Schema
fullNames types.NameSlice

// equalCondOutCnt indicates the estimated count of joined rows after evaluating `EqualConditions`.
equalCondOutCnt float64
Expand Down

0 comments on commit b6d9a05

Please sign in to comment.