diff --git a/br/pkg/lightning/backend/importer/importer_test.go b/br/pkg/lightning/backend/importer/importer_test.go index 5d75d1badc245..a0d7878ea5ecb 100644 --- a/br/pkg/lightning/backend/importer/importer_test.go +++ b/br/pkg/lightning/backend/importer/importer_test.go @@ -24,13 +24,13 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - . "github.com/pingcap/check" "github.com/pingcap/errors" kvpb "github.com/pingcap/kvproto/pkg/import_kvpb" "github.com/pingcap/tidb/br/pkg/lightning/backend" "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/mock" + "github.com/stretchr/testify/require" ) type importerSuite struct { @@ -43,20 +43,14 @@ type importerSuite struct { kvPairs kv.Rows } -var _ = Suite(&importerSuite{}) - const testPDAddr = "pd-addr:2379" -// FIXME: Cannot use the real SetUpTest/TearDownTest to set up the mock -// otherwise the mock error will be ignored. - -func (s *importerSuite) setUpTest(c *C) { - s.controller = gomock.NewController(c) - s.mockClient = mock.NewMockImportKVClient(s.controller) - s.mockWriter = mock.NewMockImportKV_WriteEngineClient(s.controller) - importer := NewMockImporter(s.mockClient, testPDAddr) - - s.ctx = context.Background() +func createImportSuite(t *testing.T) *importerSuite { + controller := gomock.NewController(t) + mockClient := mock.NewMockImportKVClient(controller) + mockWriter := mock.NewMockImportKV_WriteEngineClient(controller) + importer := NewMockImporter(mockClient, testPDAddr) + s := &importerSuite{controller: controller, mockClient: mockClient, mockWriter: mockWriter, ctx: context.Background()} engineUUID := uuid.MustParse("7e3f3a3c-67ce-506d-af34-417ec138fbcb") s.engineUUID = engineUUID[:] s.kvPairs = kv.MakeRowsFromKvPairs([]common.KvPair{ @@ -76,15 +70,17 @@ func (s *importerSuite) setUpTest(c *C) { var err error s.engine, err = importer.OpenEngine(s.ctx, &backend.EngineConfig{}, "`db`.`table`", -1) - c.Assert(err, IsNil) + require.NoError(t, err) + return s } func (s *importerSuite) tearDownTest() { s.controller.Finish() } -func (s *importerSuite) TestWriteRows(c *C) { - s.setUpTest(c) +func TestWriteRows(t *testing.T) { + t.Parallel() + s := createImportSuite(t) defer s.tearDownTest() s.mockClient.EXPECT().WriteEngine(s.ctx).Return(s.mockWriter, nil) @@ -99,10 +95,10 @@ func (s *importerSuite) TestWriteRows(c *C) { batchSendCall := s.mockWriter.EXPECT(). Send(gomock.Any()). DoAndReturn(func(x *kvpb.WriteEngineRequest) error { - c.Assert(x.GetBatch().GetMutations(), DeepEquals, []*kvpb.Mutation{ + require.Equal(t, []*kvpb.Mutation{ {Op: kvpb.Mutation_Put, Key: []byte("k1"), Value: []byte("v1")}, {Op: kvpb.Mutation_Put, Key: []byte("k2"), Value: []byte("v2")}, - }) + }, x.GetBatch().GetMutations()) return nil }). After(headSendCall) @@ -112,16 +108,17 @@ func (s *importerSuite) TestWriteRows(c *C) { After(batchSendCall) writer, err := s.engine.LocalWriter(s.ctx, nil) - c.Assert(err, IsNil) + require.NoError(t, err) err = writer.WriteRows(s.ctx, nil, s.kvPairs) - c.Assert(err, IsNil) + require.NoError(t, err) st, err := writer.Close(s.ctx) - c.Assert(err, IsNil) - c.Assert(st, IsNil) + require.NoError(t, err) + require.Nil(t, st) } -func (s *importerSuite) TestWriteHeadSendFailed(c *C) { - s.setUpTest(c) +func TestWriteHeadSendFailed(t *testing.T) { + t.Parallel() + s := createImportSuite(t) defer s.tearDownTest() s.mockClient.EXPECT().WriteEngine(s.ctx).Return(s.mockWriter, nil) @@ -129,7 +126,7 @@ func (s *importerSuite) TestWriteHeadSendFailed(c *C) { headSendCall := s.mockWriter.EXPECT(). Send(gomock.Any()). DoAndReturn(func(x *kvpb.WriteEngineRequest) error { - c.Assert(x.GetHead(), NotNil) + require.NotNil(t, x.GetHead()) return errors.Annotate(context.Canceled, "fake unrecoverable write head error") }) s.mockWriter.EXPECT(). @@ -138,13 +135,15 @@ func (s *importerSuite) TestWriteHeadSendFailed(c *C) { After(headSendCall) writer, err := s.engine.LocalWriter(s.ctx, nil) - c.Assert(err, IsNil) + require.NoError(t, err) err = writer.WriteRows(s.ctx, nil, s.kvPairs) - c.Assert(err, ErrorMatches, "fake unrecoverable write head error.*") + require.Error(t, err) + require.Regexp(t, "^fake unrecoverable write head error", err.Error()) } -func (s *importerSuite) TestWriteBatchSendFailed(c *C) { - s.setUpTest(c) +func TestWriteBatchSendFailed(t *testing.T) { + t.Parallel() + s := createImportSuite(t) defer s.tearDownTest() s.mockClient.EXPECT().WriteEngine(s.ctx).Return(s.mockWriter, nil) @@ -152,13 +151,13 @@ func (s *importerSuite) TestWriteBatchSendFailed(c *C) { headSendCall := s.mockWriter.EXPECT(). Send(gomock.Any()). DoAndReturn(func(x *kvpb.WriteEngineRequest) error { - c.Assert(x.GetHead(), NotNil) + require.NotNil(t, x.GetHead()) return nil }) batchSendCall := s.mockWriter.EXPECT(). Send(gomock.Any()). DoAndReturn(func(x *kvpb.WriteEngineRequest) error { - c.Assert(x.GetBatch(), NotNil) + require.NotNil(t, x.GetBatch()) return errors.Annotate(context.Canceled, "fake unrecoverable write batch error") }). After(headSendCall) @@ -168,13 +167,15 @@ func (s *importerSuite) TestWriteBatchSendFailed(c *C) { After(batchSendCall) writer, err := s.engine.LocalWriter(s.ctx, nil) - c.Assert(err, IsNil) + require.NoError(t, err) err = writer.WriteRows(s.ctx, nil, s.kvPairs) - c.Assert(err, ErrorMatches, "fake unrecoverable write batch error.*") + require.Error(t, err) + require.Regexp(t, "^fake unrecoverable write batch error", err.Error()) } -func (s *importerSuite) TestWriteCloseFailed(c *C) { - s.setUpTest(c) +func TestWriteCloseFailed(t *testing.T) { + t.Parallel() + s := createImportSuite(t) defer s.tearDownTest() s.mockClient.EXPECT().WriteEngine(s.ctx).Return(s.mockWriter, nil) @@ -182,13 +183,13 @@ func (s *importerSuite) TestWriteCloseFailed(c *C) { headSendCall := s.mockWriter.EXPECT(). Send(gomock.Any()). DoAndReturn(func(x *kvpb.WriteEngineRequest) error { - c.Assert(x.GetHead(), NotNil) + require.NotNil(t, x.GetHead()) return nil }) batchSendCall := s.mockWriter.EXPECT(). Send(gomock.Any()). DoAndReturn(func(x *kvpb.WriteEngineRequest) error { - c.Assert(x.GetBatch(), NotNil) + require.NotNil(t, x.GetBatch()) return nil }). After(headSendCall) @@ -198,13 +199,15 @@ func (s *importerSuite) TestWriteCloseFailed(c *C) { After(batchSendCall) writer, err := s.engine.LocalWriter(s.ctx, nil) - c.Assert(err, IsNil) + require.NoError(t, err) err = writer.WriteRows(s.ctx, nil, s.kvPairs) - c.Assert(err, ErrorMatches, "fake unrecoverable close stream error.*") + require.Error(t, err) + require.Regexp(t, "^fake unrecoverable close stream error", err.Error()) } -func (s *importerSuite) TestCloseImportCleanupEngine(c *C) { - s.setUpTest(c) +func TestCloseImportCleanupEngine(t *testing.T) { + t.Parallel() + s := createImportSuite(t) defer s.tearDownTest() s.mockClient.EXPECT(). @@ -218,11 +221,11 @@ func (s *importerSuite) TestCloseImportCleanupEngine(c *C) { Return(nil, nil) engine, err := s.engine.Close(s.ctx, nil) - c.Assert(err, IsNil) + require.NoError(t, err) err = engine.Import(s.ctx, 1) - c.Assert(err, IsNil) + require.NoError(t, err) err = engine.Cleanup(s.ctx) - c.Assert(err, IsNil) + require.NoError(t, err) } func BenchmarkMutationAlloc(b *testing.B) { @@ -261,33 +264,41 @@ func BenchmarkMutationPool(b *testing.B) { _ = g } -func (s *importerSuite) TestCheckTiDBVersion(c *C) { +func TestCheckTiDBVersion(t *testing.T) { var version string ctx := context.Background() mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - c.Assert(req.URL.Path, Equals, "/status") + require.Equal(t, "/status", req.URL.Path) w.WriteHeader(http.StatusOK) err := json.NewEncoder(w).Encode(map[string]interface{}{ "version": version, }) - c.Assert(err, IsNil) + require.NoError(t, err) })) tls := common.NewTLSFromMockServer(mockServer) version = "5.7.25-TiDB-v4.0.0" - c.Assert(checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion), IsNil) + require.Nil(t, checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion)) version = "5.7.25-TiDB-v9999.0.0" - c.Assert(checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion), ErrorMatches, "TiDB version too new.*") + err := checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion) + require.Error(t, err) + require.Regexp(t, "^TiDB version too new", err.Error()) version = "5.7.25-TiDB-v6.0.0" - c.Assert(checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion), ErrorMatches, "TiDB version too new.*") + err = checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion) + require.Error(t, err) + require.Regexp(t, "^TiDB version too new", err.Error()) version = "5.7.25-TiDB-v6.0.0-beta" - c.Assert(checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion), ErrorMatches, "TiDB version too new.*") + err = checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion) + require.Error(t, err) + require.Regexp(t, "^TiDB version too new", err.Error()) version = "5.7.25-TiDB-v1.0.0" - c.Assert(checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion), ErrorMatches, "TiDB version too old.*") + err = checkTiDBVersionByTLS(ctx, tls, requiredMinTiDBVersion, requiredMaxTiDBVersion) + require.Error(t, err) + require.Regexp(t, "^TiDB version too old", err.Error()) }