From 8445821f105477d5e31de9d8451b7fbf35a02cd1 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Mon, 4 Dec 2023 14:25:22 +0800 Subject: [PATCH] server: make `clientConn()` thread-safe (#49073) (#49104) ref pingcap/tidb#48224 --- server/conn.go | 41 +++++++++++++++++++++++++++-------------- server/conn_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/server/conn.go b/server/conn.go index ecd5977f0d101..6afe49299593c 100644 --- a/server/conn.go +++ b/server/conn.go @@ -207,6 +207,7 @@ type clientConn struct { lastActive time.Time // last active time authPlugin string // default authentication plugin isUnixSocket bool // connection is Unix Socket file + closeOnce sync.Once // closeOnce is used to make sure clientConn closes only once rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets. inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8. socketCredUID uint32 // UID from the other end of the Unix Socket @@ -346,21 +347,33 @@ func (cc *clientConn) Close() error { } func closeConn(cc *clientConn, connections int) error { - metrics.ConnGauge.Set(float64(connections)) - if cc.bufReadConn != nil { - err := cc.bufReadConn.Close() - if err != nil { - // We need to expect connection might have already disconnected. - // This is because closeConn() might be called after a connection read-timeout. - logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + var err error + cc.closeOnce.Do(func() { + metrics.ConnGauge.Set(float64(connections)) + + if cc.bufReadConn != nil { + err = cc.bufReadConn.Close() + if err != nil { + // We need to expect connection might have already disconnected. + // This is because closeConn() might be called after a connection read-timeout. + logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + } + if cc.bufReadConn != nil { + err = cc.bufReadConn.Close() + if err != nil { + // We need to expect connection might have already disconnected. + // This is because closeConn() might be called after a connection read-timeout. + logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + } + } + // Close statements and session + // This will release advisory locks, row locks, etc. + if ctx := cc.getCtx(); ctx != nil { + err = ctx.Close() + } } - } - // Close statements and session - // This will release advisory locks, row locks, etc. - if ctx := cc.getCtx(); ctx != nil { - return ctx.Close() - } - return nil + }) + return err } func (cc *clientConn) closeWithoutLock() error { diff --git a/server/conn_test.go b/server/conn_test.go index fa3b9d5317a96..7c0cd7e93ae81 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -25,6 +25,7 @@ import ( "io" "path/filepath" "strings" + "sync" "sync/atomic" "testing" "time" @@ -1830,3 +1831,43 @@ func TestProcessInfoForExecuteCommand(t *testing.T) { 0x0A, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0})) require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100") } +func TestCloseConn(t *testing.T) { + var outBuffer bytes.Buffer + + store, _ := testkit.CreateMockStoreAndDomain(t) + cfg := newTestConfig() + cfg.Port = 0 + cfg.Status.StatusPort = 0 + drv := NewTiDBDriver(store) + server, err := NewServer(cfg, drv) + require.NoError(t, err) + + cc := &clientConn{ + connectionID: 0, + salt: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, + 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, + }, + server: server, + pkt: &packetIO{ + bufWriter: bufio.NewWriter(&outBuffer), + }, + collation: mysql.DefaultCollationID, + peerHost: "localhost", + alloc: arena.NewAllocator(512), + chunkAlloc: chunk.NewAllocator(), + capability: mysql.ClientProtocol41, + } + + var wg sync.WaitGroup + const numGoroutines = 10 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + err := closeConn(cc, 1) + require.NoError(t, err) + }() + } + wg.Wait() +}