Skip to content

Commit

Permalink
TransformUp is now sensitive to tree modifications (#867)
Browse files Browse the repository at this point in the history
* TransformUp now sensitive to tree modifications

`TransformUp` and related node/expression DFS helper functions expect the
visit function to return an additional parameter indicating
whether the visit changed the node: `sql.SameTree` or `sql.NewTree`.

We use `plan.InspectUp` when possible, and then `plan.TransformUp` where
possible, resorting to the more expensive `plan.TransformUpCtx` and
`plan.TransformUpCtxSchema` only when necessary.

* progress for nodesEqual refactor

* prog

* renames

* fixup names more

* Move transform related interfaces into transform package

* [ga-format-pr] Run ./format_repo.sh to fix formatting

* missing docstring

* prog

* Zach's comments

Co-authored-by: max-hoffman <max-hoffman@users.noreply.github.com>
  • Loading branch information
max-hoffman and max-hoffman authored Apr 1, 2022
1 parent b5db32c commit 24320fe
Show file tree
Hide file tree
Showing 115 changed files with 2,984 additions and 2,215 deletions.
6 changes: 4 additions & 2 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"fmt"
"os"

"github.com/dolthub/go-mysql-server/sql/transform"

"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
Expand Down Expand Up @@ -225,7 +227,7 @@ func (e *Engine) QueryNodeWithBindings(
// allNode2 returns whether all the nodes in the tree implement Node2.
func allNode2(n sql.Node) bool {
allNode2 := true
plan.Inspect(n, func(n sql.Node) bool {
transform.Inspect(n, func(n sql.Node) bool {
switch n := n.(type) {
case *plan.ResolvedTable:
table := n.Table
Expand All @@ -249,7 +251,7 @@ func allNode2(n sql.Node) bool {

// All expressions in the tree must likewise be Expression2, and all types Type2, or we can't use rowFrame iteration
// TODO: likely that some nodes rely on expressions but don't implement sql.Expressioner, or implement it incompletely
plan.InspectExpressions(n, func(e sql.Expression) bool {
transform.InspectExpressions(n, func(e sql.Expression) bool {
if e == nil {
return false
}
Expand Down
2 changes: 1 addition & 1 deletion enginetest/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func TestTrackProcess(t *testing.T) {
require.NoError(err)

rule := getRuleFrom(analyzer.OnceAfterAll, "track_process")
result, err := rule.Apply(ctx, a, node, nil)
result, _, err := rule.Apply(ctx, a, node, nil)
require.NoError(err)

processes := ctx.ProcessList.Processes()
Expand Down
17 changes: 9 additions & 8 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/information_schema"
"github.com/dolthub/go-mysql-server/sql/parse"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
"github.com/dolthub/go-mysql-server/test"
)

Expand Down Expand Up @@ -853,7 +854,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand Down Expand Up @@ -882,7 +883,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand Down Expand Up @@ -910,7 +911,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand Down Expand Up @@ -938,7 +939,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand Down Expand Up @@ -967,7 +968,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand All @@ -994,7 +995,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand All @@ -1021,7 +1022,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand Down Expand Up @@ -1053,7 +1054,7 @@ func TestTruncate(t *testing.T, harness Harness) {
analyzed, err := e.Analyzer.Analyze(ctx, parsed, nil)
require.NoError(t, err)
truncateFound := false
plan.Inspect(analyzed, func(n sql.Node) bool {
transform.Inspect(analyzed, func(n sql.Node) bool {
switch n.(type) {
case *plan.Truncate:
truncateFound = true
Expand Down
2 changes: 1 addition & 1 deletion enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestSingleQuery(t *testing.T) {

var test enginetest.QueryTest
test = enginetest.QueryTest{
Query: `SELECT a.* FROM one_pk a CROSS JOIN one_pk b LEFT JOIN one_pk c ON b.pk = c.pk`,
Query: `SELECT COUNT(*) FROM mytable`,
Expected: []sql.Row{},
}

Expand Down
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474/go.mod h1:kMz7uXOXq4qRriCEyZ/LUeTqraLJCjf0WVZcUi6TxUY=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20220321185459-a24823fdb878 h1:i3U7Vppl2PuubOH6+EF13elXNUfBz1/CksCxyCn94qY=
github.com/dolthub/vitess v0.0.0-20220321185459-a24823fdb878/go.mod h1:qpZ4j0dval04OgZJ5fyKnlniSFUosTH280pdzUjUJig=
github.com/dolthub/vitess v0.0.0-20220323175412-7e0381fb7c3f h1:5kW1g4mscnLfJJR0C871vqQhQVifLm8mkhKBU131Nt4=
github.com/dolthub/vitess v0.0.0-20220323175412-7e0381fb7c3f/go.mod h1:jxgvpEvrTNw2i4BKlwT75E775eUXBeMv5MPeQkIb9zI=
github.com/dolthub/vitess v0.0.0-20220328235252-487aaa0ff789 h1:8SSJGCb73qXjGfd7P0D+aR4Cucgoco1dtGvupvbyGSU=
github.com/dolthub/vitess v0.0.0-20220328235252-487aaa0ff789/go.mod h1:jxgvpEvrTNw2i4BKlwT75E775eUXBeMv5MPeQkIb9zI=
github.com/dolthub/vitess v0.0.0-20220330190824-c23b568183c5 h1:66HHMuSkaCV3S7bNyLD9T9nRRqeHDvKZt8o4nq81IDc=
Expand Down
7 changes: 4 additions & 3 deletions memory/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/transform"
)

// Table represents an in-memory database table.
Expand Down Expand Up @@ -550,11 +551,11 @@ func (t *Table) addColumnToSchema(ctx *sql.Context, newCol *sql.Column, order *s
if i == newColIdx {
continue
}
newDefault, _ := expression.TransformUp(newSchCol.Default, func(expr sql.Expression) (sql.Expression, error) {
newDefault, _, _ := transform.Expr(newSchCol.Default, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
if expr, ok := expr.(*expression.GetField); ok {
return expr.WithIndex(newSch.IndexOf(expr.Name(), t.name)), nil
return expr.WithIndex(newSch.IndexOf(expr.Name(), t.name)), transform.NewTree, nil
}
return expr, nil
return expr, transform.SameTree, nil
})
newSchCol.Default = newDefault.(*sql.ColumnDefaultValue)
}
Expand Down
5 changes: 3 additions & 2 deletions optgen/cmd/support/agg_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (g *AggGen) Generate(defines GenDefs, w io.Writer) {
fmt.Fprintf(g.w, " \"fmt\"\n")
fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql\"\n")
fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql/expression\"\n")
fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql/transform\"\n")
fmt.Fprintf(g.w, ")\n\n")

for _, define := range g.defines {
Expand Down Expand Up @@ -111,7 +112,7 @@ func (g *AggGen) genAggWithWindow(define AggDef) {

func (g *AggGen) genAggWindowConstructor(define AggDef) {
fmt.Fprintf(g.w, "func (a *%s) NewWindowFunction() (sql.WindowFunction, error) {\n", define.Name)
fmt.Fprintf(g.w, " child, err := expression.Clone(a.UnaryExpression.Child)\n")
fmt.Fprintf(g.w, " child, err := transform.Clone(a.UnaryExpression.Child)\n")
fmt.Fprintf(g.w, " if err != nil {\n")
fmt.Fprintf(g.w, " return nil, err\n")
fmt.Fprintf(g.w, " }\n")
Expand All @@ -121,7 +122,7 @@ func (g *AggGen) genAggWindowConstructor(define AggDef) {

func (g *AggGen) genAggNewBuffer(define AggDef) {
fmt.Fprintf(g.w, "func (a *%s) NewBuffer() (sql.AggregationBuffer, error) {\n", define.Name)
fmt.Fprintf(g.w, " child, err := expression.Clone(a.UnaryExpression.Child)\n")
fmt.Fprintf(g.w, " child, err := transform.Clone(a.UnaryExpression.Child)\n")
fmt.Fprintf(g.w, " if err != nil {\n")
fmt.Fprintf(g.w, " return nil, err\n")
fmt.Fprintf(g.w, " }\n")
Expand Down
5 changes: 3 additions & 2 deletions optgen/cmd/support/agg_gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func TestAggGen(t *testing.T) {
"fmt"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/transform"
)
type Test struct{
Expand Down Expand Up @@ -67,15 +68,15 @@ func TestAggGen(t *testing.T) {
}
func (a *Test) NewBuffer() (sql.AggregationBuffer, error) {
child, err := expression.Clone(a.UnaryExpression.Child)
child, err := transform.Clone(a.UnaryExpression.Child)
if err != nil {
return nil, err
}
return NewTestBuffer(child), nil
}
func (a *Test) NewWindowFunction() (sql.WindowFunction, error) {
child, err := expression.Clone(a.UnaryExpression.Child)
child, err := transform.Clone(a.UnaryExpression.Child)
if err != nil {
return nil, err
}
Expand Down
59 changes: 31 additions & 28 deletions sql/analyzer/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
)

// flattenAggregationExpressions flattens any complex aggregate or window expressions in a GroupBy or Window node and
Expand All @@ -27,81 +28,81 @@ import (
// e.g. GroupBy(sum(a) + sum(b)) becomes project(sum(a) + sum(b), GroupBy(sum(a), sum(b)).
// e.g. Window(sum(a) + sum(b) over (partition by a)) becomes
// project(sum(a) + sum(b) over (partition by a), Window(sum(a), sum(b) over (partition by a))).
func flattenAggregationExpressions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) {
func flattenAggregationExpressions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, transform.TreeIdentity, error) {
span, _ := ctx.Span("flatten_aggregation_exprs")
defer span.Finish()

if !n.Resolved() {
return n, nil
return n, transform.SameTree, nil
}

return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
switch n := n.(type) {
case *plan.Window:
if !hasHiddenAggregations(n.SelectExprs) && !hasHiddenWindows(n.SelectExprs) {
return n, nil
return n, transform.SameTree, nil
}

return flattenedWindow(ctx, n.SelectExprs, n.Child)
case *plan.GroupBy:
if !hasHiddenAggregations(n.SelectedExprs) {
return n, nil
return n, transform.SameTree, nil
}

return flattenedGroupBy(ctx, n.SelectedExprs, n.GroupByExprs, n.Child)
default:
return n, nil
return n, transform.SameTree, nil
}
})
}

func flattenedGroupBy(ctx *sql.Context, projection, grouping []sql.Expression, child sql.Node) (sql.Node, error) {
newProjection, newAggregates, err := replaceAggregatesWithGetFieldProjections(ctx, projection)
func flattenedGroupBy(ctx *sql.Context, projection, grouping []sql.Expression, child sql.Node) (sql.Node, transform.TreeIdentity, error) {
newProjection, newAggregates, allSame, err := replaceAggregatesWithGetFieldProjections(ctx, projection)
if err != nil {
return nil, err
return nil, transform.SameTree, err
}
if allSame {
return nil, transform.SameTree, nil
}

return plan.NewProject(
newProjection,
plan.NewGroupBy(newAggregates, grouping, child),
), nil
), transform.NewTree, nil
}

// replaceAggregatesWithGetFieldProjections takes a slice of projection expressions and flattens out any aggregate
// expressions within, wrapping all such flattened aggregations into a GetField projection. Returns two new slices: the
// new set of project expressions, and the new set of aggregations. The former always matches the size of the projection
// expressions passed in. The latter will have the size of the number of aggregate expressions contained in the input
// slice.
func replaceAggregatesWithGetFieldProjections(ctx *sql.Context, projection []sql.Expression) (projections, aggregations []sql.Expression, err error) {
func replaceAggregatesWithGetFieldProjections(ctx *sql.Context, projection []sql.Expression) (projections, aggregations []sql.Expression, identity transform.TreeIdentity, err error) {
var newProjection = make([]sql.Expression, len(projection))
var newAggregates []sql.Expression
allGetFields := make(map[int]sql.Expression)
projDeps := make(map[int]struct{})
for i, p := range projection {
var transformed bool
e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) {
e, same, err := transform.Expr(p, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
switch e := e.(type) {
case sql.Aggregation, sql.WindowAggregation:
// continue on
case *expression.GetField:
allGetFields[e.Index()] = e
projDeps[e.Index()] = struct{}{}
return e, nil
return e, transform.SameTree, nil
default:
return e, nil
return e, transform.SameTree, nil
}

transformed = true
newAggregates = append(newAggregates, e)
return expression.NewGetField(
len(newAggregates)-1, e.Type(), e.String(), e.IsNullable(),
), nil
), transform.NewTree, nil
})
if err != nil {
return nil, nil, err
return nil, nil, transform.SameTree, err
}

if !transformed {
if same {
newAggregates = append(newAggregates, e)
name, source := getNameAndSource(e)
newProjection[i] = expression.NewGetFieldWithTable(
Expand All @@ -115,12 +116,12 @@ func replaceAggregatesWithGetFieldProjections(ctx *sql.Context, projection []sql
// find subset of allGetFields not covered by newAggregates
newAggDeps := make(map[int]struct{}, 0)
for _, agg := range newAggregates {
_, _ = expression.TransformUp(agg, func(e sql.Expression) (sql.Expression, error) {
_ = transform.InspectExpr(agg, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.GetField:
newAggDeps[e.Index()] = struct{}{}
}
return e, nil
return false
})
}
for i, _ := range projDeps {
Expand All @@ -130,19 +131,21 @@ func replaceAggregatesWithGetFieldProjections(ctx *sql.Context, projection []sql
}
}

return newProjection, newAggregates, nil
return newProjection, newAggregates, transform.NewTree, nil
}

func flattenedWindow(ctx *sql.Context, projection []sql.Expression, child sql.Node) (sql.Node, error) {
newProjection, newAggregates, err := replaceAggregatesWithGetFieldProjections(ctx, projection)
func flattenedWindow(ctx *sql.Context, projection []sql.Expression, child sql.Node) (sql.Node, transform.TreeIdentity, error) {
newProjection, newAggregates, allSame, err := replaceAggregatesWithGetFieldProjections(ctx, projection)
if err != nil {
return nil, err
return nil, transform.SameTree, err
}
if allSame {
return nil, allSame, nil
}

return plan.NewProject(
newProjection,
plan.NewWindow(newAggregates, child),
), nil
), transform.NewTree, nil
}

func getNameAndSource(e sql.Expression) (name, source string) {
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/aggregations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func TestFlattenAggregationExprs(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), test.node, nil)
result, _, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), test.node, nil)
require.NoError(err)
require.Equal(test.expected, result)
})
Expand Down
Loading

0 comments on commit 24320fe

Please sign in to comment.