Skip to content

Commit

Permalink
feat: handle last_insert_id with arguments in complex expressions
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Dec 10, 2024
1 parent f3ee39d commit 7be85da
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 33 deletions.
26 changes: 26 additions & 0 deletions go/vt/sqlparser/cow.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ func CopyOnRewrite(
return out
}

func CopyAndReplaceExpr(node SQLNode, replaceFn func(node Expr) (Expr, bool)) SQLNode {
var replace Expr
pre := func(node, _ SQLNode) bool {
expr, ok := node.(Expr)
if !ok {
return true
}
newExpr, ok := replaceFn(expr)
if !ok {
return true
}
replace = newExpr
return false
}

post := func(cursor *CopyOnWriteCursor) {
if replace == nil {
return
}
cursor.Replace(replace)
replace = nil
}

return CopyOnRewrite(node, pre, post, nil)
}

// StopTreeWalk aborts the current tree walking. No more nodes will be visited, and the rewriter will exit out early
func (c *CopyOnWriteCursor) StopTreeWalk() {
c.stop = true
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/executorcontext/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,8 +1130,8 @@ func (vc *VCursorImpl) SetFoundRows(foundRows uint64) {
vc.SafeSession.SetFoundRows(foundRows)
}

func (vc *vcursorImpl) SetLastInsertID(id uint64) {
vc.safeSession.LastInsertId = id
func (vc *VCursorImpl) SetLastInsertID(id uint64) {
vc.SafeSession.LastInsertId = id
}

// SetDDLStrategy implements the SessionActions interface
Expand Down
100 changes: 75 additions & 25 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ import (
func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator {
var selExpr sqlparser.SelectExprs
if horizon, isHorizon := root.(*Horizon); isHorizon {
sel := sqlparser.GetFirstSelect(horizon.Query)
selExpr = sqlparser.Clone(sel.SelectExprs)
selExpr = extractSelectExpressions(horizon)
}

output := runPhases(ctx, root)
Expand Down Expand Up @@ -821,36 +820,46 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *Apply
return newUnion(sources, selects, op.unionColumns, op.distinct), Rewrote("merge union inputs")
}

// addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for
func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator {
if len(selExprs) == 0 {
return output
}

cols := output.GetSelectExprs(ctx)
sizeCorrect := len(selExprs) == len(cols) || tryTruncateColumnsAt(output, len(selExprs))
if !sizeCorrect || !colNamesAlign(selExprs, cols) {
output = createSimpleProjection(ctx, selExprs, output)
}

if !ctx.SemTable.QuerySignature.LastInsertIDArg {
return output
}

var offset int
for i, expr := range selExprs {
func handleLastInsertIDColumns(ctx *plancontext.PlanningContext, output Operator) Operator {
offset := -1
topLevel := false
var arg sqlparser.Expr
for i, expr := range output.GetSelectExprs(ctx) {
ae, ok := expr.(*sqlparser.AliasedExpr)
if !ok {
panic(vterrors.VT09015())
}
fnc, ok := ae.Expr.(*sqlparser.FuncExpr)
if !ok || !fnc.Name.EqualString("last_insert_id") {
continue

replaceFn := func(node sqlparser.Expr) (sqlparser.Expr, bool) {
fnc, ok := node.(*sqlparser.FuncExpr)
if !ok || !fnc.Name.EqualString("last_insert_id") {
return node, false
}
if offset != -1 {
panic(vterrors.VT12001("last_insert_id() found multiple times in select list"))
}
arg = fnc.Exprs[0]
if node == ae.Expr {
topLevel = true
}
offset = i
return arg, true
}
offset = i
break

newExpr := sqlparser.CopyAndReplaceExpr(ae.Expr, replaceFn)
ae.Expr = newExpr.(sqlparser.Expr)
}

if topLevel {
return &SaveToSession{
unaryOperator: unaryOperator{
Source: output,
},
Offset: offset,
}
}

offset = output.AddColumn(ctx, false, false, aeWrap(arg))
return &SaveToSession{
unaryOperator: unaryOperator{
Source: output,
Expand All @@ -859,6 +868,47 @@ func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, s
}
}

// addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for
func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator {
if len(selExprs) == 0 {
return output
}

if ctx.SemTable.QuerySignature.LastInsertIDArg {
output = handleLastInsertIDColumns(ctx, output)
}

cols := output.GetSelectExprs(ctx)
sizeCorrect := len(selExprs) == len(cols) || tryTruncateColumnsAt(output, len(selExprs))
if sizeCorrect && colNamesAlign(selExprs, cols) {
return output
}

return createSimpleProjection(ctx, selExprs, output)
}

func extractSelectExpressions(horizon *Horizon) sqlparser.SelectExprs {
sel := sqlparser.GetFirstSelect(horizon.Query)
// we handle last_insert_id with arguments separately - no need to send this down to mysql
selExprs := sqlparser.CopyAndReplaceExpr(sel.SelectExprs, func(node sqlparser.Expr) (sqlparser.Expr, bool) {
switch node := node.(type) {
case *sqlparser.FuncExpr:
if node.Name.EqualString("last_insert_id") && len(node.Exprs) == 1 {
return node.Exprs[0], true
}
return node, true
case sqlparser.Expr:
// we do this to make sure we get a clone of the expression
// if planning changes the expression, we should not change the original
return node, true
default:
return nil, false
}
})

return selExprs.(sqlparser.SelectExprs)
}

func colNamesAlign(expected, actual sqlparser.SelectExprs) bool {
if len(expected) > len(actual) {
// if we expect more columns than we have, we can't align
Expand Down
12 changes: 6 additions & 6 deletions go/vt/vtgate/planbuilder/testdata/select_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -2177,8 +2177,8 @@
"Name": "main",
"Sharded": false
},
"FieldQuery": "select last_insert_id(12) from dual where 1 != 1",
"Query": "select last_insert_id(12) from dual",
"FieldQuery": "select 12 from dual where 1 != 1",
"Query": "select 12 from dual",
"Table": "dual"
}
]
Expand All @@ -2205,8 +2205,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select bar, 12, last_insert_id(foo) from `user` where 1 != 1",
"Query": "select bar, 12, last_insert_id(foo) from `user`",
"FieldQuery": "select bar, 12, foo from `user` where 1 != 1",
"Query": "select bar, 12, foo from `user`",
"Table": "`user`"
}
]
Expand Down Expand Up @@ -5687,8 +5687,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select last_insert_id(id) from `user` where 1 != 1",
"Query": "select last_insert_id(id) from `user`",
"FieldQuery": "select id from `user` where 1 != 1",
"Query": "select id from `user`",
"Table": "`user`"
}
]
Expand Down

0 comments on commit 7be85da

Please sign in to comment.