Skip to content

Commit

Permalink
mysql sink(ticdc): Refactor DB Connection Creation to Facilitate Unit…
Browse files Browse the repository at this point in the history
… Testing (#11491)

close #11490
  • Loading branch information
wlwilliamx authored Aug 20, 2024
1 parent 3c1e436 commit 691380d
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 245 deletions.
35 changes: 8 additions & 27 deletions cdc/sink/ddlsink/mysql/async_ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"errors"
"fmt"
"net/url"
"sync/atomic"
"testing"
"time"

Expand All @@ -33,18 +32,8 @@ import (
)

func TestWaitAsynExecDone(t *testing.T) {
var dbIndex int32 = 0
GetDBConnImpl = func(ctx context.Context, dsnStr string) (*sql.DB, error) {
defer func() {
atomic.AddInt32(&dbIndex, 1)
}()
if atomic.LoadInt32(&dbIndex) == 0 {
// test db
db, err := pmysql.MockTestDB()
require.Nil(t, err)
return db, nil
}
// normal db
dbConnFactory := pmysql.NewDBConnectionFactoryForTest()
dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
require.Nil(t, err)
mock.ExpectQuery("select tidb_version()").
Expand All @@ -71,7 +60,8 @@ func TestWaitAsynExecDone(t *testing.T) {

mock.ExpectClose()
return db, nil
}
})
GetDBConnImpl = dbConnFactory

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -116,18 +106,8 @@ func TestWaitAsynExecDone(t *testing.T) {

func TestAsyncExecAddIndex(t *testing.T) {
ddlExecutionTime := time.Second * 15
var dbIndex int32 = 0
GetDBConnImpl = func(ctx context.Context, dsnStr string) (*sql.DB, error) {
defer func() {
atomic.AddInt32(&dbIndex, 1)
}()
if atomic.LoadInt32(&dbIndex) == 0 {
// test db
db, err := pmysql.MockTestDB()
require.Nil(t, err)
return db, nil
}
// normal db
dbConnFactory := pmysql.NewDBConnectionFactoryForTest()
dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
require.Nil(t, err)
mock.ExpectQuery("select tidb_version()").
Expand All @@ -145,7 +125,8 @@ func TestAsyncExecAddIndex(t *testing.T) {
mock.ExpectCommit()
mock.ExpectClose()
return db, nil
}
})
GetDBConnImpl = dbConnFactory

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
8 changes: 4 additions & 4 deletions cdc/sink/ddlsink/mysql/mysql_ddl_sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ const (
networkDriftDuration = 5 * time.Second
)

// GetDBConnImpl is the implementation of pmysql.Factory.
// GetDBConnImpl is the implementation of pmysql.IDBConnectionFactory.
// Exported for testing.
var GetDBConnImpl pmysql.Factory = pmysql.CreateMySQLDBConn
var GetDBConnImpl pmysql.IDBConnectionFactory = &pmysql.DBConnectionFactory{}

// Assert Sink implementation
var _ ddlsink.Sink = (*DDLSink)(nil)
Expand Down Expand Up @@ -81,12 +81,12 @@ func NewDDLSink(
return nil, err
}

dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, GetDBConnImpl)
dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, GetDBConnImpl.CreateTemporaryConnection)
if err != nil {
return nil, err
}

db, err := GetDBConnImpl(ctx, dsnStr)
db, err := GetDBConnImpl.CreateStandardConnection(ctx, dsnStr)
if err != nil {
return nil, err
}
Expand Down
17 changes: 4 additions & 13 deletions cdc/sink/ddlsink/mysql/mysql_ddl_sink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,8 @@ import (
)

func TestWriteDDLEvent(t *testing.T) {
dbIndex := 0
GetDBConnImpl = func(ctx context.Context, dsnStr string) (*sql.DB, error) {
defer func() {
dbIndex++
}()
if dbIndex == 0 {
// test db
db, err := pmysql.MockTestDB()
require.Nil(t, err)
return db, nil
}
// normal db
dbConnFactory := pmysql.NewDBConnectionFactoryForTest()
dbConnFactory.SetStandardConnectionFactory(func(ctx context.Context, dsnStr string) (*sql.DB, error) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
require.Nil(t, err)
mock.ExpectQuery("select tidb_version()").
Expand All @@ -66,7 +56,8 @@ func TestWriteDDLEvent(t *testing.T) {
mock.ExpectRollback()
mock.ExpectClose()
return db, nil
}
})
GetDBConnImpl = dbConnFactory

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
6 changes: 3 additions & 3 deletions cdc/sink/dmlsink/txn/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func NewMySQLBackends(
changefeedID model.ChangeFeedID,
sinkURI *url.URL,
replicaConfig *config.ReplicaConfig,
dbConnFactory pmysql.Factory,
dbConnFactory pmysql.IDBConnectionFactory,
statistics *metrics.Statistics,
) ([]*mysqlBackend, error) {
changefeed := fmt.Sprintf("%s.%s", changefeedID.Namespace, changefeedID.ID)
Expand All @@ -97,12 +97,12 @@ func NewMySQLBackends(
return nil, err
}

dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, dbConnFactory)
dsnStr, err := pmysql.GenerateDSN(ctx, sinkURI, cfg, dbConnFactory.CreateTemporaryConnection)
if err != nil {
return nil, err
}

db, err := dbConnFactory(ctx, dsnStr)
db, err := dbConnFactory.CreateStandardConnection(ctx, dsnStr)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 691380d

Please sign in to comment.