diff --git a/server/conn.go b/server/conn.go index 251d431f9a49a..5ad396b11ac5e 100644 --- a/server/conn.go +++ b/server/conn.go @@ -65,6 +65,7 @@ import ( "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" ) @@ -945,6 +946,10 @@ func (cc *clientConn) flush() error { func (cc *clientConn) writeOK() error { msg := cc.ctx.LastMessage() + return cc.writeOkWith(msg, cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), cc.ctx.Status(), cc.ctx.WarningCount()) +} + +func (cc *clientConn) writeOkWith(msg string, affectedRows, lastInsertID uint64, status, warnCnt uint16) error { enclen := 0 if len(msg) > 0 { enclen = lengthEncodedIntSize(uint64(len(msg))) + len(msg) @@ -952,11 +957,11 @@ func (cc *clientConn) writeOK() error { data := cc.alloc.AllocWithLen(4, 32+enclen) data = append(data, mysql.OKHeader) - data = dumpLengthEncodedInt(data, cc.ctx.AffectedRows()) - data = dumpLengthEncodedInt(data, cc.ctx.LastInsertID()) + data = dumpLengthEncodedInt(data, affectedRows) + data = dumpLengthEncodedInt(data, lastInsertID) if cc.capability&mysql.ClientProtocol41 > 0 { - data = dumpUint16(data, cc.ctx.Status()) - data = dumpUint16(data, cc.ctx.WarningCount()) + data = dumpUint16(data, status) + data = dumpUint16(data, warnCnt) } if enclen > 0 { // although MySQL manual says the info message is string(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html), @@ -1396,12 +1401,27 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet } func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error { - for _, rs := range rss { - if err := cc.writeResultset(ctx, rs, binary, mysql.ServerMoreResultsExists, 0); err != nil { + for i, rs := range rss { + lastRs := i == len(rss)-1 + if r, ok := rs.(*tidbResultSet).recordSet.(sqlexec.MultiQueryNoDelayResult); ok { + status := r.Status() + if !lastRs { + status |= mysql.ServerMoreResultsExists + } + if err := cc.writeOkWith(r.LastMessage(), r.AffectedRows(), r.LastInsertID(), status, r.WarnCount()); err != nil { + return err + } + continue + } + status := uint16(0) + if !lastRs { + status |= mysql.ServerMoreResultsExists + } + if err := cc.writeResultset(ctx, rs, binary, status, 0); err != nil { return err } } - return cc.writeOK() + return nil } func (cc *clientConn) setConn(conn net.Conn) { diff --git a/session/session.go b/session/session.go index 593db28078e88..edeb21747d53c 100644 --- a/session/session.go +++ b/session/session.go @@ -945,7 +945,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu s.processInfo.Store(&pi) } -func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet) ([]sqlexec.RecordSet, error) { +func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet, inMulitQuery bool) ([]sqlexec.RecordSet, error) { s.SetValue(sessionctx.QueryString, stmt.OriginText()) if _, ok := stmtNode.(ast.DDLNode); ok { s.SetValue(sessionctx.LastExecuteDDL, true) @@ -970,6 +970,16 @@ func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode sessionExecuteRunDurationGeneral.Observe(time.Since(startTime).Seconds()) } + if inMulitQuery && recordSet == nil { + recordSet = &multiQueryNoDelayRecordSet{ + affectedRows: s.AffectedRows(), + lastMessage: s.LastMessage(), + warnCount: s.sessionVars.StmtCtx.WarningCount(), + lastInsertID: s.sessionVars.StmtCtx.LastInsertID, + status: s.sessionVars.Status, + } + } + if recordSet != nil { recordSets = append(recordSets, recordSet) } @@ -1016,6 +1026,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec var tempStmtNodes []ast.StmtNode compiler := executor.Compiler{Ctx: s} + multiQuery := len(stmtNodes) > 1 for idx, stmtNode := range stmtNodes { s.PrepareTxnCtx(ctx) @@ -1052,7 +1063,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec s.currentPlan = stmt.Plan // Step3: Execute the physical plan. - if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets); err != nil { + if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets, multiQuery); err != nil { return nil, err } } @@ -1889,3 +1900,47 @@ func (s *session) recordTransactionCounter(err error) { } } } + +type multiQueryNoDelayRecordSet struct { + affectedRows uint64 + lastMessage string + status uint16 + warnCount uint16 + lastInsertID uint64 +} + +func (c *multiQueryNoDelayRecordSet) Fields() []*ast.ResultField { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) Next(ctx context.Context, chk *chunk.Chunk) error { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) NewChunk() *chunk.Chunk { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) Close() error { + return nil +} + +func (c *multiQueryNoDelayRecordSet) AffectedRows() uint64 { + return c.affectedRows +} + +func (c *multiQueryNoDelayRecordSet) LastMessage() string { + return c.lastMessage +} + +func (c *multiQueryNoDelayRecordSet) WarnCount() uint16 { + return c.warnCount +} + +func (c *multiQueryNoDelayRecordSet) Status() uint16 { + return c.status +} + +func (c *multiQueryNoDelayRecordSet) LastInsertID() uint64 { + return c.lastInsertID +} diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 5c5c0d31711f0..ec42336087752 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -96,3 +96,17 @@ type RecordSet interface { // restart the iteration. Close() error } + +// MultiQueryNoDelayResult is an interface for one no-delay result for one statement in multi-queries. +type MultiQueryNoDelayResult interface { + // AffectedRows return affected row for one statement in multi-queries. + AffectedRows() uint64 + // LastMessage return last message for one statement in multi-queries. + LastMessage() string + // WarnCount return warn count for one statement in multi-queries. + WarnCount() uint16 + // Status return status when executing one statement in multi-queries. + Status() uint16 + // LastInsertID return last insert id for one statement in multi-queries. + LastInsertID() uint64 +}