Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: handle kill signal during write result to connection #52882

Merged
merged 18 commits into from
Jun 12, 2024
53 changes: 40 additions & 13 deletions pkg/executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,24 @@ type processinfoSetter interface {

// recordSet wraps an executor, implements sqlexec.RecordSet interface
type recordSet struct {
fields []*ast.ResultField
executor exec.Executor
fields []*ast.ResultField
executor exec.Executor
// The `Fields` method may be called after `Close`, and the executor is cleared in the `Close` function.
// Therefore, we need to store the schema in `recordSet` to avoid a null pointer exception when calling `executor.Schema()`.
schema *expression.Schema
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
stmt *ExecStmt
lastErrs []error
txnStartTS uint64
once sync.Once
// finishLock is a mutex used to synchronize access to the `Next` and `Finish` functions of the adapter.
// It ensures that only one goroutine can access the `Next` and `Finish` functions at a time, preventing race conditions.
// When we terminate the current SQL externally (e.g., kill query), an additional goroutine would be used to call the `Finish` function.
finishLock sync.Mutex
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
}

func (a *recordSet) Fields() []*ast.ResultField {
if len(a.fields) == 0 {
a.fields = colNames2ResultFields(a.executor.Schema(), a.stmt.OutputNames, a.stmt.Ctx.GetSessionVars().CurrentDB)
a.fields = colNames2ResultFields(a.schema, a.stmt.OutputNames, a.stmt.Ctx.GetSessionVars().CurrentDB)
}
return a.fields
}
Expand Down Expand Up @@ -156,6 +163,13 @@ func (a *recordSet) Next(ctx context.Context, req *chunk.Chunk) (err error) {
err = util2.GetRecoverError(r)
logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.stmt.GetTextToLog(false)), zap.Stack("stack"))
}()
a.finishLock.Lock()
defer a.finishLock.Unlock()
if a.stmt != nil {
if err := a.stmt.Ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil {
return err
}
}

err = a.stmt.next(ctx, a.executor, req)
if err != nil {
Expand Down Expand Up @@ -186,16 +200,27 @@ func (a *recordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk {

func (a *recordSet) Finish() error {
var err error
a.once.Do(func() {
err = exec.Close(a.executor)
cteErr := resetCTEStorageMap(a.stmt.Ctx)
if cteErr != nil {
logutil.BgLogger().Error("got error when reset cte storage, should check if the spill disk file deleted or not", zap.Error(cteErr))
}
if err == nil {
err = cteErr
}
})
if a.finishLock.TryLock() {
defer a.finishLock.Unlock()
a.once.Do(func() {
err = exec.Close(a.executor)
cteErr := resetCTEStorageMap(a.stmt.Ctx)
if cteErr != nil {
logutil.BgLogger().Error("got error when reset cte storage, should check if the spill disk file deleted or not", zap.Error(cteErr))
}
if err == nil {
err = cteErr
}
a.executor = nil
if a.stmt != nil {
status := a.stmt.Ctx.GetSessionVars().SQLKiller.GetKillSignal()
inWriteResultSet := a.stmt.Ctx.GetSessionVars().SQLKiller.InWriteResultSet.Load()
if status > 0 && inWriteResultSet {
logutil.BgLogger().Warn("kill, this SQL might be stuck in the network stack while writing packets to the client.", zap.Uint64("connection ID", a.stmt.Ctx.GetSessionVars().ConnectionID))
}
}
})
}
if err != nil {
a.lastErrs = append(a.lastErrs, err)
}
Expand Down Expand Up @@ -336,6 +361,7 @@ func (a *ExecStmt) PointGet(ctx context.Context) (*recordSet, error) {

return &recordSet{
executor: executor,
schema: executor.Schema(),
stmt: a,
txnStartTS: startTs,
}, nil
Expand Down Expand Up @@ -571,6 +597,7 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {

return &recordSet{
executor: e,
schema: e.Schema(),
stmt: a,
txnStartTS: txnStartTS,
}, nil
Expand Down
6 changes: 6 additions & 0 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2048,6 +2048,12 @@ func (cc *clientConn) handleStmt(
if cc.getStatus() == connStatusShutdown {
return false, exeerrors.ErrQueryInterrupted
}
cc.ctx.GetSessionVars().SQLKiller.Finish = func() {
//nolint: errcheck
rs.Finish()
}
cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true)
defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false)
if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil {
return retryable, err
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,13 @@ func (s *Server) Kill(connectionID uint64, query bool, maxExecutionTime bool) {
// Mark the client connection status as WaitShutdown, when clientConn.Run detect
// this, it will end the dispatch loop and exit.
conn.setStatus(connStatusWaitShutdown)
if conn.bufReadConn != nil {
// When attempting to 'kill connection' and TiDB is stuck in the network stack while writing packets,
// we can quickly exit the network stack and terminate the SQL execution by setting WriteDeadline.
if err := conn.bufReadConn.SetWriteDeadline(time.Now()); err != nil {
wshwsh12 marked this conversation as resolved.
Show resolved Hide resolved
logutil.BgLogger().Warn("error setting write deadline for kill.", zap.Error(err))
}
}
}
killQuery(conn, maxExecutionTime)
}
Expand Down Expand Up @@ -940,6 +947,7 @@ func killQuery(conn *clientConn, maxExecutionTime bool) {
logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err))
}
}
sessVars.SQLKiller.FinishResultSet()
}

// KillSysProcesses kill sys processes such as auto analyze.
Expand Down
10 changes: 10 additions & 0 deletions pkg/util/servermemorylimit/servermemorylimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,20 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) {
zap.String("sql text", fmt.Sprintf("%.100v", info.Info)),
zap.Int64("sql memory usage", info.MemTracker.BytesConsumed()))
s.lastLogTime = time.Now()

if seconds := time.Since(s.killStartTime) / time.Second; seconds >= 60 {
// If the SQL cannot be terminated after 60 seconds, it may be stuck in the network stack while writing packets to the client,
// encountering some bugs that cause it to hang, or failing to detect the kill signal.
// In this case, the resources can be reclaimed by calling the `Finish` method, and then we can start looking for the next SQL with the largest memory usage.
logutil.BgLogger().Warn(fmt.Sprintf("global memory controller failed to kill the top-consumer in %d seconds. Attempting to force close the executors.", seconds))
s.sessionTracker.Killer.FinishResultSet()
goto Succ
}
}
return
}
}
Succ:
s.reset()
IsKilling.Store(false)
memory.MemUsageTop1Tracker.CompareAndSwap(s.sessionTracker, nil)
Expand Down
56 changes: 45 additions & 11 deletions pkg/util/sqlkiller/sqlkiller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,48 @@ const (
type SQLKiller struct {
Signal killSignal
ConnID uint64
Finish func()
// InWriteResultSet is used to indicate whether the query is currently calling clientConn.writeResultSet().
// If the query is in writeResultSet and Finish() can acquire rs.finishLock, we can assume the query is waiting for the client to receive data from the server over network I/O.
InWriteResultSet atomic.Bool
}

// SendKillSignal sends a kill signal to the query.
func (killer *SQLKiller) SendKillSignal(reason killSignal) {
atomic.CompareAndSwapUint32(&killer.Signal, 0, reason)
if atomic.CompareAndSwapUint32(&killer.Signal, 0, reason) {
status := atomic.LoadUint32(&killer.Signal)
err := killer.getKillError(status)
logutil.BgLogger().Warn("kill initiated", zap.Uint64("connection ID", killer.ConnID), zap.String("reason", err.Error()))
}
}

// GetKillSignal gets the kill signal.
func (killer *SQLKiller) GetKillSignal() killSignal {
return atomic.LoadUint32(&killer.Signal)
}

// getKillError gets the error according to the kill signal.
func (killer *SQLKiller) getKillError(status killSignal) error {
switch status {
case QueryInterrupted:
return exeerrors.ErrQueryInterrupted.GenWithStackByArgs()
case MaxExecTimeExceeded:
return exeerrors.ErrMaxExecTimeExceeded.GenWithStackByArgs()
case QueryMemoryExceeded:
return exeerrors.ErrMemoryExceedForQuery.GenWithStackByArgs(killer.ConnID)
case ServerMemoryExceeded:
return exeerrors.ErrMemoryExceedForInstance.GenWithStackByArgs(killer.ConnID)
}
return nil
}

// FinishResultSet is used to close the result set.
// If a kill signal is sent but the SQL query is stuck in the network stack while writing packets to the client,
// encountering some bugs that cause it to hang, or failing to detect the kill signal, we can call Finish to release resources used during the SQL execution process.
func (killer *SQLKiller) FinishResultSet() {
if killer.Finish != nil {
killer.Finish()
}
}

// HandleSignal handles the kill signal and return the error.
Expand All @@ -61,22 +98,19 @@ func (killer *SQLKiller) HandleSignal() error {
}
})
status := atomic.LoadUint32(&killer.Signal)
switch status {
case QueryInterrupted:
return exeerrors.ErrQueryInterrupted.GenWithStackByArgs()
case MaxExecTimeExceeded:
return exeerrors.ErrMaxExecTimeExceeded.GenWithStackByArgs()
case QueryMemoryExceeded:
return exeerrors.ErrMemoryExceedForQuery.GenWithStackByArgs(killer.ConnID)
case ServerMemoryExceeded:
err := killer.getKillError(status)
if status == ServerMemoryExceeded {
logutil.BgLogger().Warn("global memory controller, NeedKill signal is received successfully",
zap.Uint64("conn", killer.ConnID))
return exeerrors.ErrMemoryExceedForInstance.GenWithStackByArgs(killer.ConnID)
}
return nil
return err
}

// Reset resets the SqlKiller.
func (killer *SQLKiller) Reset() {
if atomic.LoadUint32(&killer.Signal) != 0 {
logutil.BgLogger().Warn("kill finished", zap.Uint64("conn", killer.ConnID))
}
atomic.StoreUint32(&killer.Signal, 0)
killer.Finish = nil
}