diff --git a/pkg/executor/checksum.go b/pkg/executor/checksum.go index 7a48671d8b346..e96e9dd5ad776 100644 --- a/pkg/executor/checksum.go +++ b/pkg/executor/checksum.go @@ -18,6 +18,7 @@ import ( "context" "strconv" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/distsql" "github.com/pingcap/tidb/pkg/executor/internal/exec" "github.com/pingcap/tidb/pkg/kv" @@ -129,6 +130,9 @@ func (e *ChecksumTableExec) checksumWorker(taskCh <-chan *checksumTask, resultCh } func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.ChecksumResponse, err error) { + if err = e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return nil, err + } ctx := distsql.WithSQLKvExecCounterInterceptor(context.TODO(), e.Ctx().GetSessionVars().StmtCtx) res, err := distsql.Checksum(ctx, e.Ctx().GetClient(), req, e.Ctx().GetSessionVars().KVVars) if err != nil { @@ -138,6 +142,7 @@ func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.C if err1 := res.Close(); err1 != nil { err = err1 } + failpoint.Inject("afterHandleChecksumRequest", nil) }() resp = &tipb.ChecksumResponse{} @@ -155,6 +160,9 @@ func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.C return nil, err } updateChecksumResponse(resp, checksum) + if err = e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return nil, err + } } return resp, nil diff --git a/pkg/executor/importer/BUILD.bazel b/pkg/executor/importer/BUILD.bazel index 2e7469e47d602..13909c3cc3d23 100644 --- a/pkg/executor/importer/BUILD.bazel +++ b/pkg/executor/importer/BUILD.bazel @@ -65,6 +65,7 @@ go_library( "//pkg/util/mathutil", "//pkg/util/promutil", "//pkg/util/sqlexec", + "//pkg/util/sqlkiller", "//pkg/util/stringutil", "//pkg/util/syncutil", "@com_github_docker_go_units//:go-units", diff --git a/pkg/executor/importer/importer_testkit_test.go b/pkg/executor/importer/importer_testkit_test.go index 91c7b9716ed4c..a83b98e0b2fbb 100644 --- a/pkg/executor/importer/importer_testkit_test.go +++ b/pkg/executor/importer/importer_testkit_test.go @@ -81,6 +81,45 @@ func TestVerifyChecksum(t *testing.T) { localChecksum = verify.MakeKVChecksum(1, 2, 1) err = importer.VerifyChecksum(ctx, plan, localChecksum, tk.Session(), logutil.BgLogger()) require.ErrorIs(t, err, common.ErrChecksumMismatch) + + // check a slow checksum can be canceled + plan2 := &importer.Plan{ + DBName: "db", + TableInfo: &model.TableInfo{ + Name: model.NewCIStr("tb2"), + }, + Checksum: config.OpLevelRequired, + } + tk.MustExec(` + create table db.tb2( + id int, + index idx1(id), + index idx2(id), + index idx3(id), + index idx4(id), + index idx5(id), + index idx6(id), + index idx7(id), + index idx8(id), + index idx9(id), + index idx10(id) + )`) + tk.MustExec("insert into db.tb2 values(1)") + backup, err := tk.Session().GetSessionVars().GetSessionOrGlobalSystemVar(ctx, variable.TiDBChecksumTableConcurrency) + require.NoError(t, err) + err = tk.Session().GetSessionVars().SetSystemVar(variable.TiDBChecksumTableConcurrency, "1") + require.NoError(t, err) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/afterHandleChecksumRequest", `sleep(1000)`)) + + ctx2, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err = importer.VerifyChecksum(ctx2, plan2, localChecksum, tk.Session(), logutil.BgLogger()) + require.ErrorContains(t, err, "Query execution was interrupted") + + err = tk.Session().GetSessionVars().SetSystemVar(variable.TiDBChecksumTableConcurrency, backup) + require.NoError(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/afterHandleChecksumRequest")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/importer/errWhenChecksum", `3*return(true)`)) defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/importer/errWhenChecksum")) diff --git a/pkg/executor/importer/table_import.go b/pkg/executor/importer/table_import.go index 9ad113992ec64..9e34897ecd548 100644 --- a/pkg/executor/importer/table_import.go +++ b/pkg/executor/importer/table_import.go @@ -58,6 +58,7 @@ import ( "github.com/pingcap/tidb/pkg/util/mathutil" "github.com/pingcap/tidb/pkg/util/promutil" "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/sqlkiller" "github.com/pingcap/tidb/pkg/util/syncutil" "github.com/prometheus/client_golang/prometheus" "github.com/tikv/client-go/v2/util" @@ -902,9 +903,21 @@ func checksumTable(ctx context.Context, se sessionctx.Context, plan *Plan, logge distSQLScanConcurrencyFactor = 1 remoteChecksum *local.RemoteChecksum txnErr error + doneCh = make(chan struct{}) ) + checkCtx, cancel := context.WithCancel(ctx) + defer func() { + cancel() + <-doneCh + }() + + go func() { + <-checkCtx.Done() + se.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted) + close(doneCh) + }() - ctx = util.WithInternalSourceType(ctx, tidbkv.InternalImportInto) + ctx = util.WithInternalSourceType(checkCtx, tidbkv.InternalImportInto) for i := 0; i < maxErrorRetryCount; i++ { txnErr = func() error { // increase backoff weight