diff --git a/executor/seqtest/seq_executor_test.go b/executor/seqtest/seq_executor_test.go index 9eb337dc31a17..c5ee8df40bee3 100644 --- a/executor/seqtest/seq_executor_test.go +++ b/executor/seqtest/seq_executor_test.go @@ -1085,8 +1085,10 @@ func (c *checkPrioClient) SendRequest(ctx context.Context, addr string, req *tik if c.mu.checkPrio { switch req.Type { case tikvrpc.CmdCop: - if c.getCheckPriority() != req.Priority { - return nil, errors.New("fail to set priority") + if ctx.Value(c) != nil { + if c.getCheckPriority() != req.Priority { + return nil, errors.New("fail to set priority") + } } } } @@ -1100,6 +1102,8 @@ func TestCoprocessorPriority(t *testing.T) { return cli })) + ctx := context.WithValue(context.Background(), cli, 42) + tk := testkit.NewTestKit(t, store) tk.MustExec("use test") tk.MustExec("create table t (id int primary key)") @@ -1118,18 +1122,18 @@ func TestCoprocessorPriority(t *testing.T) { cli.mu.Unlock() cli.setCheckPriority(kvrpcpb.CommandPri_High) - tk.MustQuery("select id from t where id = 1") - tk.MustQuery("select * from t1 where id = 1") + tk.MustQueryWithContext(ctx, "select id from t where id = 1") + tk.MustQueryWithContext(ctx, "select * from t1 where id = 1") cli.setCheckPriority(kvrpcpb.CommandPri_Normal) - tk.MustQuery("select count(*) from t") - tk.MustExec("update t set id = 3") - tk.MustExec("delete from t") - tk.MustExec("insert into t select * from t limit 2") - tk.MustExec("delete from t") + tk.MustQueryWithContext(ctx, "select count(*) from t") + tk.MustExecWithContext(ctx, "update t set id = 3") + tk.MustExecWithContext(ctx, "delete from t") + tk.MustExecWithContext(ctx, "insert into t select * from t limit 2") + tk.MustExecWithContext(ctx, "delete from t") // Insert some data to make sure plan build IndexLookup for t. - tk.MustExec("insert into t values (1), (2)") + tk.MustExecWithContext(ctx, "insert into t values (1), (2)") defer config.RestoreFunc()() config.UpdateGlobal(func(conf *config.Config) { @@ -1137,47 +1141,46 @@ func TestCoprocessorPriority(t *testing.T) { }) cli.setCheckPriority(kvrpcpb.CommandPri_High) - tk.MustQuery("select id from t where id = 1") - tk.MustQuery("select * from t1 where id = 1") - tk.MustExec("delete from t where id = 2") - tk.MustExec("update t set id = 2 where id = 1") + tk.MustQueryWithContext(ctx, "select id from t where id = 1") + tk.MustQueryWithContext(ctx, "select * from t1 where id = 1") + tk.MustExecWithContext(ctx, "delete from t where id = 2") + tk.MustExecWithContext(ctx, "update t set id = 2 where id = 1") cli.setCheckPriority(kvrpcpb.CommandPri_Low) - tk.MustQuery("select count(*) from t") - tk.MustExec("delete from t") - tk.MustExec("insert into t values (3)") + tk.MustQueryWithContext(ctx, "select count(*) from t") + tk.MustExecWithContext(ctx, "delete from t") + tk.MustExecWithContext(ctx, "insert into t values (3)") // Test priority specified by SQL statement. cli.setCheckPriority(kvrpcpb.CommandPri_High) - tk.MustQuery("select HIGH_PRIORITY * from t") + tk.MustQueryWithContext(ctx, "select HIGH_PRIORITY * from t") cli.setCheckPriority(kvrpcpb.CommandPri_Low) - tk.MustQuery("select LOW_PRIORITY id from t where id = 1") + tk.MustQueryWithContext(ctx, "select LOW_PRIORITY id from t where id = 1") cli.setCheckPriority(kvrpcpb.CommandPri_High) - tk.MustExec("set tidb_force_priority = 'HIGH_PRIORITY'") - tk.MustQuery("select * from t").Check(testkit.Rows("3")) - tk.MustExec("update t set id = id + 1") - tk.MustQuery("select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) + tk.MustExecWithContext(ctx, "set tidb_force_priority = 'HIGH_PRIORITY'") + tk.MustQueryWithContext(ctx, "select * from t").Check(testkit.Rows("3")) + tk.MustExecWithContext(ctx, "update t set id = id + 1") + tk.MustQueryWithContext(ctx, "select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) cli.setCheckPriority(kvrpcpb.CommandPri_Low) - tk.MustExec("set tidb_force_priority = 'LOW_PRIORITY'") - tk.MustQuery("select * from t").Check(testkit.Rows("4")) - tk.MustExec("update t set id = id + 1") - tk.MustQuery("select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) + tk.MustExecWithContext(ctx, "set tidb_force_priority = 'LOW_PRIORITY'") + tk.MustQueryWithContext(ctx, "select * from t").Check(testkit.Rows("4")) + tk.MustExecWithContext(ctx, "update t set id = id + 1") + tk.MustQueryWithContext(ctx, "select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) cli.setCheckPriority(kvrpcpb.CommandPri_Normal) - tk.MustExec("set tidb_force_priority = 'DELAYED'") - tk.MustQuery("select * from t").Check(testkit.Rows("5")) - tk.MustExec("update t set id = id + 1") - tk.MustQuery("select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) + tk.MustExecWithContext(ctx, "set tidb_force_priority = 'DELAYED'") + tk.MustQueryWithContext(ctx, "select * from t").Check(testkit.Rows("5")) + tk.MustExecWithContext(ctx, "update t set id = id + 1") + tk.MustQueryWithContext(ctx, "select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) cli.setCheckPriority(kvrpcpb.CommandPri_Low) - tk.MustExec("set tidb_force_priority = 'NO_PRIORITY'") - tk.MustQuery("select * from t").Check(testkit.Rows("6")) - tk.MustExec("update t set id = id + 1") - tk.MustQuery("select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) - + tk.MustExecWithContext(ctx, "set tidb_force_priority = 'NO_PRIORITY'") + tk.MustQueryWithContext(ctx, "select * from t").Check(testkit.Rows("6")) + tk.MustExecWithContext(ctx, "update t set id = id + 1") + tk.MustQueryWithContext(ctx, "select v from t1 where id = 0 or id = 1").Check(testkit.Rows("0", "1")) cli.mu.Lock() cli.mu.checkPrio = false cli.mu.Unlock() diff --git a/testkit/testkit.go b/testkit/testkit.go index 3514305e8e05c..071cb60144393 100644 --- a/testkit/testkit.go +++ b/testkit/testkit.go @@ -103,7 +103,12 @@ func (tk *TestKit) Session() session.Session { // MustExec executes a sql statement and asserts nil error. func (tk *TestKit) MustExec(sql string, args ...interface{}) { - res, err := tk.Exec(sql, args...) + tk.MustExecWithContext(context.Background(), sql, args...) +} + +// MustExecWithContext executes a sql statement and asserts nil error. +func (tk *TestKit) MustExecWithContext(ctx context.Context, sql string, args ...interface{}) { + res, err := tk.ExecWithContext(ctx, sql, args...) comment := fmt.Sprintf("sql:%s, %v, error stack %v", sql, args, errors.ErrorStack(err)) tk.require.NoError(err, comment) @@ -115,11 +120,16 @@ func (tk *TestKit) MustExec(sql string, args ...interface{}) { // MustQuery query the statements and returns result rows. // If expected result is set it asserts the query result equals expected result. func (tk *TestKit) MustQuery(sql string, args ...interface{}) *Result { + return tk.MustQueryWithContext(context.Background(), sql, args...) +} + +// MustQueryWithContext query the statements and returns result rows. +func (tk *TestKit) MustQueryWithContext(ctx context.Context, sql string, args ...interface{}) *Result { comment := fmt.Sprintf("sql:%s, args:%v", sql, args) - rs, err := tk.Exec(sql, args...) + rs, err := tk.ExecWithContext(ctx, sql, args...) tk.require.NoError(err, comment) tk.require.NotNil(rs, comment) - return tk.ResultSetToResult(rs, comment) + return tk.ResultSetToResultWithCtx(ctx, rs, comment) } // MustIndexLookup checks whether the plan for the sql is IndexLookUp. @@ -233,7 +243,11 @@ func (tk *TestKit) HasPlan4ExplainFor(result *Result, plan string) bool { // Exec executes a sql statement using the prepared stmt API func (tk *TestKit) Exec(sql string, args ...interface{}) (sqlexec.RecordSet, error) { - ctx := context.Background() + return tk.ExecWithContext(context.Background(), sql, args...) +} + +// ExecWithContext executes a sql statement using the prepared stmt API +func (tk *TestKit) ExecWithContext(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) { if len(args) == 0 { sc := tk.session.GetSessionVars().StmtCtx prevWarns := sc.GetWarnings()