From 7be85dac835a55e91fa82a62b815fd76d683fa45 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 10 Dec 2024 12:13:22 +0100 Subject: [PATCH] feat: handle last_insert_id with arguments in complex expressions Signed-off-by: Andres Taylor --- go/vt/sqlparser/cow.go | 26 +++++ go/vt/vtgate/executorcontext/vcursor_impl.go | 4 +- .../planbuilder/operators/query_planning.go | 100 +++++++++++++----- .../planbuilder/testdata/select_cases.json | 12 +-- 4 files changed, 109 insertions(+), 33 deletions(-) diff --git a/go/vt/sqlparser/cow.go b/go/vt/sqlparser/cow.go index e807fdeef63..be376f84403 100644 --- a/go/vt/sqlparser/cow.go +++ b/go/vt/sqlparser/cow.go @@ -53,6 +53,32 @@ func CopyOnRewrite( return out } +func CopyAndReplaceExpr(node SQLNode, replaceFn func(node Expr) (Expr, bool)) SQLNode { + var replace Expr + pre := func(node, _ SQLNode) bool { + expr, ok := node.(Expr) + if !ok { + return true + } + newExpr, ok := replaceFn(expr) + if !ok { + return true + } + replace = newExpr + return false + } + + post := func(cursor *CopyOnWriteCursor) { + if replace == nil { + return + } + cursor.Replace(replace) + replace = nil + } + + return CopyOnRewrite(node, pre, post, nil) +} + // StopTreeWalk aborts the current tree walking. No more nodes will be visited, and the rewriter will exit out early func (c *CopyOnWriteCursor) StopTreeWalk() { c.stop = true diff --git a/go/vt/vtgate/executorcontext/vcursor_impl.go b/go/vt/vtgate/executorcontext/vcursor_impl.go index 4a249059983..6b83953fe23 100644 --- a/go/vt/vtgate/executorcontext/vcursor_impl.go +++ b/go/vt/vtgate/executorcontext/vcursor_impl.go @@ -1130,8 +1130,8 @@ func (vc *VCursorImpl) SetFoundRows(foundRows uint64) { vc.SafeSession.SetFoundRows(foundRows) } -func (vc *vcursorImpl) SetLastInsertID(id uint64) { - vc.safeSession.LastInsertId = id +func (vc *VCursorImpl) SetLastInsertID(id uint64) { + vc.SafeSession.LastInsertId = id } // SetDDLStrategy implements the SessionActions interface diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index f47d6b6927c..81e3a099ebc 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -32,8 +32,7 @@ import ( func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator { var selExpr sqlparser.SelectExprs if horizon, isHorizon := root.(*Horizon); isHorizon { - sel := sqlparser.GetFirstSelect(horizon.Query) - selExpr = sqlparser.Clone(sel.SelectExprs) + selExpr = extractSelectExpressions(horizon) } output := runPhases(ctx, root) @@ -821,36 +820,46 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *Apply return newUnion(sources, selects, op.unionColumns, op.distinct), Rewrote("merge union inputs") } -// addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for -func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator { - if len(selExprs) == 0 { - return output - } - - cols := output.GetSelectExprs(ctx) - sizeCorrect := len(selExprs) == len(cols) || tryTruncateColumnsAt(output, len(selExprs)) - if !sizeCorrect || !colNamesAlign(selExprs, cols) { - output = createSimpleProjection(ctx, selExprs, output) - } - - if !ctx.SemTable.QuerySignature.LastInsertIDArg { - return output - } - - var offset int - for i, expr := range selExprs { +func handleLastInsertIDColumns(ctx *plancontext.PlanningContext, output Operator) Operator { + offset := -1 + topLevel := false + var arg sqlparser.Expr + for i, expr := range output.GetSelectExprs(ctx) { ae, ok := expr.(*sqlparser.AliasedExpr) if !ok { panic(vterrors.VT09015()) } - fnc, ok := ae.Expr.(*sqlparser.FuncExpr) - if !ok || !fnc.Name.EqualString("last_insert_id") { - continue + + replaceFn := func(node sqlparser.Expr) (sqlparser.Expr, bool) { + fnc, ok := node.(*sqlparser.FuncExpr) + if !ok || !fnc.Name.EqualString("last_insert_id") { + return node, false + } + if offset != -1 { + panic(vterrors.VT12001("last_insert_id() found multiple times in select list")) + } + arg = fnc.Exprs[0] + if node == ae.Expr { + topLevel = true + } + offset = i + return arg, true } - offset = i - break + + newExpr := sqlparser.CopyAndReplaceExpr(ae.Expr, replaceFn) + ae.Expr = newExpr.(sqlparser.Expr) } + if topLevel { + return &SaveToSession{ + unaryOperator: unaryOperator{ + Source: output, + }, + Offset: offset, + } + } + + offset = output.AddColumn(ctx, false, false, aeWrap(arg)) return &SaveToSession{ unaryOperator: unaryOperator{ Source: output, @@ -859,6 +868,47 @@ func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, s } } +// addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for +func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator { + if len(selExprs) == 0 { + return output + } + + if ctx.SemTable.QuerySignature.LastInsertIDArg { + output = handleLastInsertIDColumns(ctx, output) + } + + cols := output.GetSelectExprs(ctx) + sizeCorrect := len(selExprs) == len(cols) || tryTruncateColumnsAt(output, len(selExprs)) + if sizeCorrect && colNamesAlign(selExprs, cols) { + return output + } + + return createSimpleProjection(ctx, selExprs, output) +} + +func extractSelectExpressions(horizon *Horizon) sqlparser.SelectExprs { + sel := sqlparser.GetFirstSelect(horizon.Query) + // we handle last_insert_id with arguments separately - no need to send this down to mysql + selExprs := sqlparser.CopyAndReplaceExpr(sel.SelectExprs, func(node sqlparser.Expr) (sqlparser.Expr, bool) { + switch node := node.(type) { + case *sqlparser.FuncExpr: + if node.Name.EqualString("last_insert_id") && len(node.Exprs) == 1 { + return node.Exprs[0], true + } + return node, true + case sqlparser.Expr: + // we do this to make sure we get a clone of the expression + // if planning changes the expression, we should not change the original + return node, true + default: + return nil, false + } + }) + + return selExprs.(sqlparser.SelectExprs) +} + func colNamesAlign(expected, actual sqlparser.SelectExprs) bool { if len(expected) > len(actual) { // if we expect more columns than we have, we can't align diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index b1134becef6..6f4b8716ce0 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2177,8 +2177,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select last_insert_id(12) from dual where 1 != 1", - "Query": "select last_insert_id(12) from dual", + "FieldQuery": "select 12 from dual where 1 != 1", + "Query": "select 12 from dual", "Table": "dual" } ] @@ -2205,8 +2205,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select bar, 12, last_insert_id(foo) from `user` where 1 != 1", - "Query": "select bar, 12, last_insert_id(foo) from `user`", + "FieldQuery": "select bar, 12, foo from `user` where 1 != 1", + "Query": "select bar, 12, foo from `user`", "Table": "`user`" } ] @@ -5687,8 +5687,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select last_insert_id(id) from `user` where 1 != 1", - "Query": "select last_insert_id(id) from `user`", + "FieldQuery": "select id from `user` where 1 != 1", + "Query": "select id from `user`", "Table": "`user`" } ]