diff --git a/server/BUILD.bazel b/server/BUILD.bazel index cd23fd7370de9..04149e4084416 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -237,7 +237,11 @@ go_test( "@com_github_tikv_client_go_v2//tikvrpc", "@io_etcd_go_etcd_tests_v3//integration", "@io_opencensus_go//stats/view", +<<<<<<< HEAD "@org_golang_x_exp//slices", +======= + "@org_uber_go_atomic//:atomic", +>>>>>>> 05a1ad36ce8 (server: fix connection double close (#53690)) "@org_uber_go_goleak//:goleak", "@org_uber_go_zap//:zap", ], diff --git a/server/conn.go b/server/conn.go index 7089a0ab61d98..a10182df67bba 100644 --- a/server/conn.go +++ b/server/conn.go @@ -336,14 +336,6 @@ func closeConn(cc *clientConn, connections int) error { // This is because closeConn() might be called after a connection read-timeout. logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) } - if cc.bufReadConn != nil { - err = cc.bufReadConn.Close() - if err != nil { - // We need to expect connection might have already disconnected. - // This is because closeConn() might be called after a connection read-timeout. - logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) - } - } // Close statements and session // This will release advisory locks, row locks, etc. if ctx := cc.getCtx(); ctx != nil { diff --git a/server/conn_test.go b/server/conn_test.go index 1278d24e9a742..5739068c923d3 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1692,7 +1692,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes := append([]byte{0x00, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 999)))...) _, err := inBuffer.Write(bytes) require.NoError(t, err) - brc := newBufferedReadConn(&bytesConn{inBuffer}) + brc := newBufferedReadConn(&bytesConn{b: inBuffer}) pkt := newPacketIO(brc) pkt.setMaxAllowedPacket(maxAllowedPacket) readBytes, err = pkt.readPacket() @@ -1705,7 +1705,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes = append([]byte{0x01, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 1000)))...) _, err = inBuffer.Write(bytes) require.NoError(t, err) - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt = newPacketIO(brc) pkt.setMaxAllowedPacket(maxAllowedPacket) _, err = pkt.readPacket() @@ -1717,7 +1717,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes = append([]byte{0x01, 0x02, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 488)))...) _, err = inBuffer.Write(bytes) require.NoError(t, err) - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt = newPacketIO(brc) pkt.setMaxAllowedPacket(maxAllowedPacket) readBytes, err = pkt.readPacket() @@ -1728,7 +1728,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes = append([]byte{0x01, 0x02, 0x00, 0x01}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("b", 488)))...) _, err = inBuffer.Write(bytes) require.NoError(t, err) - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt.setBufferedReadConn(brc) readBytes, err = pkt.readPacket() require.NoError(t, err) @@ -1987,6 +1987,7 @@ func TestProcessInfoForExecuteCommand(t *testing.T) { require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100") } +<<<<<<< HEAD func TestLDAPAuthSwitch(t *testing.T) { store := testkit.CreateMockStore(t) cfg := newTestConfig() @@ -2026,6 +2027,8 @@ func TestLDAPAuthSwitch(t *testing.T) { require.Equal(t, []byte(mysql.AuthMySQLClearPassword), respAuthSwitch) } +======= +>>>>>>> 05a1ad36ce8 (server: fix connection double close (#53690)) func TestCloseConn(t *testing.T) { var outBuffer bytes.Buffer @@ -2036,7 +2039,12 @@ func TestCloseConn(t *testing.T) { drv := NewTiDBDriver(store) server, err := NewServer(cfg, drv) require.NoError(t, err) - + var inBuffer bytes.Buffer + _, err = inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01}) + require.NoError(t, err) + // Test read one packet + brc := newBufferedReadConn(&bytesConn{b: inBuffer}) + require.NoError(t, err) cc := &clientConn{ connectionID: 0, salt: []byte{ @@ -2047,11 +2055,12 @@ func TestCloseConn(t *testing.T) { pkt: &packetIO{ bufWriter: bufio.NewWriter(&outBuffer), }, - collation: mysql.DefaultCollationID, - peerHost: "localhost", - alloc: arena.NewAllocator(512), - chunkAlloc: chunk.NewAllocator(), - capability: mysql.ClientProtocol41, + collation: mysql.DefaultCollationID, + peerHost: "localhost", + alloc: arena.NewAllocator(512), + chunkAlloc: chunk.NewAllocator(), + capability: mysql.ClientProtocol41, + bufReadConn: brc, } var wg sync.WaitGroup diff --git a/server/packetio_test.go b/server/packetio_test.go index fc0b38a23169b..0e9cde8bb26d2 100644 --- a/server/packetio_test.go +++ b/server/packetio_test.go @@ -21,8 +21,10 @@ import ( "testing" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func BenchmarkPacketIOWrite(b *testing.B) { @@ -64,7 +66,7 @@ func TestPacketIORead(t *testing.T) { _, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01}) require.NoError(t, err) // Test read one packet - brc := newBufferedReadConn(&bytesConn{inBuffer}) + brc := newBufferedReadConn(&bytesConn{b: inBuffer}) pkt := newPacketIO(brc) readBytes, err := pkt.readPacket() require.NoError(t, err) @@ -86,7 +88,7 @@ func TestPacketIORead(t *testing.T) { _, err = inBuffer.Write(buf) require.NoError(t, err) // Test read multiple packets - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt = newPacketIO(brc) readBytes, err = pkt.readPacket() require.NoError(t, err) @@ -96,7 +98,8 @@ func TestPacketIORead(t *testing.T) { } type bytesConn struct { - b bytes.Buffer + b bytes.Buffer + closed atomic.Bool } func (c *bytesConn) Read(b []byte) (n int, err error) { @@ -108,6 +111,10 @@ func (c *bytesConn) Write(b []byte) (n int, err error) { } func (c *bytesConn) Close() error { + if c.closed.Load() { + return errors.New("already closed") + } + c.closed.Store(true) return nil }