Skip to content

Commit

Permalink
Merge join selects do not filter left join (#1568)
Browse files Browse the repository at this point in the history
* Merge join selects do not filter left join

* zach's comments

* comment
  • Loading branch information
max-hoffman authored Jan 31, 2023
1 parent 642c208 commit 8e0c643
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 142 deletions.
6 changes: 6 additions & 0 deletions enginetest/join_op_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ var JoinOpTests = []struct {
types: []plan.JoinType{plan.JoinTypeLeftOuterMerge},
exp: []sql.Row{{0, 0, 1, 0}, {1, 0, 1, 0}, {2, 0, 1, 0}, {4, 4, nil, nil}, {5, 4, nil, nil}},
},
{
// extra join condition does not filter left-only rows
q: "select /*+ JOIN_ORDER(rs, xy) */ * from rs left join xy on y = s and y+s = 0 order by 1, 3",
types: []plan.JoinType{plan.JoinTypeLeftOuterMerge},
exp: []sql.Row{{0, 0, 1, 0}, {1, 0, 1, 0}, {2, 0, 1, 0}, {4, 4, nil, nil}, {5, 4, nil, nil}},
},
{
q: "select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y = r order by 1, 3",
types: []plan.JoinType{plan.JoinTypeMerge},
Expand Down
59 changes: 28 additions & 31 deletions enginetest/queries/query_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ var PlanTests = []QueryPlanTest{
Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs left join xy on y = s order by 1, 3`,
ExpectedPlan: "Sort(rs.r:0!null ASC nullsFirst, xy.x:2 ASC nullsFirst)\n" +
" └─ LeftOuterMergeJoin\n" +
" ├─ Eq\n" +
" ├─ cmp: Eq\n" +
" │ ├─ rs.s:1\n" +
" │ └─ xy.y:3\n" +
" ├─ IndexedTableAccess\n" +
Expand Down Expand Up @@ -103,7 +103,7 @@ var PlanTests = []QueryPlanTest{
" │ ├─ outerVisibility: false\n" +
" │ ├─ cacheable: true\n" +
" │ └─ MergeJoin\n" +
" │ ├─ Eq\n" +
" │ ├─ cmp: Eq\n" +
" │ │ ├─ ab.a:0!null\n" +
" │ │ └─ xy.y:3\n" +
" │ ├─ IndexedTableAccess\n" +
Expand Down Expand Up @@ -131,7 +131,7 @@ var PlanTests = []QueryPlanTest{
{
Query: `select /*+ JOIN_ORDER(ab, xy) */ * from ab join xy on y = a`,
ExpectedPlan: "MergeJoin\n" +
" ├─ Eq\n" +
" ├─ cmp: Eq\n" +
" │ ├─ ab.a:0!null\n" +
" │ └─ xy.y:3\n" +
" ├─ IndexedTableAccess\n" +
Expand All @@ -154,7 +154,7 @@ var PlanTests = []QueryPlanTest{
Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y = s order by 1, 3`,
ExpectedPlan: "Sort(rs.r:0!null ASC nullsFirst, xy.x:2!null ASC nullsFirst)\n" +
" └─ MergeJoin\n" +
" ├─ Eq\n" +
" ├─ cmp: Eq\n" +
" │ ├─ rs.s:1\n" +
" │ └─ xy.y:3\n" +
" ├─ IndexedTableAccess\n" +
Expand All @@ -176,7 +176,7 @@ var PlanTests = []QueryPlanTest{
{
Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y = s`,
ExpectedPlan: "MergeJoin\n" +
" ├─ Eq\n" +
" ├─ cmp: Eq\n" +
" │ ├─ rs.s:1\n" +
" │ └─ xy.y:3\n" +
" ├─ IndexedTableAccess\n" +
Expand All @@ -198,7 +198,7 @@ var PlanTests = []QueryPlanTest{
{
Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y+10 = s`,
ExpectedPlan: "MergeJoin\n" +
" ├─ Eq\n" +
" ├─ cmp: Eq\n" +
" │ ├─ rs.s:1\n" +
" │ └─ (xy.y:3 + 10 (tinyint))\n" +
" ├─ IndexedTableAccess\n" +
Expand Down Expand Up @@ -1465,7 +1465,7 @@ inner join pq on true
ExpectedPlan: "Project\n" +
" ├─ columns: [t1.i:0!null]\n" +
" └─ MergeJoin\n" +
" ├─ Eq\n" +
" ├─ cmp: Eq\n" +
" │ ├─ t1.i:0!null\n" +
" │ └─ (t2.i:1!null + 1 (tinyint))\n" +
" ├─ Filter\n" +
Expand Down Expand Up @@ -16863,16 +16863,15 @@ FROM
" │ │ ├─ J4JYP.ZH72S:27\n" +
" │ │ └─ TIZHK.TVNW2:1\n" +
" │ ├─ LeftOuterMergeJoin\n" +
" │ │ ├─ AND\n" +
" │ │ ├─ cmp: Eq\n" +
" │ │ │ ├─ TIZHK.TVNW2:1\n" +
" │ │ │ └─ NHMXW.NOHHR:11!null\n" +
" │ │ ├─ sel: AND\n" +
" │ │ │ ├─ AND\n" +
" │ │ │ │ ├─ AND\n" +
" │ │ │ │ │ ├─ AND\n" +
" │ │ │ │ │ │ ├─ Eq\n" +
" │ │ │ │ │ │ │ ├─ TIZHK.TVNW2:1\n" +
" │ │ │ │ │ │ │ └─ NHMXW.NOHHR:11!null\n" +
" │ │ │ │ │ │ └─ Eq\n" +
" │ │ │ │ │ │ ├─ NHMXW.SWCQV:17!null\n" +
" │ │ │ │ │ │ └─ 0 (tinyint)\n" +
" │ │ │ │ │ ├─ Eq\n" +
" │ │ │ │ │ │ ├─ NHMXW.SWCQV:17!null\n" +
" │ │ │ │ │ │ └─ 0 (tinyint)\n" +
" │ │ │ │ │ └─ Eq\n" +
" │ │ │ │ │ ├─ NHMXW.AVPYF:12!null\n" +
" │ │ │ │ │ └─ TIZHK.ZHITY:2\n" +
Expand Down Expand Up @@ -17125,15 +17124,14 @@ WHERE
" │ │ └─ Project\n" +
" │ │ ├─ columns: [uct.NO52D:7, uct.VYO5E:9, uct.ZH72S:2, I7HCR.FVUCX:17]\n" +
" │ │ └─ LeftOuterMergeJoin\n" +
" │ │ ├─ AND\n" +
" │ │ ├─ cmp: Eq\n" +
" │ │ │ ├─ uct.FTQLQ:1\n" +
" │ │ │ └─ I7HCR.TOFPN:14!null\n" +
" │ │ ├─ sel: AND\n" +
" │ │ │ ├─ AND\n" +
" │ │ │ │ ├─ AND\n" +
" │ │ │ │ │ ├─ Eq\n" +
" │ │ │ │ │ │ ├─ uct.FTQLQ:1\n" +
" │ │ │ │ │ │ └─ I7HCR.TOFPN:14!null\n" +
" │ │ │ │ │ └─ Eq\n" +
" │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" +
" │ │ │ │ │ └─ 0 (tinyint)\n" +
" │ │ │ │ ├─ Eq\n" +
" │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" +
" │ │ │ │ │ └─ 0 (tinyint)\n" +
" │ │ │ │ └─ Eq\n" +
" │ │ │ │ ├─ I7HCR.SJYN2:15!null\n" +
" │ │ │ │ └─ uct.ZH72S:2\n" +
Expand Down Expand Up @@ -17419,15 +17417,14 @@ WHERE
" │ │ │ └─ N/A (longtext)\n" +
" │ │ │ )) THEN uct.FHCYT:11 ELSE NULL (null) END as FHCYT, uct.ZH72S:2 as K3B6V, uct.LJLUM:5 as BTXC5, I7HCR.FVUCX:17 as H4DMT]\n" +
" │ │ └─ LeftOuterMergeJoin\n" +
" │ │ ├─ AND\n" +
" │ │ ├─ cmp: Eq\n" +
" │ │ │ ├─ uct.FTQLQ:1\n" +
" │ │ │ └─ I7HCR.TOFPN:14!null\n" +
" │ │ ├─ sel: AND\n" +
" │ │ │ ├─ AND\n" +
" │ │ │ │ ├─ AND\n" +
" │ │ │ │ │ ├─ Eq\n" +
" │ │ │ │ │ │ ├─ uct.FTQLQ:1\n" +
" │ │ │ │ │ │ └─ I7HCR.TOFPN:14!null\n" +
" │ │ │ │ │ └─ Eq\n" +
" │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" +
" │ │ │ │ │ └─ 0 (tinyint)\n" +
" │ │ │ │ ├─ Eq\n" +
" │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" +
" │ │ │ │ │ └─ 0 (tinyint)\n" +
" │ │ │ │ └─ Eq\n" +
" │ │ │ │ ├─ I7HCR.SJYN2:15!null\n" +
" │ │ │ │ └─ uct.ZH72S:2\n" +
Expand Down
36 changes: 29 additions & 7 deletions sql/plan/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"os"
"strings"

"github.com/dolthub/go-mysql-server/sql/expression"

"github.com/dolthub/go-mysql-server/sql"
)

Expand Down Expand Up @@ -100,6 +102,15 @@ func (i JoinType) IsDegenerate() bool {
i == JoinTypeCross
}

func (i JoinType) IsMerge() bool {
switch i {
case JoinTypeMerge, JoinTypeSemiMerge, JoinTypeAntiMerge, JoinTypeLeftOuterMerge:
return true
default:
return false
}
}

func (i JoinType) IsRightPartial() bool {
switch i {
case JoinTypeRightSemi, JoinTypeRightSemiLookup:
Expand Down Expand Up @@ -150,11 +161,6 @@ func (i JoinType) IsLookup() bool {
i == JoinTypeLeftOuterLookup
}

func (i JoinType) IsMerge() bool {
return i == JoinTypeMerge ||
i == JoinTypeLeftOuterMerge
}

func (i JoinType) IsCross() bool {
return i == JoinTypeCross
}
Expand Down Expand Up @@ -311,7 +317,15 @@ func (j *JoinNode) String() string {
pr := sql.NewTreePrinter()
var children []string
if j.Filter != nil {
children = append(children, j.Filter.String())
if j.Op.IsMerge() {
filters := expression.SplitConjunction(j.Filter)
children = append(children, fmt.Sprintf("cmp: %s", filters[0]))
if len(filters) > 1 {
children = append(children, fmt.Sprintf("sel: %s", expression.JoinAnd(filters[1:]...)))
}
} else {
children = append(children, j.Filter.String())
}
}
children = append(children, j.left.String(), j.right.String())
pr.WriteNode("%s", j.Op)
Expand All @@ -323,7 +337,15 @@ func (j *JoinNode) DebugString() string {
pr := sql.NewTreePrinter()
var children []string
if j.Filter != nil {
children = append(children, sql.DebugString(j.Filter))
if j.Op.IsMerge() {
filters := expression.SplitConjunction(j.Filter)
children = append(children, fmt.Sprintf("cmp: %s", sql.DebugString(filters[0])))
if len(filters) > 1 {
children = append(children, fmt.Sprintf("sel: %s", sql.DebugString(expression.JoinAnd(filters[1:]...))))
}
} else {
children = append(children, sql.DebugString(j.Filter))
}
}
children = append(children, sql.DebugString(j.left), sql.DebugString(j.right))
pr.WriteNode("%s", j.Op)
Expand Down
132 changes: 79 additions & 53 deletions sql/plan/join_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,23 @@ func (i *existsIter) loadSecondary(ctx *sql.Context, left sql.Row) (row sql.Row,
return iter.Next(ctx)
}

type existsState uint8

const (
esIncLeft existsState = iota
esIncRight
esRightIterEOF
esCompare
esRet
)

func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) {
var row sql.Row
var matches bool
var right sql.Row
var left sql.Row
var rIter sql.RowIter
var err error

// the common sequence is: LOAD_LEFT -> LOAD_RIGHT -> COMPARE -> RET
// notable exceptions are represented as goto jumps:
Expand All @@ -282,64 +294,78 @@ func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) {
// - antiJoin succeeds to RET when LOAD_RIGHT EOF's
// - semiJoin fails when LOAD_RIGHT EOF's, falling back to LOAD_LEFT
// - antiJoin fails when COMPARE returns true, falling back to LOAD_LEFT
goto LOAD_LEFT
LOAD_LEFT:
r, err := i.primary.Next(ctx)
if err != nil {
return nil, err
}
left = i.parentRow.Append(r)
rIter, err := i.secondaryProvider.RowIter(ctx, left)
if err != nil {
return nil, err
}
if isEmptyIter(rIter) {
if i.nullRej {
return nil, io.EOF
}
goto COMPARE
}
goto LOAD_RIGHT
LOAD_RIGHT:
right, err = rIter.Next(ctx)
if err != nil {
iterErr := rIter.Close(ctx)
if iterErr != nil {
return nil, fmt.Errorf("%w; error on close: %s", err, iterErr)
}
if errors.Is(err, io.EOF) {
nextState := esIncLeft
for {
switch nextState {
case esIncLeft:
r, err := i.primary.Next(ctx)
if err != nil {
return nil, err
}
left = i.parentRow.Append(r)
rIter, err = i.secondaryProvider.RowIter(ctx, left)
if err != nil {
return nil, err
}
if isEmptyIter(rIter) {
if i.nullRej {
return nil, io.EOF
}
nextState = esCompare
} else {
nextState = esIncRight
}
case esIncRight:
right, err = rIter.Next(ctx)
if err != nil {
iterErr := rIter.Close(ctx)
if iterErr != nil {
return nil, fmt.Errorf("%w; error on close: %s", err, iterErr)
}
if errors.Is(err, io.EOF) {
nextState = esRightIterEOF
} else {
return nil, err
}
} else {
nextState = esCompare
}
case esRightIterEOF:
if i.typ.IsSemi() {
// reset iter, no match
goto LOAD_LEFT
nextState = esIncLeft
} else {
nextState = esRet
}
case esCompare:
row = i.buildRow(left, right)
matches, err = conditionIsTrue(ctx, row, i.cond)
if err != nil {
return nil, err
}
if !matches {
nextState = esIncRight
} else {
err = rIter.Close(ctx)
if err != nil {
return nil, err
}
if i.typ.IsAnti() {
// reset iter, found match -> no return row
nextState = esIncLeft
} else {
nextState = esRet
}
}
case esRet:
if i.typ.IsRightPartial() {
return append(left[:i.scopeLen], right...), nil
}
goto RET
return i.removeParentRow(left), nil
default:
return nil, fmt.Errorf("invalid exists join state")
}
return nil, err
}
goto COMPARE
COMPARE:
row = i.buildRow(left, right)
matches, err = conditionIsTrue(ctx, row, i.cond)
if err != nil {
return nil, err
}
if !matches {
goto LOAD_RIGHT
}
err = rIter.Close(ctx)
if err != nil {
return nil, err
}
if i.typ.IsAnti() {
// reset iter, found match -> no return row
goto LOAD_LEFT
}
goto RET
RET:
if i.typ.IsRightPartial() {
return append(left[:i.scopeLen], right...), nil
}
return i.removeParentRow(left), nil
}

func (i *existsIter) removeParentRow(r sql.Row) sql.Row {
Expand Down
Loading

0 comments on commit 8e0c643

Please sign in to comment.