diff --git a/cdc/sink/dmlsink/txn/mysql/mysql.go b/cdc/sink/dmlsink/txn/mysql/mysql.go index 9ae33f49ed6..d3c1403b9fb 100644 --- a/cdc/sink/dmlsink/txn/mysql/mysql.go +++ b/cdc/sink/dmlsink/txn/mysql/mysql.go @@ -615,6 +615,82 @@ func (s *mysqlBackend) prepareDMLs() *preparedDMLs { } } +// execute SQLs in the multi statements way. +func (s *mysqlBackend) multiStmtExecute( + ctx context.Context, dmls *preparedDMLs, tx *sql.Tx, writeTimeout time.Duration, +) error { + start := time.Now() + multiStmtSQL := "" + multiStmtArgs := []any{} + for i, query := range dmls.sqls { + multiStmtSQL += query + if i != len(dmls.sqls)-1 { + multiStmtSQL += ";" + } + multiStmtArgs = append(multiStmtArgs, dmls.values[i]...) + } + ctx, cancel := context.WithTimeout(ctx, writeTimeout) + defer cancel() + _, execError := tx.ExecContext(ctx, multiStmtSQL, multiStmtArgs...) + if execError != nil { + err := logDMLTxnErr( + cerror.WrapError(cerror.ErrMySQLTxnError, execError), + start, s.changefeed, multiStmtSQL, dmls.rowCount, dmls.startTs) + if rbErr := tx.Rollback(); rbErr != nil { + if errors.Cause(rbErr) != context.Canceled { + log.Warn("failed to rollback txn", zap.Error(rbErr)) + } + } + return err + } + return nil +} + +// execute SQLs in each preparedDMLs one by one in the same transaction. +func (s *mysqlBackend) sequenceExecute( + ctx context.Context, dmls *preparedDMLs, tx *sql.Tx, writeTimeout time.Duration, +) error { + start := time.Now() + for i, query := range dmls.sqls { + args := dmls.values[i] + log.Debug("exec row", zap.Int("workerID", s.workerID), + zap.String("sql", query), zap.Any("args", args)) + ctx, cancelFunc := context.WithTimeout(ctx, writeTimeout) + var execError error + if s.cachePrepStmts { + stmt, ok := s.stmtCache.Get(query) + if !ok { + var err error + stmt, err = s.db.Prepare(query) + if err != nil { + cancelFunc() + return errors.Trace(err) + } + + s.stmtCache.Add(query, stmt) + } + //nolint:sqlclosecheck + _, execError = tx.Stmt(stmt.(*sql.Stmt)).ExecContext(ctx, args...) + } else { + _, execError = tx.ExecContext(ctx, query, args...) + } + if execError != nil { + err := logDMLTxnErr( + cerror.WrapError(cerror.ErrMySQLTxnError, execError), + start, s.changefeed, query, dmls.rowCount, dmls.startTs) + if rbErr := tx.Rollback(); rbErr != nil { + if errors.Cause(rbErr) != context.Canceled { + log.Warn("failed to rollback txn", zap.Error(rbErr)) + } + } + cancelFunc() + return err + } + cancelFunc() + } + return nil +} + func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *preparedDMLs) error { if len(dmls.sqls) != len(dmls.values) { log.Panic("unexpected number of sqls and values", @@ -623,6 +699,7 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare } start := time.Now() + fallbackToSeqWay := false return retry.Do(pctx, func() error { writeTimeout, _ := time.ParseDuration(s.cfg.WriteTimeout) writeTimeout += networkDriftDuration @@ -644,42 +721,22 @@ func (s *mysqlBackend) execDMLWithMaxRetries(pctx context.Context, dmls *prepare start, s.changefeed, "BEGIN", dmls.rowCount, dmls.startTs) } - for i, query := range dmls.sqls { - args := dmls.values[i] - log.Debug("exec row", zap.Int("workerID", s.workerID), - zap.String("sql", query), zap.Any("args", args)) - ctx, cancelFunc := context.WithTimeout(pctx, writeTimeout) - var execError error - if s.cachePrepStmts { - stmt, ok := s.stmtCache.Get(query) - if !ok { - var err error - stmt, err = s.db.Prepare(query) - if err != nil { - cancelFunc() - return 0, errors.Trace(err) - } - - s.stmtCache.Add(query, stmt) - } - //nolint:sqlclosecheck - _, execError = tx.Stmt(stmt.(*sql.Stmt)).ExecContext(ctx, args...) - } else { - _, execError = tx.ExecContext(ctx, query, args...) + // If interplated SQL size exceeds maxAllowPacket, mysql driver will + // fall back to the sequantial way. + // error can be ErrPrepareMulti, ErrBadConn etc. + // TODO: add a quick path to check whether we should fallback to + // the sequence way. + if s.cfg.MultiStmtEnable && !fallbackToSeqWay { + err = s.multiStmtExecute(pctx, dmls, tx, writeTimeout) + if err != nil { + fallbackToSeqWay = true + return 0, err } - if execError != nil { - err := logDMLTxnErr( - cerror.WrapError(cerror.ErrMySQLTxnError, execError), - start, s.changefeed, query, dmls.rowCount, dmls.startTs) - if rbErr := tx.Rollback(); rbErr != nil { - if errors.Cause(rbErr) != context.Canceled { - log.Warn("failed to rollback txn", zap.Error(rbErr)) - } - } - cancelFunc() + } else { + err = s.sequenceExecute(pctx, dmls, tx, writeTimeout) + if err != nil { return 0, err } - cancelFunc() } // we set write source for each txn, @@ -724,6 +781,9 @@ func logDMLTxnErr( err error, start time.Time, changefeed string, query string, count int, startTs []model.Ts, ) error { + if len(query) > 1024 { + query = query[:1024] + } if isRetryableDMLError(err) { log.Warn("execute DMLs with error, retry later", zap.Error(err), zap.Duration("duration", time.Since(start)), diff --git a/pkg/applier/redo_test.go b/pkg/applier/redo_test.go index 1f77aa85616..f26a11db203 100644 --- a/pkg/applier/redo_test.go +++ b/pkg/applier/redo_test.go @@ -203,7 +203,8 @@ func TestApply(t *testing.T) { cfg := &RedoApplierConfig{ SinkURI: "mysql://127.0.0.1:4000/?worker-count=1&max-txn-row=1" + - "&tidb_placement_mode=ignore&safe-mode=true&cache-prep-stmts=false", + "&tidb_placement_mode=ignore&safe-mode=true&cache-prep-stmts=false" + + "&multi-stmt-enable=false", } ap := NewRedoApplier(cfg) err := ap.Apply(ctx) diff --git a/pkg/sink/mysql/config.go b/pkg/sink/mysql/config.go index 2a313bff808..e8e8a9fd660 100644 --- a/pkg/sink/mysql/config.go +++ b/pkg/sink/mysql/config.go @@ -71,7 +71,8 @@ const ( // BackoffMaxDelay indicates the max delay time for retrying. BackoffMaxDelay = 60 * time.Second - defaultBatchDMLEnable = true + defaultBatchDMLEnable = true + defaultMultiStmtEnable = true // defaultcachePrepStmts is the default value of cachePrepStmts defaultCachePrepStmts = true @@ -102,6 +103,7 @@ type Config struct { IsTiDB bool // IsTiDB is true if the downstream is TiDB SourceID uint64 BatchDMLEnable bool + MultiStmtEnable bool CachePrepStmts bool PrepStmtCacheSize int } @@ -119,6 +121,7 @@ func NewConfig() *Config { DialTimeout: defaultDialTimeout, SafeMode: defaultSafeMode, BatchDMLEnable: defaultBatchDMLEnable, + MultiStmtEnable: defaultMultiStmtEnable, CachePrepStmts: defaultCachePrepStmts, PrepStmtCacheSize: defaultPrepStmtCacheSize, } @@ -176,6 +179,9 @@ func (c *Config) Apply( if err = getBatchDMLEnable(query, &c.BatchDMLEnable); err != nil { return err } + if err = getMultiStmtEnable(query, &c.MultiStmtEnable); err != nil { + return err + } if err = getCachePrepStmts(query, &c.CachePrepStmts); err != nil { return err } @@ -385,6 +391,18 @@ func getBatchDMLEnable(values url.Values, batchDMLEnable *bool) error { return nil } +func getMultiStmtEnable(values url.Values, multiStmtEnable *bool) error { + s := values.Get("multi-stmt-enable") + if len(s) > 0 { + enable, err := strconv.ParseBool(s) + if err != nil { + return cerror.WrapError(cerror.ErrMySQLInvalidConfig, err) + } + *multiStmtEnable = enable + } + return nil +} + func getCachePrepStmts(values url.Values, cachePrepStmts *bool) error { s := values.Get("cache-prep-stmts") if len(s) > 0 {