diff --git a/enginetest/join_op_tests.go b/enginetest/join_op_tests.go index 92fa0aef82..b7bb15c547 100644 --- a/enginetest/join_op_tests.go +++ b/enginetest/join_op_tests.go @@ -68,6 +68,12 @@ var JoinOpTests = []struct { types: []plan.JoinType{plan.JoinTypeLeftOuterMerge}, exp: []sql.Row{{0, 0, 1, 0}, {1, 0, 1, 0}, {2, 0, 1, 0}, {4, 4, nil, nil}, {5, 4, nil, nil}}, }, + { + // extra join condition does not filter left-only rows + q: "select /*+ JOIN_ORDER(rs, xy) */ * from rs left join xy on y = s and y+s = 0 order by 1, 3", + types: []plan.JoinType{plan.JoinTypeLeftOuterMerge}, + exp: []sql.Row{{0, 0, 1, 0}, {1, 0, 1, 0}, {2, 0, 1, 0}, {4, 4, nil, nil}, {5, 4, nil, nil}}, + }, { q: "select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y = r order by 1, 3", types: []plan.JoinType{plan.JoinTypeMerge}, diff --git a/enginetest/queries/query_plans.go b/enginetest/queries/query_plans.go index e80b557b1e..c301715f00 100644 --- a/enginetest/queries/query_plans.go +++ b/enginetest/queries/query_plans.go @@ -71,7 +71,7 @@ var PlanTests = []QueryPlanTest{ Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs left join xy on y = s order by 1, 3`, ExpectedPlan: "Sort(rs.r:0!null ASC nullsFirst, xy.x:2 ASC nullsFirst)\n" + " └─ LeftOuterMergeJoin\n" + - " ├─ Eq\n" + + " ├─ cmp: Eq\n" + " │ ├─ rs.s:1\n" + " │ └─ xy.y:3\n" + " ├─ IndexedTableAccess\n" + @@ -103,7 +103,7 @@ var PlanTests = []QueryPlanTest{ " │ ├─ outerVisibility: false\n" + " │ ├─ cacheable: true\n" + " │ └─ MergeJoin\n" + - " │ ├─ Eq\n" + + " │ ├─ cmp: Eq\n" + " │ │ ├─ ab.a:0!null\n" + " │ │ └─ xy.y:3\n" + " │ ├─ IndexedTableAccess\n" + @@ -131,7 +131,7 @@ var PlanTests = []QueryPlanTest{ { Query: `select /*+ JOIN_ORDER(ab, xy) */ * from ab join xy on y = a`, ExpectedPlan: "MergeJoin\n" + - " ├─ Eq\n" + + " ├─ cmp: Eq\n" + " │ ├─ ab.a:0!null\n" + " │ └─ xy.y:3\n" + " ├─ IndexedTableAccess\n" + @@ -154,7 +154,7 @@ var PlanTests = []QueryPlanTest{ Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y = s order by 1, 3`, ExpectedPlan: "Sort(rs.r:0!null ASC nullsFirst, xy.x:2!null ASC nullsFirst)\n" + " └─ MergeJoin\n" + - " ├─ Eq\n" + + " ├─ cmp: Eq\n" + " │ ├─ rs.s:1\n" + " │ └─ xy.y:3\n" + " ├─ IndexedTableAccess\n" + @@ -176,7 +176,7 @@ var PlanTests = []QueryPlanTest{ { Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y = s`, ExpectedPlan: "MergeJoin\n" + - " ├─ Eq\n" + + " ├─ cmp: Eq\n" + " │ ├─ rs.s:1\n" + " │ └─ xy.y:3\n" + " ├─ IndexedTableAccess\n" + @@ -198,7 +198,7 @@ var PlanTests = []QueryPlanTest{ { Query: `select /*+ JOIN_ORDER(rs, xy) */ * from rs join xy on y+10 = s`, ExpectedPlan: "MergeJoin\n" + - " ├─ Eq\n" + + " ├─ cmp: Eq\n" + " │ ├─ rs.s:1\n" + " │ └─ (xy.y:3 + 10 (tinyint))\n" + " ├─ IndexedTableAccess\n" + @@ -1465,7 +1465,7 @@ inner join pq on true ExpectedPlan: "Project\n" + " ├─ columns: [t1.i:0!null]\n" + " └─ MergeJoin\n" + - " ├─ Eq\n" + + " ├─ cmp: Eq\n" + " │ ├─ t1.i:0!null\n" + " │ └─ (t2.i:1!null + 1 (tinyint))\n" + " ├─ Filter\n" + @@ -16863,16 +16863,15 @@ FROM " │ │ ├─ J4JYP.ZH72S:27\n" + " │ │ └─ TIZHK.TVNW2:1\n" + " │ ├─ LeftOuterMergeJoin\n" + - " │ │ ├─ AND\n" + + " │ │ ├─ cmp: Eq\n" + + " │ │ │ ├─ TIZHK.TVNW2:1\n" + + " │ │ │ └─ NHMXW.NOHHR:11!null\n" + + " │ │ ├─ sel: AND\n" + " │ │ │ ├─ AND\n" + " │ │ │ │ ├─ AND\n" + - " │ │ │ │ │ ├─ AND\n" + - " │ │ │ │ │ │ ├─ Eq\n" + - " │ │ │ │ │ │ │ ├─ TIZHK.TVNW2:1\n" + - " │ │ │ │ │ │ │ └─ NHMXW.NOHHR:11!null\n" + - " │ │ │ │ │ │ └─ Eq\n" + - " │ │ │ │ │ │ ├─ NHMXW.SWCQV:17!null\n" + - " │ │ │ │ │ │ └─ 0 (tinyint)\n" + + " │ │ │ │ │ ├─ Eq\n" + + " │ │ │ │ │ │ ├─ NHMXW.SWCQV:17!null\n" + + " │ │ │ │ │ │ └─ 0 (tinyint)\n" + " │ │ │ │ │ └─ Eq\n" + " │ │ │ │ │ ├─ NHMXW.AVPYF:12!null\n" + " │ │ │ │ │ └─ TIZHK.ZHITY:2\n" + @@ -17125,15 +17124,14 @@ WHERE " │ │ └─ Project\n" + " │ │ ├─ columns: [uct.NO52D:7, uct.VYO5E:9, uct.ZH72S:2, I7HCR.FVUCX:17]\n" + " │ │ └─ LeftOuterMergeJoin\n" + - " │ │ ├─ AND\n" + + " │ │ ├─ cmp: Eq\n" + + " │ │ │ ├─ uct.FTQLQ:1\n" + + " │ │ │ └─ I7HCR.TOFPN:14!null\n" + + " │ │ ├─ sel: AND\n" + " │ │ │ ├─ AND\n" + - " │ │ │ │ ├─ AND\n" + - " │ │ │ │ │ ├─ Eq\n" + - " │ │ │ │ │ │ ├─ uct.FTQLQ:1\n" + - " │ │ │ │ │ │ └─ I7HCR.TOFPN:14!null\n" + - " │ │ │ │ │ └─ Eq\n" + - " │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" + - " │ │ │ │ │ └─ 0 (tinyint)\n" + + " │ │ │ │ ├─ Eq\n" + + " │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" + + " │ │ │ │ │ └─ 0 (tinyint)\n" + " │ │ │ │ └─ Eq\n" + " │ │ │ │ ├─ I7HCR.SJYN2:15!null\n" + " │ │ │ │ └─ uct.ZH72S:2\n" + @@ -17419,15 +17417,14 @@ WHERE " │ │ │ └─ N/A (longtext)\n" + " │ │ │ )) THEN uct.FHCYT:11 ELSE NULL (null) END as FHCYT, uct.ZH72S:2 as K3B6V, uct.LJLUM:5 as BTXC5, I7HCR.FVUCX:17 as H4DMT]\n" + " │ │ └─ LeftOuterMergeJoin\n" + - " │ │ ├─ AND\n" + + " │ │ ├─ cmp: Eq\n" + + " │ │ │ ├─ uct.FTQLQ:1\n" + + " │ │ │ └─ I7HCR.TOFPN:14!null\n" + + " │ │ ├─ sel: AND\n" + " │ │ │ ├─ AND\n" + - " │ │ │ │ ├─ AND\n" + - " │ │ │ │ │ ├─ Eq\n" + - " │ │ │ │ │ │ ├─ uct.FTQLQ:1\n" + - " │ │ │ │ │ │ └─ I7HCR.TOFPN:14!null\n" + - " │ │ │ │ │ └─ Eq\n" + - " │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" + - " │ │ │ │ │ └─ 0 (tinyint)\n" + + " │ │ │ │ ├─ Eq\n" + + " │ │ │ │ │ ├─ I7HCR.SWCQV:18!null\n" + + " │ │ │ │ │ └─ 0 (tinyint)\n" + " │ │ │ │ └─ Eq\n" + " │ │ │ │ ├─ I7HCR.SJYN2:15!null\n" + " │ │ │ │ └─ uct.ZH72S:2\n" + diff --git a/sql/plan/join.go b/sql/plan/join.go index f1cae65a6b..72c62dd672 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -19,6 +19,8 @@ import ( "os" "strings" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql" ) @@ -100,6 +102,15 @@ func (i JoinType) IsDegenerate() bool { i == JoinTypeCross } +func (i JoinType) IsMerge() bool { + switch i { + case JoinTypeMerge, JoinTypeSemiMerge, JoinTypeAntiMerge, JoinTypeLeftOuterMerge: + return true + default: + return false + } +} + func (i JoinType) IsRightPartial() bool { switch i { case JoinTypeRightSemi, JoinTypeRightSemiLookup: @@ -150,11 +161,6 @@ func (i JoinType) IsLookup() bool { i == JoinTypeLeftOuterLookup } -func (i JoinType) IsMerge() bool { - return i == JoinTypeMerge || - i == JoinTypeLeftOuterMerge -} - func (i JoinType) IsCross() bool { return i == JoinTypeCross } @@ -311,7 +317,15 @@ func (j *JoinNode) String() string { pr := sql.NewTreePrinter() var children []string if j.Filter != nil { - children = append(children, j.Filter.String()) + if j.Op.IsMerge() { + filters := expression.SplitConjunction(j.Filter) + children = append(children, fmt.Sprintf("cmp: %s", filters[0])) + if len(filters) > 1 { + children = append(children, fmt.Sprintf("sel: %s", expression.JoinAnd(filters[1:]...))) + } + } else { + children = append(children, j.Filter.String()) + } } children = append(children, j.left.String(), j.right.String()) pr.WriteNode("%s", j.Op) @@ -323,7 +337,15 @@ func (j *JoinNode) DebugString() string { pr := sql.NewTreePrinter() var children []string if j.Filter != nil { - children = append(children, sql.DebugString(j.Filter)) + if j.Op.IsMerge() { + filters := expression.SplitConjunction(j.Filter) + children = append(children, fmt.Sprintf("cmp: %s", sql.DebugString(filters[0]))) + if len(filters) > 1 { + children = append(children, fmt.Sprintf("sel: %s", sql.DebugString(expression.JoinAnd(filters[1:]...)))) + } + } else { + children = append(children, sql.DebugString(j.Filter)) + } } children = append(children, sql.DebugString(j.left), sql.DebugString(j.right)) pr.WriteNode("%s", j.Op) diff --git a/sql/plan/join_iters.go b/sql/plan/join_iters.go index 5abd4417f7..b07a9546ad 100644 --- a/sql/plan/join_iters.go +++ b/sql/plan/join_iters.go @@ -269,11 +269,23 @@ func (i *existsIter) loadSecondary(ctx *sql.Context, left sql.Row) (row sql.Row, return iter.Next(ctx) } +type existsState uint8 + +const ( + esIncLeft existsState = iota + esIncRight + esRightIterEOF + esCompare + esRet +) + func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) { var row sql.Row var matches bool var right sql.Row var left sql.Row + var rIter sql.RowIter + var err error // the common sequence is: LOAD_LEFT -> LOAD_RIGHT -> COMPARE -> RET // notable exceptions are represented as goto jumps: @@ -282,64 +294,78 @@ func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) { // - antiJoin succeeds to RET when LOAD_RIGHT EOF's // - semiJoin fails when LOAD_RIGHT EOF's, falling back to LOAD_LEFT // - antiJoin fails when COMPARE returns true, falling back to LOAD_LEFT - goto LOAD_LEFT -LOAD_LEFT: - r, err := i.primary.Next(ctx) - if err != nil { - return nil, err - } - left = i.parentRow.Append(r) - rIter, err := i.secondaryProvider.RowIter(ctx, left) - if err != nil { - return nil, err - } - if isEmptyIter(rIter) { - if i.nullRej { - return nil, io.EOF - } - goto COMPARE - } - goto LOAD_RIGHT -LOAD_RIGHT: - right, err = rIter.Next(ctx) - if err != nil { - iterErr := rIter.Close(ctx) - if iterErr != nil { - return nil, fmt.Errorf("%w; error on close: %s", err, iterErr) - } - if errors.Is(err, io.EOF) { + nextState := esIncLeft + for { + switch nextState { + case esIncLeft: + r, err := i.primary.Next(ctx) + if err != nil { + return nil, err + } + left = i.parentRow.Append(r) + rIter, err = i.secondaryProvider.RowIter(ctx, left) + if err != nil { + return nil, err + } + if isEmptyIter(rIter) { + if i.nullRej { + return nil, io.EOF + } + nextState = esCompare + } else { + nextState = esIncRight + } + case esIncRight: + right, err = rIter.Next(ctx) + if err != nil { + iterErr := rIter.Close(ctx) + if iterErr != nil { + return nil, fmt.Errorf("%w; error on close: %s", err, iterErr) + } + if errors.Is(err, io.EOF) { + nextState = esRightIterEOF + } else { + return nil, err + } + } else { + nextState = esCompare + } + case esRightIterEOF: if i.typ.IsSemi() { // reset iter, no match - goto LOAD_LEFT + nextState = esIncLeft + } else { + nextState = esRet + } + case esCompare: + row = i.buildRow(left, right) + matches, err = conditionIsTrue(ctx, row, i.cond) + if err != nil { + return nil, err + } + if !matches { + nextState = esIncRight + } else { + err = rIter.Close(ctx) + if err != nil { + return nil, err + } + if i.typ.IsAnti() { + // reset iter, found match -> no return row + nextState = esIncLeft + } else { + nextState = esRet + } + } + case esRet: + if i.typ.IsRightPartial() { + return append(left[:i.scopeLen], right...), nil } - goto RET + return i.removeParentRow(left), nil + default: + return nil, fmt.Errorf("invalid exists join state") } - return nil, err - } - goto COMPARE -COMPARE: - row = i.buildRow(left, right) - matches, err = conditionIsTrue(ctx, row, i.cond) - if err != nil { - return nil, err - } - if !matches { - goto LOAD_RIGHT - } - err = rIter.Close(ctx) - if err != nil { - return nil, err - } - if i.typ.IsAnti() { - // reset iter, found match -> no return row - goto LOAD_LEFT - } - goto RET -RET: - if i.typ.IsRightPartial() { - return append(left[:i.scopeLen], right...), nil } - return i.removeParentRow(left), nil } func (i *existsIter) removeParentRow(r sql.Row) sql.Row { diff --git a/sql/plan/merge_join.go b/sql/plan/merge_join.go index 3b569ef0fe..992fe1be63 100644 --- a/sql/plan/merge_join.go +++ b/sql/plan/merge_join.go @@ -52,17 +52,12 @@ func newMergeJoinIter(ctx *sql.Context, j *JoinNode, row sql.Row) (sql.RowIter, copy(fullRow[0:], row[:]) } - var first expression.Comparer - var filters []sql.Expression - for i, f := range expression.SplitConjunction(j.Filter) { - c, ok := f.(expression.Comparer) - if !ok { - return nil, sql.ErrMergeJoinExpectsComparerFilters.New(f) - } - if i == 0 { - first = c - } - filters = append(filters, c) + // a merge join's first filter provides direction information + // for which iter to update next + filters := expression.SplitConjunction(j.Filter) + cmp, ok := filters[0].(expression.Comparer) + if !ok { + return nil, sql.ErrMergeJoinExpectsComparerFilters.New(filters[0]) } if len(filters) == 0 { @@ -72,27 +67,25 @@ func newMergeJoinIter(ctx *sql.Context, j *JoinNode, row sql.Row) (sql.RowIter, var iter sql.RowIter = &mergeJoinIter{ left: l, right: r, - expr: first, + filters: filters[1:], + cmp: cmp, typ: j.Op, fullRow: fullRow, scopeLen: j.ScopeLen, leftRowLen: len(j.left.Schema()), rightRowLen: len(j.right.Schema()), } - if len(filters) > 1 { - iter = NewFilterIter(expression.JoinAnd(filters...), iter) - } return iter, nil } // mergeJoinIter alternates incrementing two RowIters, assuming // rows will be provided in a sorted order given the join |expr| -// (see findSortedIndexScanForRel). -// TODO: my first iteration of this saves the future row state in Next(), -// but it might be more appropriate to save the historical state for -// future iterators to switch on +// (see sortedIndexScanForTableCol). Extra join |filters| that do +// not provide a directional ordering signal for index iteration +// are evaluated separately. type mergeJoinIter struct { - expr expression.Comparer + cmp expression.Comparer + filters []sql.Expression left sql.RowIter right sql.RowIter fullRow sql.Row @@ -117,52 +110,107 @@ type mergeJoinIter struct { parentLen int } -func (i *mergeJoinIter) Next(ctx *sql.Context) (sql.Row, error) { - if !i.init { - err := i.initIters(ctx) +func (i *mergeJoinIter) sel(ctx *sql.Context, row sql.Row) (bool, error) { + for _, f := range i.filters { + res, err := sql.EvaluateCondition(ctx, f, row) if err != nil { - return nil, err + return false, err + } + + if !sql.IsTrue(res) { + return false, nil } } + return true, nil +} + +type mergeState uint8 + +const ( + msInit mergeState = iota + msExhaustCheck + msCompare + msIncLeft + msIncRight + msSelect + msRet + msRetLeft +) + +func (i *mergeJoinIter) Next(ctx *sql.Context) (sql.Row, error) { + var err error + var ret sql.Row + var res int + + nextState := msInit for { - if i.lojFinalize() { - ret := i.copyReturnRow() - err := i.incLeft(ctx) + switch nextState { + case msInit: + if !i.init { + err = i.initIters(ctx) + if err != nil { + return nil, err + } + } + nextState = msExhaustCheck + case msExhaustCheck: + if i.lojFinalize() { + nextState = msRetLeft + } else if i.exhausted() { + return nil, io.EOF + } else { + nextState = msCompare + } + case msCompare: + res, err = i.cmp.Compare(ctx, i.fullRow) if err != nil { return nil, err } - return i.removeParentRow(i.nullifyRightRow(ret)), nil - } else if i.exhausted() { - return nil, io.EOF - } - - res, err := i.expr.Compare(ctx, i.fullRow) - if err != nil { - return nil, err - } - - switch { - case res < 0: - if i.typ.IsLeftOuter() { - ret := i.copyReturnRow() - err = i.incLeft(ctx) - return i.removeParentRow(i.nullifyRightRow(ret)), nil + switch { + case res < 0: + if i.typ.IsLeftOuter() { + nextState = msRetLeft + } + nextState = msIncLeft + case res > 0: + nextState = msIncRight + case res == 0: + nextState = msSelect } + case msIncLeft: err = i.incLeft(ctx) - case res > 0: + nextState = msExhaustCheck + case msIncRight: err = i.incRight(ctx) - case res == 0: - ret := i.copyReturnRow() + nextState = msExhaustCheck + case msSelect: + ret = i.copyReturnRow() + if ok, err := i.sel(ctx, ret); err != nil { + return nil, err + } else if !ok { + if i.typ.IsLeftOuter() { + nextState = msRetLeft + } else { + nextState = msIncLeft + } + } else { + nextState = msRet + } + case msRet: err = i.incMatch(ctx) if err != nil { return nil, err } return i.removeParentRow(ret), nil + return ret, nil + case msRetLeft: + ret = i.removeParentRow(i.nullifyRightRow(i.copyReturnRow())) + err = i.incLeft(ctx) + if err != nil { + return nil, err + } + return ret, nil } - if err != nil { - return nil, err - } - } } @@ -324,7 +372,7 @@ func (i *mergeJoinIter) peekMatch(ctx *sql.Context, iter sql.RowIter) (bool, sql // check if lookahead valid copySubslice(i.fullRow, peek, off) - res, err := i.expr.Compare(ctx, i.fullRow) + res, err := i.cmp.Compare(ctx, i.fullRow) if err != nil { return false, nil, err }