diff --git a/engine.go b/engine.go index 5e152aab0c..e60028c090 100644 --- a/engine.go +++ b/engine.go @@ -172,7 +172,7 @@ func (e *Engine) QueryNodeWithBindings( } if len(bindings) > 0 { - parsed, err = plan.ApplyBindings(ctx, parsed, bindings) + parsed, err = plan.ApplyBindings(parsed, bindings) if err != nil { return nil, nil, err } diff --git a/enginetest/queries.go b/enginetest/queries.go index 9ff2d8c29b..ccb6289f4e 100644 --- a/enginetest/queries.go +++ b/enginetest/queries.go @@ -1447,6 +1447,29 @@ var QueryTests = []QueryTest{ "var": expression.NewLiteral(int64(2), sql.Int64), }, }, + { + Query: "SELECT i, 1 AS foo, 2 AS bar FROM (SELECT i FROM mYtABLE WHERE i = ?) AS a ORDER BY foo, i", + Expected: []sql.Row{ + {2, 1, 2}}, + Bindings: map[string]sql.Expression{ + "v1": expression.NewLiteral(int64(2), sql.Int64), + }, + }, + { + Query: "SELECT (select sum(?) from mytable) as x FROM mytable ORDER BY (select sum(?) from mytable)", + Bindings: map[string]sql.Expression{ + "v1": expression.NewLiteral(1, sql.Int8), + "v2": expression.NewLiteral(1, sql.Int8), + }, + Expected: []sql.Row{{float64(3)}, {float64(3)}, {float64(3)}}, + }, + { + Query: "SELECT exists(select i from mytable where i = ?)", + Bindings: map[string]sql.Expression{ + "v1": expression.NewLiteral(1, sql.Int8), + }, + Expected: []sql.Row{{true}}, + }, { Query: "SELECT i, 1 AS foo, 2 AS bar FROM MyTable WHERE bar = 1 ORDER BY foo, i;", Expected: []sql.Row{}, diff --git a/sql/analyzer/resolve_ctes.go b/sql/analyzer/resolve_ctes.go index d2d27844f3..5a3be78078 100644 --- a/sql/analyzer/resolve_ctes.go +++ b/sql/analyzer/resolve_ctes.go @@ -75,7 +75,7 @@ func resolveCtesInNode(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope, cur = n for i := 0; i < maxCteDepth && !nodesEqual(prev, cur); i++ { prev = cur - cur, err = transformUpWithOpaque(prev, func(n sql.Node) (sql.Node, error) { + cur, err = plan.TransformUpWithOpaque(prev, func(n sql.Node) (sql.Node, error) { switch n := n.(type) { case *plan.UnresolvedTable: lowerName := strings.ToLower(n.Name()) @@ -144,35 +144,6 @@ func stripWith(ctx *sql.Context, a *Analyzer, scope *Scope, n sql.Node, ctes map return with.Child, nil } -// transformUpWithOpaque applies a transformation function to the given tree from the bottom up, including through -// opaque nodes. This method is generally not safe to use for a transformation. Opaque nodes need to be considered in -// isolation except for very specific exceptions. -// TODO: a better way to do this might be to keep the WITH nodes around until the very end of anlysis, so that -// resolve_subqueries can get at this info during that stage. But we couldn't use the existing scope mechanism for -// that, so it's a bit of a headache. -func transformUpWithOpaque(node sql.Node, f sql.TransformNodeFunc) (sql.Node, error) { - children := node.Children() - if len(children) == 0 { - return f(node) - } - - newChildren := make([]sql.Node, len(children)) - for i, c := range children { - c, err := transformUpWithOpaque(c, f) - if err != nil { - return nil, err - } - newChildren[i] = c - } - - node, err := node.WithChildren(newChildren...) - if err != nil { - return nil, err - } - - return f(node) -} - // schemaLength returns the length of a node's schema without actually accessing it. Useful when a node isn't yet // resolved, so Schema() could fail. func schemaLength(node sql.Node) int { diff --git a/sql/plan/bindvar.go b/sql/plan/bindvar.go index 0760136703..50d89967fb 100644 --- a/sql/plan/bindvar.go +++ b/sql/plan/bindvar.go @@ -24,38 +24,39 @@ import ( // If a binding for a |BindVar| expression is not found in the map, no error is // returned and the |BindVar| expression is left in place. There is no check on // whether all entries in |bindings| are used at least once throughout the |n|. -// -// This applies binding substitutions across *SubqueryAlias nodes, but will -// fail to apply bindings across other |sql.Opaque| nodes. -func ApplyBindings(ctx *sql.Context, n sql.Node, bindings map[string]sql.Expression) (sql.Node, error) { - withSubqueries, err := TransformUp(n, func(n sql.Node) (sql.Node, error) { - switch n := n.(type) { - case *SubqueryAlias: - child, err := ApplyBindings(ctx, n.Child, bindings) +// sql.DeferredType instances will be resolved by the binding types. +func ApplyBindings(n sql.Node, bindings map[string]sql.Expression) (sql.Node, error) { + fixBindings := func(expr sql.Expression) (sql.Expression, error) { + switch e := expr.(type) { + case *expression.BindVar: + val, found := bindings[e.Name] + if found { + return val, nil + } + case *Subquery: + // *Subquery is a sql.Expression with a sql.Node not reachable + // by the visitor. Manually apply bindings to [Query] field. + q, err := ApplyBindings(e.Query, bindings) if err != nil { return nil, err } - return n.WithChildren(child) + return e.WithQuery(q), nil + } + return expr, nil + } + + return TransformUpWithOpaque(n, func(node sql.Node) (sql.Node, error) { + switch n := node.(type) { case *InsertInto: - source, err := ApplyBindings(ctx, n.Source, bindings) + // Manually apply bindings to [Source] because it is separated + // from [Destination]. + newSource, err := ApplyBindings(n.Source, bindings) if err != nil { return nil, err } - return n.WithSource(source), nil + return n.WithSource(newSource), nil default: - return n, nil - } - }) - if err != nil { - return nil, err - } - return TransformExpressionsUp(withSubqueries, func(e sql.Expression) (sql.Expression, error) { - if bv, ok := e.(*expression.BindVar); ok { - val, found := bindings[bv.Name] - if found { - return val, nil - } + return TransformExpressionsUp(node, fixBindings) } - return e, nil }) } diff --git a/sql/plan/bindvar_test.go b/sql/plan/bindvar_test.go index e2aabea8c1..f1d12b9d8d 100644 --- a/sql/plan/bindvar_test.go +++ b/sql/plan/bindvar_test.go @@ -196,7 +196,7 @@ func TestApplyBindings(t *testing.T) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - res, err := ApplyBindings(sql.NewEmptyContext(), c.Node, c.Bindings) + res, err := ApplyBindings(c.Node, c.Bindings) if assert.NoError(t, err) { assert.Equal(t, res, c.Expected) } diff --git a/sql/plan/transform.go b/sql/plan/transform.go index 58bd846d7a..d0f836af57 100644 --- a/sql/plan/transform.go +++ b/sql/plan/transform.go @@ -147,3 +147,29 @@ func TransformExpressionsWithNode(n sql.Node, f expression.TransformExprWithNode return e.WithExpressions(newExprs...) } + +// TransformUpWithOpaque applies a transformation function to the given tree from the bottom up, including through +// opaque nodes. This method is generally not safe to use for a transformation. Opaque nodes need to be considered in +// isolation except for very specific exceptions. +func TransformUpWithOpaque(node sql.Node, f sql.TransformNodeFunc) (sql.Node, error) { + children := node.Children() + if len(children) == 0 { + return f(node) + } + + newChildren := make([]sql.Node, len(children)) + for i, c := range children { + c, err := TransformUpWithOpaque(c, f) + if err != nil { + return nil, err + } + newChildren[i] = c + } + + node, err := node.WithChildren(newChildren...) + if err != nil { + return nil, err + } + + return f(node) +}