From 3fb795b6e20f4b6c72102e429da2e0a23416b848 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 18 Dec 2024 14:03:10 +0100 Subject: [PATCH] evalengine: handle NULL in last_insert_id correctly Signed-off-by: Andres Taylor --- go/vt/vtgate/evalengine/compiler_test.go | 49 ++++++++++++++++++------ go/vt/vtgate/evalengine/fn_misc.go | 22 ++++++++--- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 035d117083f..9fbcffdcdf6 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -884,7 +884,7 @@ func TestBindVarLiteral(t *testing.T) { } type testVcursor struct { - lastInsertID uint64 + lastInsertID *uint64 env *vtenv.Environment } @@ -905,7 +905,7 @@ func (t *testVcursor) Environment() *vtenv.Environment { } func (t *testVcursor) SetLastInsertID(id uint64) { - t.lastInsertID = id + t.lastInsertID = &id } var _ evalengine.VCursor = (*testVcursor)(nil) @@ -914,16 +914,20 @@ func TestLastInsertID(t *testing.T) { var testCases = []struct { expression string result uint64 + missing bool }{ { expression: `last_insert_id(1)`, result: 1, }, { expression: `12`, - result: 0, + missing: true, }, { expression: `last_insert_id(666)`, result: 666, + }, { + expression: `last_insert_id(null)`, + result: 0, }, } @@ -932,22 +936,43 @@ func TestLastInsertID(t *testing.T) { t.Run(tc.expression, func(t *testing.T) { expr, err := venv.Parser().ParseExpr(tc.expression) require.NoError(t, err) + cfg := &evalengine.Config{ Collation: collations.CollationUtf8mb4ID, - Environment: venv, NoConstantFolding: true, + NoCompilation: false, + Environment: venv, } + t.Run("eval", func(t *testing.T) { + cfg.NoCompilation = true + runTest(t, expr, cfg, tc) + }) + t.Run("compiled", func(t *testing.T) { + cfg.NoCompilation = false + runTest(t, expr, cfg, tc) + }) + }) + } +} - converted, err := evalengine.Translate(expr, cfg) - require.NoError(t, err) +func runTest(t *testing.T, expr sqlparser.Expr, cfg *evalengine.Config, tc struct { + expression string + result uint64 + missing bool +}) { + converted, err := evalengine.Translate(expr, cfg) + require.NoError(t, err) - vc := &testVcursor{env: venv} - env := evalengine.NewExpressionEnv(context.Background(), nil, vc) + vc := &testVcursor{env: vtenv.NewTestEnv()} + env := evalengine.NewExpressionEnv(context.Background(), nil, vc) - _, err = env.EvaluateAST(converted) - require.NoError(t, err) - assert.Equal(t, tc.result, vc.lastInsertID) - }) + _, err = env.EvaluateAST(converted) + require.NoError(t, err) + if tc.missing { + require.Nil(t, vc.lastInsertID) + } else { + require.NotNil(t, vc.lastInsertID) + require.Equal(t, tc.result, *vc.lastInsertID) } } diff --git a/go/vt/vtgate/evalengine/fn_misc.go b/go/vt/vtgate/evalengine/fn_misc.go index e323d9a7dad..948383f5352 100644 --- a/go/vt/vtgate/evalengine/fn_misc.go +++ b/go/vt/vtgate/evalengine/fn_misc.go @@ -160,6 +160,7 @@ func (call *builtinInetNtoa) compile(c *compiler) (ctype, error) { c.compileToUint64(arg, 1) col := typedCoercionCollation(sqltypes.VarChar, call.collate) c.asm.Fn_INET_NTOA(col) + c.asm.jumpDestination(skip) return ctype{Type: sqltypes.VarChar, Flag: flagNullable, Col: col}, nil @@ -201,7 +202,11 @@ func (call *builtinInet6Aton) compile(c *compiler) (ctype, error) { func (call *builtinLastInsertID) eval(env *ExpressionEnv) (eval, error) { arg, err := call.arg1(env) - if arg == nil || err != nil { + if err != nil { + return nil, err + } + if arg == nil { + env.VCursor().SetLastInsertID(0) return nil, err } insertID := uint64(evalToInt64(arg).i) @@ -215,12 +220,19 @@ func (call *builtinLastInsertID) compile(c *compiler) (ctype, error) { return ctype{}, err } - skip := c.compileNullCheck1(arg) - + setZero := c.compileNullCheck1(arg) c.compileToUint64(arg, 1) - c.asm.Fn_LAST_INSERT_ID() + setLastInsertID := c.asm.jumpFrom() - c.asm.jumpDestination(skip) + c.asm.jumpDestination(setZero) + c.asm.emit(func(env *ExpressionEnv) int { + env.vm.stack[env.vm.sp] = env.vm.arena.newEvalUint64(0) + env.vm.sp++ + return 1 + }, "PUSH UINT64(0)") + + c.asm.jumpDestination(setLastInsertID) + c.asm.Fn_LAST_INSERT_ID() return ctype{Type: sqltypes.Uint64, Flag: flagNullable, Col: collationNumeric}, nil }