Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sink(ticdc): support multi statements in mysql backend #8395

Merged
merged 16 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 93 additions & 33 deletions cdc/sink/dmlsink/txn/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,82 @@ func (s *mysqlBackend) prepareDMLs() *preparedDMLs {
}
}

// execute SQLs in the multi statements way.
func (s *mysqlBackend) multiStmtExecute(
ctx context.Context, dmls *preparedDMLs, tx *sql.Tx, writeTimeout time.Duration,
) error {
start := time.Now()
multiStmtSQL := ""
multiStmtArgs := []any{}
for i, query := range dmls.sqls {
multiStmtSQL += query
if i != len(dmls.sqls)-1 {
multiStmtSQL += ";"
}
multiStmtArgs = append(multiStmtArgs, dmls.values[i]...)
}
ctx, cancel := context.WithTimeout(ctx, writeTimeout)
defer cancel()
_, execError := tx.ExecContext(ctx, multiStmtSQL, multiStmtArgs...)
if execError != nil {
err := logDMLTxnErr(
cerror.WrapError(cerror.ErrMySQLTxnError, execError),
start, s.changefeed, multiStmtSQL, dmls.rowCount, dmls.startTs)
if rbErr := tx.Rollback(); rbErr != nil {
if errors.Cause(rbErr) != context.Canceled {
log.Warn("failed to rollback txn", zap.Error(rbErr))
}
}
return err
}
return nil
}

// execute SQLs in each preparedDMLs one by one in the same transaction.
func (s *mysqlBackend) sequenceExecute(
ctx context.Context, dmls *preparedDMLs, tx *sql.Tx, writeTimeout time.Duration,
) error {
start := time.Now()
for i, query := range dmls.sqls {
args := dmls.values[i]
log.Debug("exec row", zap.Int("workerID", s.workerID),
zap.String("sql", query), zap.Any("args", args))
ctx, cancelFunc := context.WithTimeout(ctx, writeTimeout)
var execError error
if s.cachePrepStmts {
stmt, ok := s.stmtCache.Get(query)
if !ok {
var err error
stmt, err = s.db.Prepare(query)
if err != nil {
cancelFunc()
return errors.Trace(err)
}

s.stmtCache.Add(query, stmt)
}
//nolint:sqlclosecheck
_, execError = tx.Stmt(stmt.(*sql.Stmt)).ExecContext(ctx, args...)
} else {
_, execError = tx.ExecContext(ctx, query, args...)
}
if execError != nil {
err := logDMLTxnErr(
cerror.WrapError(cerror.ErrMySQLTxnError, execError),
start, s.changefeed, query, dmls.rowCount, dmls.startTs)
if rbErr := tx.Rollback(); rbErr != nil {
if errors.Cause(rbErr) != context.Canceled {
log.Warn("failed to rollback txn", zap.Error(rbErr))
}
}
cancelFunc()
return err
}
cancelFunc()
}
return nil
}

func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *preparedDMLs) error {
if len(dmls.sqls) != len(dmls.values) {
log.Panic("unexpected number of sqls and values",
Expand All @@ -623,6 +699,7 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare
}

start := time.Now()
fallbackToSeqWay := false
return retry.Do(pctx, func() error {
writeTimeout, _ := time.ParseDuration(s.cfg.WriteTimeout)
writeTimeout += networkDriftDuration
Expand All @@ -644,42 +721,22 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare
start, s.changefeed, "BEGIN", dmls.rowCount, dmls.startTs)
}

for i, query := range dmls.sqls {
args := dmls.values[i]
log.Debug("exec row", zap.Int("workerID", s.workerID),
zap.String("sql", query), zap.Any("args", args))
ctx, cancelFunc := context.WithTimeout(pctx, writeTimeout)
var execError error
if s.cachePrepStmts {
stmt, ok := s.stmtCache.Get(query)
if !ok {
var err error
stmt, err = s.db.Prepare(query)
if err != nil {
cancelFunc()
return 0, errors.Trace(err)
}

s.stmtCache.Add(query, stmt)
}
//nolint:sqlclosecheck
_, execError = tx.Stmt(stmt.(*sql.Stmt)).ExecContext(ctx, args...)
} else {
_, execError = tx.ExecContext(ctx, query, args...)
// If interplated SQL size exceeds maxAllowPacket, mysql driver will
// fall back to the sequantial way.
// error can be ErrPrepareMulti, ErrBadConn etc.
// TODO: add a quick path to check whether we should fallback to
// the sequence way.
if s.cfg.MultiStmtEnable && !fallbackToSeqWay {
err = s.multiStmtExecute(pctx, dmls, tx, writeTimeout)
if err != nil {
fallbackToSeqWay = true
return 0, err
}
if execError != nil {
err := logDMLTxnErr(
cerror.WrapError(cerror.ErrMySQLTxnError, execError),
start, s.changefeed, query, dmls.rowCount, dmls.startTs)
if rbErr := tx.Rollback(); rbErr != nil {
if errors.Cause(rbErr) != context.Canceled {
log.Warn("failed to rollback txn", zap.Error(rbErr))
}
}
cancelFunc()
} else {
err = s.sequenceExecute(pctx, dmls, tx, writeTimeout)
if err != nil {
return 0, err
}
cancelFunc()
}

// we set write source for each txn,
Expand Down Expand Up @@ -724,6 +781,9 @@ func logDMLTxnErr(
err error, start time.Time, changefeed string,
query string, count int, startTs []model.Ts,
) error {
if len(query) > 1024 {
query = query[:1024]
}
if isRetryableDMLError(err) {
log.Warn("execute DMLs with error, retry later",
zap.Error(err), zap.Duration("duration", time.Since(start)),
Expand Down
3 changes: 2 additions & 1 deletion pkg/applier/redo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ func TestApply(t *testing.T) {

cfg := &RedoApplierConfig{
SinkURI: "mysql://127.0.0.1:4000/?worker-count=1&max-txn-row=1" +
"&tidb_placement_mode=ignore&safe-mode=true&cache-prep-stmts=false",
"&tidb_placement_mode=ignore&safe-mode=true&cache-prep-stmts=false" +
"&multi-stmt-enable=false",
}
ap := NewRedoApplier(cfg)
err := ap.Apply(ctx)
Expand Down
20 changes: 19 additions & 1 deletion pkg/sink/mysql/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ const (
// BackoffMaxDelay indicates the max delay time for retrying.
BackoffMaxDelay = 60 * time.Second

defaultBatchDMLEnable = true
defaultBatchDMLEnable = true
defaultMultiStmtEnable = true

// defaultcachePrepStmts is the default value of cachePrepStmts
defaultCachePrepStmts = true
Expand Down Expand Up @@ -102,6 +103,7 @@ type Config struct {
IsTiDB bool // IsTiDB is true if the downstream is TiDB
SourceID uint64
BatchDMLEnable bool
MultiStmtEnable bool
CachePrepStmts bool
PrepStmtCacheSize int
}
Expand All @@ -119,6 +121,7 @@ func NewConfig() *Config {
DialTimeout: defaultDialTimeout,
SafeMode: defaultSafeMode,
BatchDMLEnable: defaultBatchDMLEnable,
MultiStmtEnable: defaultMultiStmtEnable,
CachePrepStmts: defaultCachePrepStmts,
PrepStmtCacheSize: defaultPrepStmtCacheSize,
}
Expand Down Expand Up @@ -176,6 +179,9 @@ func (c *Config) Apply(
if err = getBatchDMLEnable(query, &c.BatchDMLEnable); err != nil {
return err
}
if err = getMultiStmtEnable(query, &c.MultiStmtEnable); err != nil {
return err
}
if err = getCachePrepStmts(query, &c.CachePrepStmts); err != nil {
return err
}
Expand Down Expand Up @@ -385,6 +391,18 @@ func getBatchDMLEnable(values url.Values, batchDMLEnable *bool) error {
return nil
}

func getMultiStmtEnable(values url.Values, multiStmtEnable *bool) error {
s := values.Get("multi-stmt-enable")
if len(s) > 0 {
enable, err := strconv.ParseBool(s)
if err != nil {
return cerror.WrapError(cerror.ErrMySQLInvalidConfig, err)
}
*multiStmtEnable = enable
}
return nil
}

func getCachePrepStmts(values url.Values, cachePrepStmts *bool) error {
s := values.Get("cache-prep-stmts")
if len(s) > 0 {
Expand Down