Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply bindvars to subqueries #871

Merged
merged 2 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (e *Engine) QueryNodeWithBindings(
}

if len(bindings) > 0 {
parsed, err = plan.ApplyBindings(ctx, parsed, bindings)
parsed, err = plan.ApplyBindings(parsed, bindings)
if err != nil {
return nil, nil, err
}
Expand Down
23 changes: 23 additions & 0 deletions enginetest/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,29 @@ var QueryTests = []QueryTest{
"var": expression.NewLiteral(int64(2), sql.Int64),
},
},
{
Query: "SELECT i, 1 AS foo, 2 AS bar FROM (SELECT i FROM mYtABLE WHERE i = ?) AS a ORDER BY foo, i",
Expected: []sql.Row{
{2, 1, 2}},
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral(int64(2), sql.Int64),
},
},
{
Query: "SELECT (select sum(?) from mytable) as x FROM mytable ORDER BY (select sum(?) from mytable)",
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral(1, sql.Int8),
"v2": expression.NewLiteral(1, sql.Int8),
},
Expected: []sql.Row{{float64(3)}, {float64(3)}, {float64(3)}},
},
{
Query: "SELECT exists(select i from mytable where i = ?)",
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral(1, sql.Int8),
},
Expected: []sql.Row{{true}},
},
{
Query: "SELECT i, 1 AS foo, 2 AS bar FROM MyTable WHERE bar = 1 ORDER BY foo, i;",
Expected: []sql.Row{},
Expand Down
31 changes: 1 addition & 30 deletions sql/analyzer/resolve_ctes.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func resolveCtesInNode(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope,
cur = n
for i := 0; i < maxCteDepth && !nodesEqual(prev, cur); i++ {
prev = cur
cur, err = transformUpWithOpaque(prev, func(n sql.Node) (sql.Node, error) {
cur, err = plan.TransformUpWithOpaque(prev, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.UnresolvedTable:
lowerName := strings.ToLower(n.Name())
Expand Down Expand Up @@ -144,35 +144,6 @@ func stripWith(ctx *sql.Context, a *Analyzer, scope *Scope, n sql.Node, ctes map
return with.Child, nil
}

// transformUpWithOpaque applies a transformation function to the given tree from the bottom up, including through
// opaque nodes. This method is generally not safe to use for a transformation. Opaque nodes need to be considered in
// isolation except for very specific exceptions.
// TODO: a better way to do this might be to keep the WITH nodes around until the very end of anlysis, so that
// resolve_subqueries can get at this info during that stage. But we couldn't use the existing scope mechanism for
// that, so it's a bit of a headache.
func transformUpWithOpaque(node sql.Node, f sql.TransformNodeFunc) (sql.Node, error) {
children := node.Children()
if len(children) == 0 {
return f(node)
}

newChildren := make([]sql.Node, len(children))
for i, c := range children {
c, err := transformUpWithOpaque(c, f)
if err != nil {
return nil, err
}
newChildren[i] = c
}

node, err := node.WithChildren(newChildren...)
if err != nil {
return nil, err
}

return f(node)
}

// schemaLength returns the length of a node's schema without actually accessing it. Useful when a node isn't yet
// resolved, so Schema() could fail.
func schemaLength(node sql.Node) int {
Expand Down
49 changes: 25 additions & 24 deletions sql/plan/bindvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,39 @@ import (
// If a binding for a |BindVar| expression is not found in the map, no error is
// returned and the |BindVar| expression is left in place. There is no check on
// whether all entries in |bindings| are used at least once throughout the |n|.
//
// This applies binding substitutions across *SubqueryAlias nodes, but will
// fail to apply bindings across other |sql.Opaque| nodes.
func ApplyBindings(ctx *sql.Context, n sql.Node, bindings map[string]sql.Expression) (sql.Node, error) {
withSubqueries, err := TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *SubqueryAlias:
child, err := ApplyBindings(ctx, n.Child, bindings)
// sql.DeferredType instances will be resolved by the binding types.
func ApplyBindings(n sql.Node, bindings map[string]sql.Expression) (sql.Node, error) {
fixBindings := func(expr sql.Expression) (sql.Expression, error) {
switch e := expr.(type) {
case *expression.BindVar:
val, found := bindings[e.Name]
if found {
return val, nil
}
case *Subquery:
// *Subquery is a sql.Expression with a sql.Node not reachable
// by the visitor. Manually apply bindings to [Query] field.
q, err := ApplyBindings(e.Query, bindings)
if err != nil {
return nil, err
}
return n.WithChildren(child)
return e.WithQuery(q), nil
}
return expr, nil
}

return TransformUpWithOpaque(n, func(node sql.Node) (sql.Node, error) {
switch n := node.(type) {
case *InsertInto:
source, err := ApplyBindings(ctx, n.Source, bindings)
// Manually apply bindings to [Source] because it is separated
// from [Destination].
newSource, err := ApplyBindings(n.Source, bindings)
if err != nil {
return nil, err
}
return n.WithSource(source), nil
return n.WithSource(newSource), nil
default:
return n, nil
}
})
if err != nil {
return nil, err
}
return TransformExpressionsUp(withSubqueries, func(e sql.Expression) (sql.Expression, error) {
if bv, ok := e.(*expression.BindVar); ok {
val, found := bindings[bv.Name]
if found {
return val, nil
}
return TransformExpressionsUp(node, fixBindings)
}
return e, nil
})
}
2 changes: 1 addition & 1 deletion sql/plan/bindvar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestApplyBindings(t *testing.T) {

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
res, err := ApplyBindings(sql.NewEmptyContext(), c.Node, c.Bindings)
res, err := ApplyBindings(c.Node, c.Bindings)
if assert.NoError(t, err) {
assert.Equal(t, res, c.Expected)
}
Expand Down
26 changes: 26 additions & 0 deletions sql/plan/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,29 @@ func TransformExpressionsWithNode(n sql.Node, f expression.TransformExprWithNode

return e.WithExpressions(newExprs...)
}

// TransformUpWithOpaque applies a transformation function to the given tree from the bottom up, including through
// opaque nodes. This method is generally not safe to use for a transformation. Opaque nodes need to be considered in
// isolation except for very specific exceptions.
func TransformUpWithOpaque(node sql.Node, f sql.TransformNodeFunc) (sql.Node, error) {
children := node.Children()
if len(children) == 0 {
return f(node)
}

newChildren := make([]sql.Node, len(children))
for i, c := range children {
c, err := TransformUpWithOpaque(c, f)
if err != nil {
return nil, err
}
newChildren[i] = c
}

node, err := node.WithChildren(newChildren...)
if err != nil {
return nil, err
}

return f(node)
}