diff --git a/cdc/sink/ddlsink/mysql/async_ddl_test.go b/cdc/sink/ddlsink/mysql/async_ddl_test.go index 91a6b21a3e3..871080c2eb5 100644 --- a/cdc/sink/ddlsink/mysql/async_ddl_test.go +++ b/cdc/sink/ddlsink/mysql/async_ddl_test.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "net/url" - "sync/atomic" "testing" "time" @@ -33,18 +32,8 @@ import ( ) func TestWaitAsynExecDone(t *testing.T) { - var dbIndex int32 = 0 - GetDBConnImpl = func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { - atomic.AddInt32(&dbIndex, 1) - }() - if atomic.LoadInt32(&dbIndex) == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) require.Nil(t, err) mock.ExpectQuery("select tidb_version()"). @@ -71,7 +60,8 @@ func TestWaitAsynExecDone(t *testing.T) { mock.ExpectClose() return db, nil - } + }) + GetDBConnImpl = dbConnFactory ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -116,18 +106,8 @@ func TestWaitAsynExecDone(t *testing.T) { func TestAsyncExecAddIndex(t *testing.T) { ddlExecutionTime := time.Second * 15 - var dbIndex int32 = 0 - GetDBConnImpl = func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { - atomic.AddInt32(&dbIndex, 1) - }() - if atomic.LoadInt32(&dbIndex) == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) require.Nil(t, err) mock.ExpectQuery("select tidb_version()"). @@ -145,7 +125,8 @@ func TestAsyncExecAddIndex(t *testing.T) { mock.ExpectCommit() mock.ExpectClose() return db, nil - } + }) + GetDBConnImpl = dbConnFactory ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/cdc/sink/ddlsink/mysql/mysql_ddl_sink.go b/cdc/sink/ddlsink/mysql/mysql_ddl_sink.go index 6d66a90e410..c159e159a53 100644 --- a/cdc/sink/ddlsink/mysql/mysql_ddl_sink.go +++ b/cdc/sink/ddlsink/mysql/mysql_ddl_sink.go @@ -44,9 +44,9 @@ const ( networkDriftDuration = 5 * time.Second ) -// GetDBConnImpl is the implementation of pmysql.Factory. +// GetDBConnImpl is the implementation of pmysql.IDBConnectionFactory. // Exported for testing. -var GetDBConnImpl pmysql.Factory = pmysql.CreateMySQLDBConn +var GetDBConnImpl pmysql.IDBConnectionFactory = &pmysql.DBConnectionFactory{} // Assert Sink implementation var _ ddlsink.Sink = (*DDLSink)(nil) @@ -81,12 +81,12 @@ func NewDDLSink( return nil, err } - dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, GetDBConnImpl) + dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, GetDBConnImpl.CreateTemporaryConnection) if err != nil { return nil, err } - db, err := GetDBConnImpl(ctx, dsnStr) + db, err := GetDBConnImpl.CreateStandardConnection(ctx, dsnStr) if err != nil { return nil, err } diff --git a/cdc/sink/ddlsink/mysql/mysql_ddl_sink_test.go b/cdc/sink/ddlsink/mysql/mysql_ddl_sink_test.go index 6174ac4fc1b..43936af146f 100644 --- a/cdc/sink/ddlsink/mysql/mysql_ddl_sink_test.go +++ b/cdc/sink/ddlsink/mysql/mysql_ddl_sink_test.go @@ -30,18 +30,8 @@ import ( ) func TestWriteDDLEvent(t *testing.T) { - dbIndex := 0 - GetDBConnImpl = func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { - dbIndex++ - }() - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) require.Nil(t, err) mock.ExpectQuery("select tidb_version()"). @@ -66,7 +56,8 @@ func TestWriteDDLEvent(t *testing.T) { mock.ExpectRollback() mock.ExpectClose() return db, nil - } + }) + GetDBConnImpl = dbConnFactory ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/cdc/sink/dmlsink/txn/mysql/mysql.go b/cdc/sink/dmlsink/txn/mysql/mysql.go index dea8debe570..ac11a198014 100644 --- a/cdc/sink/dmlsink/txn/mysql/mysql.go +++ b/cdc/sink/dmlsink/txn/mysql/mysql.go @@ -86,7 +86,7 @@ func NewMySQLBackends( changefeedID model.ChangeFeedID, sinkURI *url.URL, replicaConfig *config.ReplicaConfig, - dbConnFactory pmysql.Factory, + dbConnFactory pmysql.IDBConnectionFactory, statistics *metrics.Statistics, ) ([]*mysqlBackend, error) { changefeed := fmt.Sprintf("%s.%s", changefeedID.Namespace, changefeedID.ID) @@ -97,12 +97,12 @@ func NewMySQLBackends( return nil, err } - dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, dbConnFactory) + dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, dbConnFactory.CreateTemporaryConnection) if err != nil { return nil, err } - db, err := dbConnFactory(ctx, dsnStr) + db, err := dbConnFactory.CreateStandardConnection(ctx, dsnStr) if err != nil { return nil, err } diff --git a/cdc/sink/dmlsink/txn/mysql/mysql_test.go b/cdc/sink/dmlsink/txn/mysql/mysql_test.go index f0a9abb99a7..ce0adb637aa 100644 --- a/cdc/sink/dmlsink/txn/mysql/mysql_test.go +++ b/cdc/sink/dmlsink/txn/mysql/mysql_test.go @@ -68,7 +68,7 @@ func newMySQLBackend( changefeedID model.ChangeFeedID, sinkURI *url.URL, replicaConfig *config.ReplicaConfig, - dbConnFactory pmysql.Factory, + dbConnFactory pmysql.IDBConnectionFactory, ) (*mysqlBackend, error) { ctx1, cancel := context.WithCancel(ctx) statistics := metrics.NewStatistics(ctx1, changefeedID, sink.TxnSink) @@ -205,29 +205,19 @@ func TestAdjustSQLMode(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - dbIndex := 0 - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectClose() return db, nil - } + }) changefeed := "test-changefeed" sinkURI, err := url.Parse("mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1" + "&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), - sinkURI, config.GetDefaultReplicaConfig(), mockGetDBConn) + sinkURI, config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) require.Nil(t, sink.Close()) } @@ -290,24 +280,14 @@ func TestNewMySQLTimeout(t *testing.T) { sinkURI, err := url.Parse(fmt.Sprintf("mysql://%s/?read-timeout=1s&timeout=1s", addr)) require.Nil(t, err) _, err = newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), pmysql.CreateMySQLDBConn) + config.GetDefaultReplicaConfig(), &pmysql.DBConnectionFactory{}) require.Equal(t, driver.ErrBadConn, errors.Cause(err)) } // Test OnTxnEvent and Flush interfaces. Event callbacks should be called correctly after flush. func TestNewMySQLBackendExecDML(t *testing.T) { - dbIndex := 0 - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectBegin() mock.ExpectExec("INSERT INTO `s1`.`t1` (`a`,`b`) VALUES (?,?),(?,?)"). @@ -316,7 +296,7 @@ func TestNewMySQLBackendExecDML(t *testing.T) { mock.ExpectCommit() mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -329,7 +309,7 @@ func TestNewMySQLBackendExecDML(t *testing.T) { "mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConn) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) tableInfo := model.BuildTableInfo("s1", "t1", []*model.Column{ @@ -433,18 +413,8 @@ func TestExecDMLRollbackErrDatabaseNotExists(t *testing.T) { Number: uint16(infoschema.ErrDatabaseNotExists.Code()), } - dbIndex := 0 - mockGetDBConnErrDatabaseNotExists := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1` (`a`) VALUES (?),(?)"). @@ -453,7 +423,7 @@ func TestExecDMLRollbackErrDatabaseNotExists(t *testing.T) { mock.ExpectRollback() mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -462,7 +432,7 @@ func TestExecDMLRollbackErrDatabaseNotExists(t *testing.T) { "mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConnErrDatabaseNotExists) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) _ = sink.OnTxnEvent(&dmlsink.TxnCallbackableEvent{ @@ -510,18 +480,8 @@ func TestExecDMLRollbackErrTableNotExists(t *testing.T) { Number: uint16(infoschema.ErrTableNotExists.Code()), } - dbIndex := 0 - mockGetDBConnErrDatabaseNotExists := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectBegin() mock.ExpectExec("REPLACE INTO `s1`.`t1` (`a`) VALUES (?),(?)"). @@ -530,7 +490,7 @@ func TestExecDMLRollbackErrTableNotExists(t *testing.T) { mock.ExpectRollback() mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -539,7 +499,7 @@ func TestExecDMLRollbackErrTableNotExists(t *testing.T) { "mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConnErrDatabaseNotExists) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) _ = sink.OnTxnEvent(&dmlsink.TxnCallbackableEvent{ @@ -587,18 +547,8 @@ func TestExecDMLRollbackErrRetryable(t *testing.T) { Number: mysql.ErrLockDeadlock, } - dbIndex := 0 - mockGetDBConnErrDatabaseNotExists := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) for i := 0; i < 2; i++ { mock.ExpectBegin() @@ -609,7 +559,7 @@ func TestExecDMLRollbackErrRetryable(t *testing.T) { } mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -618,7 +568,7 @@ func TestExecDMLRollbackErrRetryable(t *testing.T) { "mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConnErrDatabaseNotExists) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) sink.setDMLMaxRetry(2) @@ -656,18 +606,8 @@ func TestMysqlSinkNotRetryErrDupEntry(t *testing.T) { }, } - dbIndex := 0 - mockDBInsertDupEntry := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectBegin() mock.ExpectExec("INSERT INTO `s1`.`t1` (`a`) VALUES (?)"). @@ -677,7 +617,7 @@ func TestMysqlSinkNotRetryErrDupEntry(t *testing.T) { WillReturnError(errDup) mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -687,7 +627,7 @@ func TestMysqlSinkNotRetryErrDupEntry(t *testing.T) { "&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockDBInsertDupEntry) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) sink.setDMLMaxRetry(1) _ = sink.OnTxnEvent(&dmlsink.TxnCallbackableEvent{ @@ -708,22 +648,12 @@ func TestNeedSwitchDB(t *testing.T) { } func TestNewMySQLBackend(t *testing.T) { - dbIndex := 0 - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -733,7 +663,7 @@ func TestNewMySQLBackend(t *testing.T) { "&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConn) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) require.Nil(t, sink.Close()) @@ -742,23 +672,12 @@ func TestNewMySQLBackend(t *testing.T) { } func TestNewMySQLBackendWithIPv6Address(t *testing.T) { - dbIndex := 0 - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - require.Contains(t, dsnStr, "root@tcp([::1]:3306)") - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -768,28 +687,18 @@ func TestNewMySQLBackendWithIPv6Address(t *testing.T) { "&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConn) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) require.Nil(t, sink.Close()) } func TestGBKSupported(t *testing.T) { - dbIndex := 0 - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectClose() return db, nil - } + }) zapcore, logs := observer.New(zap.WarnLevel) conf := &log.Config{Level: "warn", File: log.FileLogConfig{}} @@ -804,7 +713,7 @@ func TestGBKSupported(t *testing.T) { "&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConn) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) // no gbk-related warning log will be output because GBK charset is supported @@ -834,25 +743,15 @@ func TestHolderString(t *testing.T) { } func TestMySQLSinkExecDMLError(t *testing.T) { - dbIndex := 0 - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { dbIndex++ }() - - if dbIndex == 0 { - // test db - db, err := pmysql.MockTestDB() - require.Nil(t, err) - return db, nil - } - - // normal db + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { db, mock := newTestMockDB(t) mock.ExpectBegin() mock.ExpectExec("INSERT INTO `s1`.`t1` (`a`,`b`) VALUES (?,?),(?,?)").WillDelayFor(1 * time.Second). WillReturnError(&dmysql.MySQLError{Number: mysql.ErrNoSuchTable}) mock.ExpectClose() return db, nil - } + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -861,7 +760,7 @@ func TestMySQLSinkExecDMLError(t *testing.T) { "mysql://127.0.0.1:4000/?time-zone=UTC&worker-count=1&cache-prep-stmts=false") require.Nil(t, err) sink, err := newMySQLBackend(ctx, model.DefaultChangeFeedID(changefeed), sinkURI, - config.GetDefaultReplicaConfig(), mockGetDBConn) + config.GetDefaultReplicaConfig(), dbConnFactory) require.Nil(t, err) tableInfo := model.BuildTableInfo("s1", "t1", []*model.Column{ diff --git a/cdc/sink/dmlsink/txn/txn_dml_sink.go b/cdc/sink/dmlsink/txn/txn_dml_sink.go index f5b2cb78edb..3354de9f683 100644 --- a/cdc/sink/dmlsink/txn/txn_dml_sink.go +++ b/cdc/sink/dmlsink/txn/txn_dml_sink.go @@ -58,11 +58,9 @@ type dmlSink struct { scheme string } -// GetDBConnImpl is the implementation of pmysql.Factory. +// GetDBConnImpl is the implementation of pmysql.IDBConnectionFactory. // Exported for testing. -// Maybe we can use a better way to do this. Because this is not thread-safe. -// You can use `SetupSuite` and `TearDownSuite` to do this to get a better way. -var GetDBConnImpl pmysql.Factory = pmysql.CreateMySQLDBConn +var GetDBConnImpl pmysql.IDBConnectionFactory = &pmysql.DBConnectionFactory{} // NewMySQLSink creates a mysql dmlSink with given parameters. func NewMySQLSink( diff --git a/errors.toml b/errors.toml index d316f1ab9c4..bfa826c2a6b 100755 --- a/errors.toml +++ b/errors.toml @@ -171,6 +171,11 @@ error = ''' TiCDC cluster is unhealthy ''' +["CDC:ErrCodeNilFunction"] +error = ''' +function is not initialized +''' + ["CDC:ErrCodecDecode"] error = ''' codec decode error diff --git a/pkg/applier/redo_test.go b/pkg/applier/redo_test.go index 01615bbf7d1..2dc0582e2b8 100644 --- a/pkg/applier/redo_test.go +++ b/pkg/applier/redo_test.go @@ -100,25 +100,17 @@ func TestApply(t *testing.T) { return NewMockReader(checkpointTs, resolvedTs, redoLogCh, ddlEventCh), nil } - dbIndex := 0 // DML sink and DDL sink share the same db db := getMockDB(t) - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { - dbIndex++ - }() - if dbIndex%2 == 0 { - testDB, err := pmysql.MockTestDB() - require.Nil(t, err) - return testDB, nil - } + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { return db, nil - } + }) getDMLDBConnBak := txn.GetDBConnImpl - txn.GetDBConnImpl = mockGetDBConn + txn.GetDBConnImpl = dbConnFactory getDDLDBConnBak := mysqlDDL.GetDBConnImpl - mysqlDDL.GetDBConnImpl = mockGetDBConn + mysqlDDL.GetDBConnImpl = dbConnFactory createRedoReaderBak := createRedoReader createRedoReader = createMockReader defer func() { @@ -322,25 +314,17 @@ func TestApplyBigTxn(t *testing.T) { return NewMockReader(checkpointTs, resolvedTs, redoLogCh, ddlEventCh), nil } - dbIndex := 0 // DML sink and DDL sink share the same db db := getMockDBForBigTxn(t) - mockGetDBConn := func(ctx context.Context, dsnStr string) (*sql.DB, error) { - defer func() { - dbIndex++ - }() - if dbIndex%2 == 0 { - testDB, err := pmysql.MockTestDB() - require.Nil(t, err) - return testDB, nil - } + dbConnFactory := pmysql.NewDBConnectionFactoryForTest() + dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { return db, nil - } + }) getDMLDBConnBak := txn.GetDBConnImpl - txn.GetDBConnImpl = mockGetDBConn + txn.GetDBConnImpl = dbConnFactory getDDLDBConnBak := mysqlDDL.GetDBConnImpl - mysqlDDL.GetDBConnImpl = mockGetDBConn + mysqlDDL.GetDBConnImpl = dbConnFactory createRedoReaderBak := createRedoReader createRedoReader = createMockReader defer func() { diff --git a/pkg/errors/cdc_errors.go b/pkg/errors/cdc_errors.go index c87c0ddd6f6..f3c0f9eeeaa 100644 --- a/pkg/errors/cdc_errors.go +++ b/pkg/errors/cdc_errors.go @@ -140,6 +140,11 @@ var ( "stop processor by admin command", errors.RFCCodeText("CDC:ErrAdminStopProcessor"), ) + ErrCodeNilFunction = errors.Normalize( + "function is not initialized", + errors.RFCCodeText("CDC:ErrCodeNilFunction"), + ) + // ErrVersionIncompatible is an error for running CDC on an incompatible Cluster. ErrVersionIncompatible = errors.Normalize( "version is incompatible: %s", diff --git a/pkg/sink/mysql/db_helper.go b/pkg/sink/mysql/db_helper.go index 1a3d3c8e84c..eb1dd8288ca 100644 --- a/pkg/sink/mysql/db_helper.go +++ b/pkg/sink/mysql/db_helper.go @@ -33,27 +33,15 @@ import ( "go.uber.org/zap" ) -// CreateMySQLDBConn creates a mysql database connection with the given dsn. -func CreateMySQLDBConn(ctx context.Context, dsnStr string) (*sql.DB, error) { - db, err := sql.Open("mysql", dsnStr) - if err != nil { - return nil, cerror.ErrMySQLConnectionError.Wrap(err).GenWithStack("fail to open MySQL connection") - } - - err = db.PingContext(ctx) - if err != nil { - // close db to recycle resources - if closeErr := db.Close(); closeErr != nil { - log.Warn("close db failed", zap.Error(err)) - } - return nil, cerror.ErrMySQLConnectionError.Wrap(err).GenWithStack("fail to open MySQL connection") - } - - return db, nil -} - // GenerateDSN generates the dsn with the given config. -func GenerateDSN(ctx context.Context, sinkURI *url.URL, cfg *Config, dbConnFactory Factory) (dsnStr string, err error) { +// GenerateDSN uses the provided dbConnFactory to create a temporary connection +// to the downstream database specified by the sinkURI. This temporary connection +// is used to query important information from the downstream database, such as +// version, charset, and other relevant details. After the required information +// is retrieved, the temporary connection is closed. The retrieved data is then +// used to populate additional parameters into the Sink URI, refining +// the connection URL (dsnStr). +func GenerateDSN(ctx context.Context, sinkURI *url.URL, cfg *Config, dbConnFactory ConnectionFactory) (dsnStr string, err error) { // dsn format of the driver: // [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] dsn, err := GenBasicDSN(sinkURI, cfg) @@ -216,7 +204,7 @@ func checkTiDBVariable(ctx context.Context, db *sql.DB, variableName, defaultVal // GetTestDB checks and adjusts the password of the given DSN, // it will return a DB instance opened with the adjusted password. -func GetTestDB(ctx context.Context, dbConfig *dmysql.Config, dbConnFactory Factory) (*sql.DB, error) { +func GetTestDB(ctx context.Context, dbConfig *dmysql.Config, dbConnFactory ConnectionFactory) (*sql.DB, error) { password := dbConfig.Passwd if dbConnFactory == nil { dbConnFactory = CreateMySQLDBConn diff --git a/pkg/sink/mysql/factory.go b/pkg/sink/mysql/factory.go index d1c53e579d2..5882c722db4 100644 --- a/pkg/sink/mysql/factory.go +++ b/pkg/sink/mysql/factory.go @@ -16,7 +16,165 @@ package mysql import ( "context" "database/sql" + + "github.com/pingcap/log" + cerror "github.com/pingcap/tiflow/pkg/errors" + "go.uber.org/zap" ) -// Factory is the factory for creating db connection. -type Factory func(ctx context.Context, dsnStr string) (*sql.DB, error) +// IDBConnectionFactory is an interface designed specifically to facilitate unit testing +// in scenarios where connections to downstream databases are required. +// +// In the process of creating a downstream database connection based on a Sink URI, +// a temporary connection is first established. This temporary connection is used +// to query information from the downstream database, such as version, charset, etc. +// After retrieving this information, the temporary connection is closed. +// The retrieved data is then used to populate additional parameters into the Sink URI, +// and a more refined Sink URI is used to establish the final, standard connection that +// the Sink will use for subsequent operations. +// +// During normal system operation, it's perfectly acceptable to create both of these +// connections in the same manner. However, in the context of unit testing, where +// it's not feasible to start an actual downstream database, we need to mock these +// connections. Since both connections will send SQL requests, the unit tests require +// mocking two different connections to handle each phase of the connection creation process. +// +// This interface addresses this issue by providing two distinct methods: +// CreateTemporaryConnection, for creating the temporary connection, and +// CreateStandardConnection, for creating the standard, persistent connection. +// By using these separate methods, the interface allows for greater flexibility +// in mocking and testing, ensuring that the two connections can be created differently +// as needed during unit tests. +type IDBConnectionFactory interface { + CreateTemporaryConnection(ctx context.Context, dsnStr string) (*sql.DB, error) + CreateStandardConnection(ctx context.Context, dsnStr string) (*sql.DB, error) +} + +// DBConnectionFactory is an implementation of the IDBConnectionFactory interface, +// designed for use in normal system operations where only a single method of creating +// database connections is required. +// +// In regular workflows, both the temporary connection (used for querying information +// like version and charset) and the standard connection (used for subsequent operations) +// can be created in the same manner. Therefore, DBConnectionFactory provides a unified +// approach by implementing the IDBConnectionFactory interface and using the same +// CreateMySQLDBConn function for both CreateTemporaryConnection and CreateStandardConnection. +// +// This struct simplifies the process for scenarios where unit testing is not a concern. +// As long as you are not writing unit tests and do not need to mock different types +// of connections, using DBConnectionFactory is perfectly suitable for establishing +// the necessary database connections. +type DBConnectionFactory struct{} + +// CreateTemporaryConnection creates a temporary database connection used to query +// essential information from the downstream database, such as version and charset. +// This connection is intended to be short-lived and will be closed after the necessary +// information is retrieved. +func (d *DBConnectionFactory) CreateTemporaryConnection(ctx context.Context, dsnStr string) (*sql.DB, error) { + return CreateMySQLDBConn(ctx, dsnStr) +} + +// CreateStandardConnection creates the standard database connection that will be +// used by the system for ongoing operations. This connection is based on the refined +// Sink URI containing the necessary parameters gathered from the temporary connection. +func (d *DBConnectionFactory) CreateStandardConnection(ctx context.Context, dsnStr string) (*sql.DB, error) { + return CreateMySQLDBConn(ctx, dsnStr) +} + +// DBConnectionFactoryForTest is a utility implementation of the IDBConnectionFactory +// interface designed to simplify the process of writing unit tests that require +// different methods for creating database connections. +// +// Instead of implementing the IDBConnectionFactory interface from scratch in every +// unit test, DBConnectionFactoryForTest allows developers to easily set up custom +// connection creation methods. The SetTemporaryConnectionFactory and +// SetStandardConnectionFactory methods allow you to define how the temporary and +// standard connections should be created during testing. +// +// Once these methods are set, the system will automatically invoke CreateTemporaryConnection +// and CreateStandardConnection to establish the respective connections during the +// connection creation process. This approach provides flexibility and convenience +// in unit testing scenarios, where different connection behaviors need to be mocked +// or tested separately. +type DBConnectionFactoryForTest struct { + temp ConnectionFactory + standard ConnectionFactory +} + +// NewDBConnectionFactoryForTest creates a new instance of DBConnectionFactoryForTest +// with a predefined temporary connection creation method. This method is initialized +// to use a mock database connection, which is commonly required in the current unit tests. +// By setting up the temporary connection creation logic directly in this constructor, +// it eliminates the need to repeatedly call SetTemporaryConnectionFactory to set up +// the temporary connection factory in each test. This streamlines the setup process +// for unit tests that require a mock temporary connection. +func NewDBConnectionFactoryForTest() *DBConnectionFactoryForTest { + dbConnFactory := &DBConnectionFactoryForTest{} + dbConnFactory.SetTemporaryConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) { + return MockTestDB() + }) + return dbConnFactory +} + +// SetTemporaryConnectionFactory sets the connection factory that will be used to +// create the temporary connection during testing. This allows for custom behavior +// during unit tests when different connection logic is required. +func (d *DBConnectionFactoryForTest) SetTemporaryConnectionFactory(f ConnectionFactory) { + d.temp = f +} + +// SetStandardConnectionFactory sets the connection factory that will be used to +// create the standard connection during testing. This provides the ability to mock +// the connection behavior specifically for the standard connection phase. +func (d *DBConnectionFactoryForTest) SetStandardConnectionFactory(f ConnectionFactory) { + d.standard = f +} + +// CreateTemporaryConnection creates a temporary database connection during testing +// using the connection factory set by SetTemporaryConnectionFactory. If no factory +// has been set, it returns an error. This method allows for customized connection +// logic during the temporary connection phase in unit tests. +func (d *DBConnectionFactoryForTest) CreateTemporaryConnection(ctx context.Context, dsnStr string) (*sql.DB, error) { + if d.temp == nil { + return nil, cerror.ErrCodeNilFunction.GenWithStackByArgs() + } + return d.temp(ctx, dsnStr) +} + +// CreateStandardConnection creates the standard database connection during testing +// using the connection factory set by SetStandardConnectionFactory. If no factory +// has been set, it returns an error. This method supports custom connection behavior +// during the standard connection phase in unit tests. +func (d *DBConnectionFactoryForTest) CreateStandardConnection(ctx context.Context, dsnStr string) (*sql.DB, error) { + if d.standard == nil { + return nil, cerror.ErrCodeNilFunction.GenWithStackByArgs() + } + return d.standard(ctx, dsnStr) +} + +// ConnectionFactory is a function type that takes a context and a DSN (Data Source Name) string +// as input parameters, and returns a pointer to an sql.DB object and an error. This function type +// is typically used to create and configure database connections. +type ConnectionFactory func(ctx context.Context, dsnStr string) (*sql.DB, error) + +// CreateMySQLDBConn establishes a connection to a MySQL database and pings it to verify the connection. +// It takes a context for managing cancellation and timeouts, and a DSN (Data Source Name) string as input. +// The function returns a pointer to an sql.DB object, representing the connection, or an error if the connection fails. +func CreateMySQLDBConn(ctx context.Context, dsnStr string) (*sql.DB, error) { + db, err := sql.Open("mysql", dsnStr) + if err != nil { + return nil, cerror.ErrMySQLConnectionError.Wrap(err).GenWithStack("fail to open MySQL connection") + } + + // Ping the database to verify that the connection is established and functioning. + err = db.PingContext(ctx) + if err != nil { + // If pinging fails, attempt to close the connection to release resources. + if closeErr := db.Close(); closeErr != nil { + log.Warn("close db failed", zap.Error(err)) + } + return nil, cerror.ErrMySQLConnectionError.Wrap(err).GenWithStack("fail to open MySQL connection") + } + + return db, nil +} diff --git a/pkg/sink/observer/observer.go b/pkg/sink/observer/observer.go index 64eb31a4862..348ca6573f6 100644 --- a/pkg/sink/observer/observer.go +++ b/pkg/sink/observer/observer.go @@ -38,14 +38,14 @@ type Observer interface { // NewObserverOpt represents available options when creating a new observer. type NewObserverOpt struct { - dbConnFactory pmysql.Factory + dbConnFactory pmysql.ConnectionFactory } // NewObserverOption configures NewObserverOpt. type NewObserverOption func(*NewObserverOpt) // WithDBConnFactory specifies factory to create db connection. -func WithDBConnFactory(factory pmysql.Factory) NewObserverOption { +func WithDBConnFactory(factory pmysql.ConnectionFactory) NewObserverOption { return func(opt *NewObserverOpt) { opt.dbConnFactory = factory }