Skip to content

Commit

Permalink
Lateral join uses prepend row on RHS (#1954)
Browse files Browse the repository at this point in the history
* Lateral join uses prepend row on RHS

* [ga-format-pr] Run ./format_repo.sh to fix formatting

---------

Co-authored-by: max-hoffman <max-hoffman@users.noreply.github.com>
  • Loading branch information
max-hoffman and max-hoffman authored Aug 21, 2023
1 parent 95bd5fa commit 5d3d97c
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 87 deletions.
2 changes: 1 addition & 1 deletion optgen/cmd/optgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func main() {
case "frameFactory":
case "framer":
case "memo":
absPath, _ := filepath.Abs(path.Join("..", "go-mysql-server/optgen/cmd/source", "memo.yaml"))
absPath, _ := filepath.Abs(path.Join("../../../..", "go-mysql-server/optgen/cmd/source", "memo.yaml"))
defs, err = support.DecodeMemoExprs(absPath)
if err != nil {
exit(err)
Expand Down
6 changes: 1 addition & 5 deletions optgen/cmd/source/memo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ exprs:
- [swapCmp, "bool"]
- name: "FullOuterJoin"
join: true
- name: "LateralCrossJoin"
join: true
- name: "LateralInnerJoin"
join: true
- name: "LateralLeftJoin"
- name: "LateralJoin"
join: true
- name: "TableScan"
sourceType: "*plan.ResolvedTable"
Expand Down
24 changes: 5 additions & 19 deletions sql/memo/coster.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,8 @@ func (c *coster) costRel(ctx *sql.Context, n RelExpr, s sql.StatsReader) (float6
return c.costLookupJoin(ctx, n, s)
case *RangeHeapJoin:
return c.costRangeHeapJoin(ctx, n, s)
case *LateralCrossJoin:
return c.costLateralCrossJoin(ctx, n, s)
case *LateralInnerJoin:
return c.costLateralInnerJoin(ctx, n, s)
case *LateralLeftJoin:
return c.costLateralLeftJoin(ctx, n, s)
case *LateralJoin:
return c.costLateralJoin(ctx, n, s)
case *SemiJoin:
return c.costSemiJoin(ctx, n, s)
case *AntiJoin:
Expand Down Expand Up @@ -203,19 +199,7 @@ func (c *coster) costRangeHeapJoin(_ *sql.Context, n *RangeHeapJoin, _ sql.Stats
return l * expectedNumberOfOverlappingJoins * (seqIOCostFactor), nil
}

func (c *coster) costLateralCrossJoin(ctx *sql.Context, n *LateralCrossJoin, _ sql.StatsReader) (float64, error) {
l := n.Left.RelProps.card
r := n.Right.RelProps.card
return ((l*r-1)*seqIOCostFactor + (l*r)*cpuCostFactor) * degeneratePenalty, nil
}

func (c *coster) costLateralInnerJoin(ctx *sql.Context, n *LateralInnerJoin, _ sql.StatsReader) (float64, error) {
l := n.Left.RelProps.card
r := n.Right.RelProps.card
return (l*r-1)*seqIOCostFactor + (l*r)*cpuCostFactor, nil
}

func (c *coster) costLateralLeftJoin(ctx *sql.Context, n *LateralLeftJoin, _ sql.StatsReader) (float64, error) {
func (c *coster) costLateralJoin(ctx *sql.Context, n *LateralJoin, _ sql.StatsReader) (float64, error) {
l := n.Left.RelProps.card
r := n.Right.RelProps.card
return (l*r-1)*seqIOCostFactor + (l*r)*cpuCostFactor, nil
Expand Down Expand Up @@ -366,6 +350,8 @@ func (c *carder) cardRel(ctx *sql.Context, n RelExpr, s sql.StatsReader) (float6
sel += lookupJoinSelectivity(l)
}
return n.Left.RelProps.card * optimisticJoinSel * sel, nil
case *LateralJoin:
return n.Left.RelProps.card * n.Right.RelProps.card, nil
default:
}
if jp.Op.IsPartial() {
Expand Down
11 changes: 11 additions & 0 deletions sql/memo/exec_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,17 @@ func (b *ExecBuilder) buildMergeJoin(j *MergeJoin, input sql.Schema, children ..
return plan.NewJoin(inner, outer, j.Op, filters).WithScopeLen(j.g.m.scopeLen), nil
}

func (b *ExecBuilder) buildLateralJoin(j *LateralJoin, input sql.Schema, children ...sql.Node) (sql.Node, error) {
if len(j.Filter) == 0 {
return plan.NewCrossJoin(children[0], children[1]), nil
}
filters, err := b.buildFilterConjunction(j.g.m.scope, input, j.Filter...)
if err != nil {
return nil, err
}
return plan.NewJoin(children[0], children[1], j.Op.AsLateral(), filters), nil
}

func (b *ExecBuilder) buildSubqueryAlias(r *SubqueryAlias, input sql.Schema, children ...sql.Node) (sql.Node, error) {
return r.Table, nil
}
Expand Down
20 changes: 17 additions & 3 deletions sql/memo/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,26 @@ func buildBestJoinPlan(b *ExecBuilder, grp *ExprGroup, input sql.Schema) (sql.No
n := grp.Best
var err error
children := make([]sql.Node, len(n.Children()))
for i, g := range n.Children() {
children[i], err = buildBestJoinPlan(b, g, input)
switch n := n.(type) {
case *LateralJoin:
left, err := buildBestJoinPlan(b, n.Left, input)
if err != nil {
return nil, err
}
input = append(input, g.RelProps.OutputCols()...)
right, err := buildBestJoinPlan(b, n.Right, append(input, left.Schema()...))
if err != nil {
return nil, err
}
children[0] = left
children[1] = right
default:
for i, g := range n.Children() {
children[i], err = buildBestJoinPlan(b, g, input)
if err != nil {
return nil, err
}
input = append(input, g.RelProps.OutputCols()...)
}
}
return b.buildRel(n, input, children...)
}
Expand Down
52 changes: 10 additions & 42 deletions sql/memo/memo.og.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions sql/plan/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func (f *Filter) Expressions() []sql.Expression {
type FilterIter struct {
cond sql.Expression
childIter sql.RowIter
ParentRow sql.Row
}

// NewFilterIter creates a new FilterIter.
Expand All @@ -113,7 +112,7 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, err
}

res, err := sql.EvaluateCondition(ctx, i.cond, append(i.ParentRow, row...))
res, err := sql.EvaluateCondition(ctx, i.cond, row)
if err != nil {
return nil, err
}
Expand Down
13 changes: 13 additions & 0 deletions sql/plan/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,19 @@ func (i JoinType) AsLookup() JoinType {
}
}

func (i JoinType) AsLateral() JoinType {
switch i {
case JoinTypeInner:
return JoinTypeLateralInner
case JoinTypeLeftOuter, JoinTypeLeftOuterExcludeNulls:
return JoinTypeLateralLeft
case JoinTypeCross:
return JoinTypeLateralCross
default:
return i
}
}

// JoinNode contains all the common data fields and implements the common sql.Node getters for all join types.
type JoinNode struct {
BinaryNode
Expand Down
18 changes: 9 additions & 9 deletions sql/plan/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@ func (s *Subquery) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return rows[0], nil
}

// prependRowInPlan returns a transformation function that prepends the row given to any row source in a query
// PrependRowInPlan returns a transformation function that prepends the row given to any row source in a query
// plan. Any source of rows, as well as any node that alters the schema of its children, will be wrapped so that its
// result rows are prepended with the row given.
func prependRowInPlan(row sql.Row) func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
func PrependRowInPlan(row sql.Row, lateral bool) func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
return func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
switch n := n.(type) {
case sql.Table, sql.Projector, *ValueDerivedTable, *TableCountLookup:
Expand All @@ -190,11 +190,11 @@ func prependRowInPlan(row sql.Row) func(n sql.Node) (sql.Node, transform.TreeIde
}, transform.NewTree, nil
case *Union:
newUnion := *n
newRight, _, err := transform.Node(n.Right(), prependRowInPlan(row))
newRight, _, err := transform.Node(n.Right(), PrependRowInPlan(row, lateral))
if err != nil {
return n, transform.SameTree, err
}
newLeft, _, err := transform.Node(n.Left(), prependRowInPlan(row))
newLeft, _, err := transform.Node(n.Left(), PrependRowInPlan(row, lateral))
if err != nil {
return n, transform.SameTree, err
}
Expand All @@ -203,17 +203,17 @@ func prependRowInPlan(row sql.Row) func(n sql.Node) (sql.Node, transform.TreeIde
return &newUnion, transform.NewTree, nil
case *RecursiveCte:
newRecursiveCte := *n
newUnion, _, err := transform.Node(n.union, prependRowInPlan(row))
newUnion, _, err := transform.Node(n.union, PrependRowInPlan(row, lateral))
newRecursiveCte.union = newUnion.(*Union)
return &newRecursiveCte, transform.NewTree, err
case *SubqueryAlias:
// For SubqueryAliases (i.e. DerivedTables), since they may have visibility to outer scopes, we need to
// transform their inner nodes to prepend the outer scope row data. Ideally, we would only do this when
// the subquery alias references those outer fields. That will also require updating subquery expression
// scope handling to also make the same optimization.
if n.OuterScopeVisibility {
if n.OuterScopeVisibility || lateral {
newSubqueryAlias := *n
newChildNode, _, err := transform.Node(n.Child, prependRowInPlan(row))
newChildNode, _, err := transform.Node(n.Child, PrependRowInPlan(row, lateral))
newSubqueryAlias.Child = newChildNode
return &newSubqueryAlias, transform.NewTree, err
} else {
Expand Down Expand Up @@ -330,7 +330,7 @@ func (s *Subquery) EvalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, e
func (s *Subquery) evalMultiple(ctx *sql.Context, row sql.Row) ([]interface{}, error) {
// Any source of rows, as well as any node that alters the schema of its children, needs to be wrapped so that its
// result rows are prepended with the scope row.
q, _, err := transform.Node(s.Query, prependRowInPlan(row))
q, _, err := transform.Node(s.Query, PrependRowInPlan(row, false))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -415,7 +415,7 @@ func (s *Subquery) HasResultRow(ctx *sql.Context, row sql.Row) (bool, error) {

// Any source of rows, as well as any node that alters the schema of its children, needs to be wrapped so that its
// result rows are prepended with the scope row.
q, _, err := transform.Node(s.Query, prependRowInPlan(row))
q, _, err := transform.Node(s.Query, PrependRowInPlan(row, false))
if err != nil {
return false, err
}
Expand Down
12 changes: 6 additions & 6 deletions sql/rowexec/join_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
)

func newJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
Expand Down Expand Up @@ -752,14 +753,13 @@ func (i *lateralJoinIterator) loadLeft(ctx *sql.Context) error {

func (i *lateralJoinIterator) buildRight(ctx *sql.Context) error {
if i.rIter == nil {
iter, err := i.b.Build(ctx, i.rNode, i.lRow)
prepended, _, err := transform.Node(i.rNode, plan.PrependRowInPlan(i.lRow, true))
if err != nil {
return err
}

// Prepend node doesn't work over filter, because it calls filter.Next(), then prepends the row
if _, ok := iter.(*plan.FilterIter); ok {
iter.(*plan.FilterIter).ParentRow = i.lRow
iter, err := i.b.Build(ctx, prepended, i.lRow)
if err != nil {
return err
}
i.rIter = iter
}
Expand All @@ -772,7 +772,7 @@ func (i *lateralJoinIterator) loadRight(ctx *sql.Context) error {
if err != nil {
return err
}
i.rRow = rRow
i.rRow = rRow[len(i.lRow):]
}
return nil
}
Expand Down

0 comments on commit 5d3d97c

Please sign in to comment.