diff --git a/pkg/bindinfo/BUILD.bazel b/pkg/bindinfo/BUILD.bazel index 6e9946f7f1b74..aaff8c85961ca 100644 --- a/pkg/bindinfo/BUILD.bazel +++ b/pkg/bindinfo/BUILD.bazel @@ -57,7 +57,7 @@ go_test( embed = [":bindinfo"], flaky = True, race = "on", - shard_count = 43, + shard_count = 44, deps = [ "//pkg/bindinfo/internal", "//pkg/config", diff --git a/pkg/bindinfo/session_handle_test.go b/pkg/bindinfo/session_handle_test.go index 82633bb6c0eb5..fa8243c2e09d2 100644 --- a/pkg/bindinfo/session_handle_test.go +++ b/pkg/bindinfo/session_handle_test.go @@ -16,6 +16,7 @@ package bindinfo_test import ( "context" + "fmt" "strconv" "testing" "time" @@ -430,6 +431,25 @@ func TestDropSingleBindings(t *testing.T) { require.Len(t, rows, 0) } +func TestIssue53834(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec(`use test`) + tk.MustExec(`create table t (a varchar(1024))`) + tk.MustExec(`insert into t values (space(1024))`) + for i := 0; i < 12; i++ { + tk.MustExec(`insert into t select * from t`) + } + oomAction := tk.MustQuery(`select @@tidb_mem_oom_action`).Rows()[0][0].(string) + defer func() { + tk.MustExec(fmt.Sprintf(`set global tidb_mem_oom_action='%v'`, oomAction)) + }() + tk.MustExec(`set global tidb_mem_oom_action='cancel'`) + tk.MustExec(`create binding for replace into t select * from t using replace into t select /*+ memory_quota(1 mb) */ * from t`) + err := tk.ExecToErr(`replace into t select * from t`) + require.ErrorContains(t, err, "cancelled due to exceeding the allowed memory limit") +} + func TestPreparedStmt(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/pkg/planner/optimize.go b/pkg/planner/optimize.go index 1c5c1b6d16f54..00703f9d2c7b1 100644 --- a/pkg/planner/optimize.go +++ b/pkg/planner/optimize.go @@ -296,7 +296,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in } metrics.BindUsageCounter.WithLabelValues(scope).Inc() hint.BindHint(stmtNode, binding.Hint) - curStmtHints, _, curWarns := handleStmtHints(binding.Hint.GetFirstTableHints()) + curStmtHints, _, curWarns := handleStmtHints(binding.Hint.GetStmtHints()) sessVars.StmtCtx.StmtHints = curStmtHints // update session var by hint /set_var/ for name, val := range sessVars.StmtCtx.StmtHints.SetVars { diff --git a/pkg/util/hint/hint_processor.go b/pkg/util/hint/hint_processor.go index 6bb29dce4123a..e62245c59303b 100644 --- a/pkg/util/hint/hint_processor.go +++ b/pkg/util/hint/hint_processor.go @@ -45,12 +45,30 @@ type HintsSet struct { indexHints [][]*ast.IndexHint // Slice offset is the traversal order of `TableName` in the ast. } -// GetFirstTableHints gets the first table hints. -func (hs *HintsSet) GetFirstTableHints() []*ast.TableOptimizerHint { +// GetStmtHints gets all statement-level hints. +func (hs *HintsSet) GetStmtHints() []*ast.TableOptimizerHint { + var result []*ast.TableOptimizerHint if len(hs.tableHints) > 0 { - return hs.tableHints[0] + result = append(result, hs.tableHints[0]...) // keep the same behavior with prior implementation + } + for _, tHints := range hs.tableHints[1:] { + for _, h := range tHints { + if isStmtHint(h) { + result = append(result, h) + } + } + } + return result +} + +// isStmtHint checks whether this hint is a statement-level hint. +func isStmtHint(h *ast.TableOptimizerHint) bool { + switch h.HintName.L { + case "max_execution_time", "memory_quota", "resource_group": + return true + default: + return false } - return nil } // ContainTableHint checks whether the table hint set contains a hint.