diff --git a/pkg/executor/read_write_splitting.go b/pkg/executor/read_write_splitting.go index 2f71721..4874bf3 100644 --- a/pkg/executor/read_write_splitting.go +++ b/pkg/executor/read_write_splitting.go @@ -150,7 +150,7 @@ func (executor *ReadWriteSplittingExecutor) ExecuteFieldList(ctx context.Context } func (executor *ReadWriteSplittingExecutor) ExecutorComQuery( - ctx context.Context, _ string) (result proto.Result, warns uint16, err error) { + ctx context.Context, sqlText string) (result proto.Result, warns uint16, err error) { spanCtx, span := tracing.GetTraceSpan(ctx, tracing.RWSComQuery) defer span.End() @@ -178,9 +178,7 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery( connectionID := proto.ConnectionID(spanCtx) queryStmt := proto.QueryStmt(spanCtx) - if err := queryStmt.Restore(format.NewRestoreCtx(format.RestoreStringSingleQuotes| - format.RestoreKeyWordUppercase| - format.RestoreStringWithoutDefaultCharset, &sb)); err != nil { + if err := queryStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)); err != nil { return nil, 0, err } sql := sb.String() @@ -202,12 +200,12 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery( txi, ok := executor.localTransactionMap.Load(connectionID) if ok { tx = txi.(proto.Tx) - return tx.Query(spanCtx, sql) + return tx.Query(spanCtx, sqlText) } // set to all db for _, db := range executor.all { go func(db *DataSourceBrief) { - if _, _, err := db.DB.Query(spanCtx, sql); err != nil { + if _, _, err := db.DB.Query(spanCtx, sqlText); err != nil { log.Error(err) } }(db) diff --git a/pkg/executor/single_db.go b/pkg/executor/single_db.go index 12c9533..0fec66a 100644 --- a/pkg/executor/single_db.go +++ b/pkg/executor/single_db.go @@ -124,7 +124,7 @@ func (executor *SingleDBExecutor) ExecuteFieldList(ctx context.Context, table, w } func (executor *SingleDBExecutor) ExecutorComQuery( - ctx context.Context, _ string) (result proto.Result, warns uint16, err error) { + ctx context.Context, sqlText string) (result proto.Result, warns uint16, err error) { spanCtx, span := tracing.GetTraceSpan(ctx, tracing.SDBComQuery) defer span.End() @@ -155,9 +155,7 @@ func (executor *SingleDBExecutor) ExecutorComQuery( if queryStmt == nil { return nil, 0, errors.New("query stmt should not be nil") } - if err := queryStmt.Restore(format.NewRestoreCtx(format.RestoreStringSingleQuotes| - format.RestoreKeyWordUppercase| - format.RestoreStringWithoutDefaultCharset, &sb)); err != nil { + if err := queryStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)); err != nil { return nil, 0, err } sql := sb.String() @@ -179,9 +177,9 @@ func (executor *SingleDBExecutor) ExecutorComQuery( txi, ok := executor.localTransactionMap.Load(connectionID) if ok { tx = txi.(proto.Tx) - return tx.Query(spanCtx, sql) + return tx.Query(spanCtx, sqlText) } - return db.Query(spanCtx, sql) + return db.Query(spanCtx, sqlText) } case *ast.BeginStmt: // TODO add metrics