diff --git a/pkg/ddl/BUILD.bazel b/pkg/ddl/BUILD.bazel index 296ac8df576a3..914731ca6d462 100644 --- a/pkg/ddl/BUILD.bazel +++ b/pkg/ddl/BUILD.bazel @@ -281,6 +281,7 @@ go_test( "//pkg/disttask/framework/proto", "//pkg/disttask/framework/scheduler", "//pkg/disttask/framework/storage", + "//pkg/disttask/operator", "//pkg/domain", "//pkg/domain/infosync", "//pkg/errctx", diff --git a/pkg/ddl/backfilling_dist_scheduler.go b/pkg/ddl/backfilling_dist_scheduler.go index 55abe4225da12..2e7f490cec106 100644 --- a/pkg/ddl/backfilling_dist_scheduler.go +++ b/pkg/ddl/backfilling_dist_scheduler.go @@ -294,7 +294,7 @@ func generateNonPartitionPlan( return true, nil } - regionBatch := calculateRegionBatch(len(recordRegionMetas), instanceCnt, !useCloud) + regionBatch := CalculateRegionBatch(len(recordRegionMetas), instanceCnt, !useCloud) for i := 0; i < len(recordRegionMetas); i += regionBatch { end := i + regionBatch @@ -329,7 +329,8 @@ func generateNonPartitionPlan( return subTaskMetas, nil } -func calculateRegionBatch(totalRegionCnt int, instanceCnt int, useLocalDisk bool) int { +// CalculateRegionBatch is exported for test. +func CalculateRegionBatch(totalRegionCnt int, instanceCnt int, useLocalDisk bool) int { failpoint.Inject("mockRegionBatch", func(val failpoint.Value) { failpoint.Return(val.(int)) }) diff --git a/pkg/ddl/backfilling_dist_scheduler_test.go b/pkg/ddl/backfilling_dist_scheduler_test.go index 3d08208317b2a..6e4f1f32c9264 100644 --- a/pkg/ddl/backfilling_dist_scheduler_test.go +++ b/pkg/ddl/backfilling_dist_scheduler_test.go @@ -113,19 +113,19 @@ func TestBackfillingSchedulerLocalMode(t *testing.T) { func TestCalculateRegionBatch(t *testing.T) { // Test calculate in cloud storage. - batchCnt := ddl.CalculateRegionBatchForTest(100, 8, false) + batchCnt := ddl.CalculateRegionBatch(100, 8, false) require.Equal(t, 13, batchCnt) - batchCnt = ddl.CalculateRegionBatchForTest(2, 8, false) + batchCnt = ddl.CalculateRegionBatch(2, 8, false) require.Equal(t, 1, batchCnt) - batchCnt = ddl.CalculateRegionBatchForTest(8, 8, false) + batchCnt = ddl.CalculateRegionBatch(8, 8, false) require.Equal(t, 1, batchCnt) // Test calculate in local storage. - batchCnt = ddl.CalculateRegionBatchForTest(100, 8, true) + batchCnt = ddl.CalculateRegionBatch(100, 8, true) require.Equal(t, 13, batchCnt) - batchCnt = ddl.CalculateRegionBatchForTest(2, 8, true) + batchCnt = ddl.CalculateRegionBatch(2, 8, true) require.Equal(t, 1, batchCnt) - batchCnt = ddl.CalculateRegionBatchForTest(24, 8, true) + batchCnt = ddl.CalculateRegionBatch(24, 8, true) require.Equal(t, 3, batchCnt) } diff --git a/pkg/ddl/bench_test.go b/pkg/ddl/bench_test.go index e0702e1ed367d..f778e4a5c6ec6 100644 --- a/pkg/ddl/bench_test.go +++ b/pkg/ddl/bench_test.go @@ -53,7 +53,7 @@ func BenchmarkExtractDatumByOffsets(b *testing.B) { endKey := startKey.PrefixNext() txn, err := store.Begin() require.NoError(b, err) - copChunk := ddl.FetchChunk4Test(copCtx, tbl.(table.PhysicalTable), startKey, endKey, store, 10) + copChunk, err := FetchChunk4Test(copCtx, tbl.(table.PhysicalTable), startKey, endKey, store, 10) require.NoError(b, err) require.NoError(b, txn.Rollback()) @@ -66,7 +66,7 @@ func BenchmarkExtractDatumByOffsets(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - ddl.ExtractDatumByOffsetsForTest(tk.Session().GetExprCtx().GetEvalCtx(), row, offsets, c.ExprColumnInfos, handleDataBuf) + ddl.ExtractDatumByOffsets(tk.Session().GetExprCtx().GetEvalCtx(), row, offsets, c.ExprColumnInfos, handleDataBuf) } } diff --git a/pkg/ddl/export_test.go b/pkg/ddl/export_test.go index d814000eff867..9e63c99c07be9 100644 --- a/pkg/ddl/export_test.go +++ b/pkg/ddl/export_test.go @@ -12,59 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. -package ddl +package ddl_test import ( "context" "time" "github.com/ngaut/pools" + "github.com/pingcap/tidb/pkg/ddl" "github.com/pingcap/tidb/pkg/ddl/copr" "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/ddl/testutil" + "github.com/pingcap/tidb/pkg/disttask/operator" "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" ) -type resultChanForTest struct { - ch chan IndexRecordChunk -} - -func (r *resultChanForTest) AddTask(rs IndexRecordChunk) { - r.ch <- rs -} - func FetchChunk4Test(copCtx copr.CopContext, tbl table.PhysicalTable, startKey, endKey kv.Key, store kv.Storage, - batchSize int) *chunk.Chunk { - variable.SetDDLReorgBatchSize(int32(batchSize)) - task := &reorgBackfillTask{ - id: 1, - startKey: startKey, - endKey: endKey, - physicalTable: tbl, - } - taskCh := make(chan *reorgBackfillTask, 5) - resultCh := make(chan IndexRecordChunk, 5) + batchSize int) (*chunk.Chunk, error) { resPool := pools.NewResourcePool(func() (pools.Resource, error) { ctx := mock.NewContext() ctx.Store = store return ctx, nil }, 8, 8, 0) sessPool := session.NewSessionPool(resPool) - pool := newCopReqSenderPool(context.Background(), copCtx, store, taskCh, sessPool, nil) - pool.chunkSender = &resultChanForTest{ch: resultCh} - pool.adjustSize(1) - pool.tasksCh <- task - rs := <-resultCh - close(taskCh) - pool.close(false) - sessPool.Close() - return rs.Chunk + srcChkPool := make(chan *chunk.Chunk, 10) + for i := 0; i < 10; i++ { + srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, batchSize) + } + opCtx := ddl.NewLocalOperatorCtx(context.Background(), 1) + src := testutil.NewOperatorTestSource(ddl.TableScanTask{1, startKey, endKey}) + scanOp := ddl.NewTableScanOperator(opCtx, sessPool, copCtx, srcChkPool, 1, nil) + sink := testutil.NewOperatorTestSink[ddl.IndexRecordChunk]() + + operator.Compose[ddl.TableScanTask](src, scanOp) + operator.Compose[ddl.IndexRecordChunk](scanOp, sink) + + pipeline := operator.NewAsyncPipeline(src, scanOp, sink) + err := pipeline.Execute() + if err != nil { + return nil, err + } + err = pipeline.Close() + if err != nil { + return nil, err + } + + results := sink.Collect() + return results[0].Chunk, nil } func ConvertRowToHandleAndIndexDatum( @@ -72,14 +72,8 @@ func ConvertRowToHandleAndIndexDatum( handleDataBuf, idxDataBuf []types.Datum, row chunk.Row, copCtx copr.CopContext, idxID int64) (kv.Handle, []types.Datum, error) { c := copCtx.GetBase() - idxData := extractDatumByOffsets(ctx, row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf) - handleData := extractDatumByOffsets(ctx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf) - handle, err := buildHandle(handleData, c.TableInfo, c.PrimaryKeyInfo, time.Local, errctx.StrictNoWarningContext) + idxData := ddl.ExtractDatumByOffsets(ctx, row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf) + handleData := ddl.ExtractDatumByOffsets(ctx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf) + handle, err := ddl.BuildHandle(handleData, c.TableInfo, c.PrimaryKeyInfo, time.Local, errctx.StrictNoWarningContext) return handle, idxData, err } - -// ExtractDatumByOffsetsForTest is used for test. -var ExtractDatumByOffsetsForTest = extractDatumByOffsets - -// CalculateRegionBatchForTest is used for test. -var CalculateRegionBatchForTest = calculateRegionBatch diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index d9b2c97d50903..b3a8af5bed61d 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -1710,20 +1710,20 @@ func writeChunkToLocal( restoreDataBuf = make([]types.Datum, len(c.HandleOutputOffsets)) } for row := iter.Begin(); row != iter.End(); row = iter.Next() { - handleDataBuf := extractDatumByOffsets(ectx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf) + handleDataBuf := ExtractDatumByOffsets(ectx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf) if restore { // restoreDataBuf should not truncate index values. for i, datum := range handleDataBuf { restoreDataBuf[i] = *datum.Clone() } } - h, err := buildHandle(handleDataBuf, c.TableInfo, c.PrimaryKeyInfo, loc, errCtx) + h, err := BuildHandle(handleDataBuf, c.TableInfo, c.PrimaryKeyInfo, loc, errCtx) if err != nil { return 0, nil, errors.Trace(err) } for i, index := range indexes { idxID := index.Meta().ID - idxDataBuf = extractDatumByOffsets(ectx, + idxDataBuf = ExtractDatumByOffsets(ectx, row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf) idxData := idxDataBuf[:len(index.Meta().Columns)] var rsData []types.Datum diff --git a/pkg/ddl/index_cop.go b/pkg/ddl/index_cop.go index 30d6f70c8b9e0..6a9c443b7e032 100644 --- a/pkg/ddl/index_cop.go +++ b/pkg/ddl/index_cop.go @@ -16,14 +16,10 @@ package ddl import ( "context" - "encoding/hex" - "sync" "time" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/ddl/copr" - "github.com/pingcap/tidb/pkg/ddl/ingest" sess "github.com/pingcap/tidb/pkg/ddl/session" "github.com/pingcap/tidb/pkg/distsql" distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" @@ -31,24 +27,18 @@ import ( "github.com/pingcap/tidb/pkg/expression" exprctx "github.com/pingcap/tidb/pkg/expression/context" "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/table/tables" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/codec" "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/timeutil" "github.com/pingcap/tipb/go-tipb" kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" ) // copReadBatchSize is the batch size of coprocessor read. @@ -65,121 +55,6 @@ func copReadChunkPoolSize() int { return 10 * int(variable.GetDDLReorgWorkerCounter()) } -// chunkSender is used to receive the result of coprocessor request. -type chunkSender interface { - AddTask(IndexRecordChunk) -} - -type copReqSenderPool struct { - tasksCh chan *reorgBackfillTask - chunkSender chunkSender - checkpointMgr *ingest.CheckpointManager - sessPool *sess.Pool - - ctx context.Context - copCtx copr.CopContext - store kv.Storage - - senders []*copReqSender - wg sync.WaitGroup - closed bool - - srcChkPool chan *chunk.Chunk -} - -type copReqSender struct { - senderPool *copReqSenderPool - - ctx context.Context - cancel context.CancelFunc -} - -func (c *copReqSender) run() { - p := c.senderPool - defer p.wg.Done() - defer util.Recover(metrics.LabelDDL, "copReqSender.run", func() { - p.chunkSender.AddTask(IndexRecordChunk{Err: dbterror.ErrReorgPanic}) - }, false) - sessCtx, err := p.sessPool.Get() - if err != nil { - logutil.Logger(p.ctx).Error("copReqSender get session from pool failed", zap.Error(err)) - p.chunkSender.AddTask(IndexRecordChunk{Err: err}) - return - } - se := sess.NewSession(sessCtx) - defer p.sessPool.Put(sessCtx) - var ( - task *reorgBackfillTask - ok bool - ) - - for { - select { - case <-c.ctx.Done(): - return - case task, ok = <-p.tasksCh: - } - if !ok { - return - } - if p.checkpointMgr != nil && p.checkpointMgr.IsKeyProcessed(task.endKey) { - logutil.Logger(p.ctx).Info("checkpoint detected, skip a cop-request task", - zap.Int("task ID", task.id), - zap.String("task end key", hex.EncodeToString(task.endKey))) - continue - } - err := scanRecords(p, task, se) - if err != nil { - p.chunkSender.AddTask(IndexRecordChunk{ID: task.id, Err: err}) - return - } - } -} - -func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) error { - logutil.Logger(p.ctx).Info("start a cop-request task", - zap.Int("id", task.id), zap.Stringer("task", task)) - - return wrapInBeginRollback(se, func(startTS uint64) error { - rs, err := buildTableScan(p.ctx, p.copCtx.GetBase(), startTS, task.startKey, task.endKey) - if err != nil { - return err - } - failpoint.Inject("mockCopSenderPanic", func(val failpoint.Value) { - if val.(bool) { - panic("mock panic") - } - }) - if p.checkpointMgr != nil { - p.checkpointMgr.Register(task.id, task.endKey) - } - var done bool - startTime := time.Now() - for !done { - srcChk := p.getChunk() - done, err = fetchTableScanResult(p.ctx, p.copCtx.GetBase(), rs, srcChk) - if err != nil { - p.recycleChunk(srcChk) - terror.Call(rs.Close) - return err - } - if p.checkpointMgr != nil { - p.checkpointMgr.UpdateTotalKeys(task.id, srcChk.NumRows(), done) - } - idxRs := IndexRecordChunk{ID: task.id, Chunk: srcChk, Done: done} - rate := float64(srcChk.MemoryUsage()) / 1024.0 / 1024.0 / time.Since(startTime).Seconds() - metrics.AddIndexScanRate.WithLabelValues(metrics.LblAddIndex).Observe(rate) - failpoint.Inject("mockCopSenderError", func() { - idxRs.Err = errors.New("mock cop error") - }) - p.chunkSender.AddTask(idxRs) - startTime = time.Now() - } - terror.Call(rs.Close) - return nil - }) -} - func wrapInBeginRollback(se *sess.Session, f func(startTS uint64) error) error { err := se.Begin(context.Background()) if err != nil { @@ -194,81 +69,6 @@ func wrapInBeginRollback(se *sess.Session, f func(startTS uint64) error) error { return f(startTS) } -func newCopReqSenderPool(ctx context.Context, copCtx copr.CopContext, store kv.Storage, - taskCh chan *reorgBackfillTask, sessPool *sess.Pool, - checkpointMgr *ingest.CheckpointManager) *copReqSenderPool { - poolSize := copReadChunkPoolSize() - srcChkPool := make(chan *chunk.Chunk, poolSize) - for i := 0; i < poolSize; i++ { - srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, copReadBatchSize()) - } - return &copReqSenderPool{ - tasksCh: taskCh, - ctx: ctx, - copCtx: copCtx, - store: store, - senders: make([]*copReqSender, 0, variable.GetDDLReorgWorkerCounter()), - wg: sync.WaitGroup{}, - srcChkPool: srcChkPool, - sessPool: sessPool, - checkpointMgr: checkpointMgr, - } -} - -func (c *copReqSenderPool) adjustSize(n int) { - // Add some senders. - for i := len(c.senders); i < n; i++ { - ctx, cancel := context.WithCancel(c.ctx) - c.senders = append(c.senders, &copReqSender{ - senderPool: c, - ctx: ctx, - cancel: cancel, - }) - c.wg.Add(1) - go c.senders[i].run() - } - // Remove some senders. - if n < len(c.senders) { - for i := n; i < len(c.senders); i++ { - c.senders[i].cancel() - } - c.senders = c.senders[:n] - } -} - -func (c *copReqSenderPool) close(force bool) { - if c.closed { - return - } - logutil.Logger(c.ctx).Info("close cop-request sender pool", zap.Bool("force", force)) - if force { - for _, w := range c.senders { - w.cancel() - } - } - // Wait for all cop-req senders to exit. - c.wg.Wait() - c.closed = true -} - -func (c *copReqSenderPool) getChunk() *chunk.Chunk { - chk := <-c.srcChkPool - newCap := copReadBatchSize() - if chk.Capacity() != newCap { - chk = chunk.NewChunkWithCapacity(c.copCtx.GetBase().FieldTypes, newCap) - } - chk.Reset() - return chk -} - -// recycleChunk puts the index record slice and the chunk back to the pool for reuse. -func (c *copReqSenderPool) recycleChunk(chk *chunk.Chunk) { - if chk == nil { - return - } - c.srcChkPool <- chk -} - func buildTableScan(ctx context.Context, c *copr.CopContextBase, startTS uint64, start, end kv.Key) (distsql.SelectResult, error) { dagPB, err := buildDAGPB(c.ExprCtx, c.DistSQLCtx, c.PushDownFlags, c.TableInfo, c.ColumnInfos) if err != nil { @@ -369,7 +169,8 @@ func constructTableScanPB(ctx exprctx.BuildContext, tblInfo *model.TableInfo, co return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err } -func extractDatumByOffsets(ctx expression.EvalContext, row chunk.Row, offsets []int, expCols []*expression.Column, buf []types.Datum) []types.Datum { +// ExtractDatumByOffsets is exported for test. +func ExtractDatumByOffsets(ctx expression.EvalContext, row chunk.Row, offsets []int, expCols []*expression.Column, buf []types.Datum) []types.Datum { for i, offset := range offsets { c := expCols[offset] row.DatumWithBuffer(offset, c.GetType(ctx), &buf[i]) @@ -377,7 +178,8 @@ func extractDatumByOffsets(ctx expression.EvalContext, row chunk.Row, offsets [] return buf } -func buildHandle(pkDts []types.Datum, tblInfo *model.TableInfo, +// BuildHandle is exported for test. +func BuildHandle(pkDts []types.Datum, tblInfo *model.TableInfo, pkInfo *model.IndexInfo, loc *time.Location, errCtx errctx.Context) (kv.Handle, error) { if tblInfo.IsCommonHandle { tablecodec.TruncateIndexValues(tblInfo, pkInfo, pkDts) diff --git a/pkg/ddl/index_cop_test.go b/pkg/ddl/index_cop_test.go index 28f07752ab6ed..7c2f8fb2c15d3 100644 --- a/pkg/ddl/index_cop_test.go +++ b/pkg/ddl/index_cop_test.go @@ -50,7 +50,7 @@ func TestAddIndexFetchRowsFromCoprocessor(t *testing.T) { endKey := startKey.PrefixNext() txn, err := store.Begin() require.NoError(t, err) - copChunk := ddl.FetchChunk4Test(copCtx, tbl.(table.PhysicalTable), startKey, endKey, store, 10) + copChunk, err := FetchChunk4Test(copCtx, tbl.(table.PhysicalTable), startKey, endKey, store, 10) require.NoError(t, err) require.NoError(t, txn.Rollback()) @@ -61,7 +61,7 @@ func TestAddIndexFetchRowsFromCoprocessor(t *testing.T) { idxDataBuf := make([]types.Datum, len(idxInfo.Columns)) for row := iter.Begin(); row != iter.End(); row = iter.Next() { - handle, idxDatum, err := ddl.ConvertRowToHandleAndIndexDatum(tk.Session().GetExprCtx().GetEvalCtx(), handleDataBuf, idxDataBuf, row, copCtx, idxInfo.ID) + handle, idxDatum, err := ConvertRowToHandleAndIndexDatum(tk.Session().GetExprCtx().GetEvalCtx(), handleDataBuf, idxDataBuf, row, copCtx, idxInfo.ID) require.NoError(t, err) handles = append(handles, handle) copiedIdxDatum := make([]types.Datum, len(idxDatum)) diff --git a/pkg/ddl/testutil/BUILD.bazel b/pkg/ddl/testutil/BUILD.bazel index 4cfa751c86c17..cee0cfee17fcc 100644 --- a/pkg/ddl/testutil/BUILD.bazel +++ b/pkg/ddl/testutil/BUILD.bazel @@ -2,11 +2,15 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "testutil", - srcs = ["testutil.go"], + srcs = [ + "operator.go", + "testutil.go", + ], importpath = "github.com/pingcap/tidb/pkg/ddl/testutil", visibility = ["//visibility:public"], deps = [ "//pkg/ddl/logutil", + "//pkg/disttask/operator", "//pkg/domain", "//pkg/kv", "//pkg/parser/model", @@ -18,6 +22,7 @@ go_library( "//pkg/types", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", + "@org_golang_x_sync//errgroup", "@org_uber_go_zap//:zap", ], ) diff --git a/pkg/ddl/testutil/operator.go b/pkg/ddl/testutil/operator.go new file mode 100644 index 0000000000000..90ba745bc387d --- /dev/null +++ b/pkg/ddl/testutil/operator.go @@ -0,0 +1,107 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "github.com/pingcap/tidb/pkg/disttask/operator" + "golang.org/x/sync/errgroup" +) + +// OperatorTestSource is used for dist task operator test. +type OperatorTestSource[T any] struct { + errGroup errgroup.Group + ch chan T + toBeSent []T +} + +// NewOperatorTestSource creates a new OperatorTestSource. +func NewOperatorTestSource[T any](toBeSent ...T) *OperatorTestSource[T] { + return &OperatorTestSource[T]{ + ch: make(chan T), + toBeSent: toBeSent, + } +} + +// SetSink implements disttask/operator.Operator. +func (s *OperatorTestSource[T]) SetSink(sink operator.DataChannel[T]) { + s.ch = sink.Channel() +} + +// Open implements disttask/operator.Operator. +func (s *OperatorTestSource[T]) Open() error { + s.errGroup.Go(func() error { + for _, data := range s.toBeSent { + s.ch <- data + } + close(s.ch) + return nil + }) + return nil +} + +// Close implements disttask/operator.Operator. +func (s *OperatorTestSource[T]) Close() error { + return s.errGroup.Wait() +} + +// String implements disttask/operator.Operator. +func (*OperatorTestSource[T]) String() string { + return "testSource" +} + +// OperatorTestSink is used for dist task operator test. +type OperatorTestSink[T any] struct { + errGroup errgroup.Group + ch chan T + collected []T +} + +// NewOperatorTestSink creates a new OperatorTestSink. +func NewOperatorTestSink[T any]() *OperatorTestSink[T] { + return &OperatorTestSink[T]{ + ch: make(chan T), + } +} + +// Open implements disttask/operator.Operator. +func (s *OperatorTestSink[T]) Open() error { + s.errGroup.Go(func() error { + for data := range s.ch { + s.collected = append(s.collected, data) + } + return nil + }) + return nil +} + +// Close implements disttask/operator.Operator. +func (s *OperatorTestSink[T]) Close() error { + return s.errGroup.Wait() +} + +// SetSource implements disttask/operator.Operator. +func (s *OperatorTestSink[T]) SetSource(dataCh operator.DataChannel[T]) { + s.ch = dataCh.Channel() +} + +// String implements disttask/operator.Operator. +func (*OperatorTestSink[T]) String() string { + return "testSink" +} + +// Collect the result from OperatorTestSink. +func (s *OperatorTestSink[T]) Collect() []T { + return s.collected +} diff --git a/tests/realtikvtest/addindextest3/BUILD.bazel b/tests/realtikvtest/addindextest3/BUILD.bazel index 72ee42fdcc780..ac0fcc528e572 100644 --- a/tests/realtikvtest/addindextest3/BUILD.bazel +++ b/tests/realtikvtest/addindextest3/BUILD.bazel @@ -36,6 +36,5 @@ go_test( "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_client_go_v2//util", - "@org_golang_x_sync//errgroup", ], ) diff --git a/tests/realtikvtest/addindextest3/operator_test.go b/tests/realtikvtest/addindextest3/operator_test.go index 1b941850011d8..74bb4a003a7b6 100644 --- a/tests/realtikvtest/addindextest3/operator_test.go +++ b/tests/realtikvtest/addindextest3/operator_test.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/pkg/ddl" "github.com/pingcap/tidb/pkg/ddl/copr" "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/testutil" "github.com/pingcap/tidb/pkg/disttask/operator" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/kv" @@ -38,7 +39,6 @@ import ( "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/tests/realtikvtest" "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" ) func init() { @@ -61,7 +61,7 @@ func TestBackfillOperators(t *testing.T) { opCtx := ddl.NewDistTaskOperatorCtx(ctx, 1, 1) pTbl := tbl.(table.PhysicalTable) src := ddl.NewTableScanTaskSource(opCtx, store, pTbl, startKey, endKey, nil) - sink := newTestSink[ddl.TableScanTask]() + sink := testutil.NewOperatorTestSink[ddl.TableScanTask]() operator.Compose[ddl.TableScanTask](src, sink) @@ -71,7 +71,7 @@ func TestBackfillOperators(t *testing.T) { err = pipeline.Close() require.NoError(t, err) - tasks := sink.collect() + tasks := sink.Collect() require.Len(t, tasks, 10) require.Equal(t, 1, tasks[0].ID) require.Equal(t, startKey, tasks[0].Start) @@ -94,9 +94,9 @@ func TestBackfillOperators(t *testing.T) { ctx := context.Background() opCtx := ddl.NewDistTaskOperatorCtx(ctx, 1, 1) - src := newTestSource(opTasks...) + src := testutil.NewOperatorTestSource(opTasks...) scanOp := ddl.NewTableScanOperator(opCtx, sessPool, copCtx, srcChkPool, 3, nil) - sink := newTestSink[ddl.IndexRecordChunk]() + sink := testutil.NewOperatorTestSink[ddl.IndexRecordChunk]() operator.Compose[ddl.TableScanTask](src, scanOp) operator.Compose[ddl.IndexRecordChunk](scanOp, sink) @@ -107,7 +107,7 @@ func TestBackfillOperators(t *testing.T) { err = pipeline.Close() require.NoError(t, err) - results := sink.collect() + results := sink.Collect() cnt := 0 for _, rs := range results { require.NoError(t, rs.Err) @@ -140,12 +140,12 @@ func TestBackfillOperators(t *testing.T) { mockEngine := ingest.NewMockEngineInfo(nil) mockEngine.SetHook(onWrite) - src := newTestSource(chunkResults...) + src := testutil.NewOperatorTestSource(chunkResults...) reorgMeta := ddl.NewDDLReorgMeta(tk.Session()) ingestOp := ddl.NewIndexIngestOperator( opCtx, copCtx, mockBackendCtx, sessPool, pTbl, []table.Index{index}, []ingest.Engine{mockEngine}, srcChkPool, 3, reorgMeta, nil, &ddl.EmptyRowCntListener{}) - sink := newTestSink[ddl.IndexWriteResult]() + sink := testutil.NewOperatorTestSink[ddl.IndexWriteResult]() operator.Compose[ddl.IndexRecordChunk](src, ingestOp) operator.Compose[ddl.IndexWriteResult](ingestOp, sink) @@ -156,7 +156,7 @@ func TestBackfillOperators(t *testing.T) { err = pipeline.Close() require.NoError(t, err) - results := sink.collect() + results := sink.Collect() cnt := 0 for _, rs := range results { cnt += rs.Added @@ -368,77 +368,3 @@ func (p *sessPoolForTest) Get() (sessionctx.Context, error) { func (p *sessPoolForTest) Put(sctx sessionctx.Context) { p.pool.Put(sctx.(pools.Resource)) } - -type testSink[T any] struct { - errGroup errgroup.Group - ch chan T - collected []T -} - -func newTestSink[T any]() *testSink[T] { - return &testSink[T]{ - ch: make(chan T), - } -} - -func (s *testSink[T]) Open() error { - s.errGroup.Go(func() error { - for data := range s.ch { - s.collected = append(s.collected, data) - } - return nil - }) - return nil -} - -func (s *testSink[T]) Close() error { - return s.errGroup.Wait() -} - -func (s *testSink[T]) SetSource(dataCh operator.DataChannel[T]) { - s.ch = dataCh.Channel() -} - -func (s *testSink[T]) String() string { - return "testSink" -} - -func (s *testSink[T]) collect() []T { - return s.collected -} - -type testSource[T any] struct { - errGroup errgroup.Group - ch chan T - toBeSent []T -} - -func newTestSource[T any](toBeSent ...T) *testSource[T] { - return &testSource[T]{ - ch: make(chan T), - toBeSent: toBeSent, - } -} - -func (s *testSource[T]) SetSink(sink operator.DataChannel[T]) { - s.ch = sink.Channel() -} - -func (s *testSource[T]) Open() error { - s.errGroup.Go(func() error { - for _, data := range s.toBeSent { - s.ch <- data - } - close(s.ch) - return nil - }) - return nil -} - -func (s *testSource[T]) Close() error { - return s.errGroup.Wait() -} - -func (s *testSource[T]) String() string { - return "testSource" -}