From d7ddb8b9e324830b1ede89c5fea090c824497c51 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sat, 23 Mar 2024 00:57:24 +0900 Subject: [PATCH] Fix issue 1567 (#1570) ### Description closes https://github.com/go-sql-driver/mysql/issues/1567 When TLS is enabled, `mc.netConn` is rewritten after the TLS handshak as detailed here: https://github.com/go-sql-driver/mysql/blob/d86c4527bae98ccd4e5060f72887520ce30eda5e/packets.go#L355 Therefore, `mc.netConn` should not be accessed within the watcher goroutine. Instead, `mc.rawConn` should be initialized prior to invoking `mc.startWatcher`, and `mc.rawConn` should be used in lieu of `mc.netConn`. ### Checklist - [x] Code compiles correctly - [x] Created tests which fail without the change (if possible) - [x] All tests passing - [x] Extended the README / documentation, if necessary - [x] Added myself / the copyright holder to the AUTHORS file ## Summary by CodeRabbit - **Refactor** - Improved variable naming for better code readability and maintenance. - Enhanced network connection handling logic. - **New Features** - Updated TCP connection handling to better support TCP Keepalives. - **Tests** - Added a new test to address and verify the fix for a specific issue related to TLS, connection pooling, and round trip time estimation. --- connection.go | 6 +++--- connector.go | 2 +- driver_test.go | 33 +++++++++++++++++++++++++++++++++ packets.go | 1 - 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/connection.go b/connection.go index f3656f0e6..7b8abeb00 100644 --- a/connection.go +++ b/connection.go @@ -153,11 +153,11 @@ func (mc *mysqlConn) cleanup() { // Makes cleanup idempotent close(mc.closech) - nc := mc.netConn - if nc == nil { + conn := mc.rawConn + if conn == nil { return } - if err := nc.Close(); err != nil { + if err := conn.Close(); err != nil { mc.log(err) } // This function can be called from multiple goroutines. diff --git a/connector.go b/connector.go index a0ee62839..b67077596 100644 --- a/connector.go +++ b/connector.go @@ -102,10 +102,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { nd := net.Dialer{Timeout: mc.cfg.Timeout} mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) } - if err != nil { return nil, err } + mc.rawConn = mc.netConn // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { diff --git a/driver_test.go b/driver_test.go index 6b52650c2..4fd196d4b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -20,6 +20,7 @@ import ( "io" "log" "math" + mrand "math/rand" "net" "net/url" "os" @@ -3577,3 +3578,35 @@ func runCallCommand(dbt *DBTest, query, name string) { } } } + +func TestIssue1567(t *testing.T) { + // enable TLS. + runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) { + // disable connection pooling. + // data race happens when new connection is created. + dbt.db.SetMaxIdleConns(0) + + // estimate round trip time. + start := time.Now() + if err := dbt.db.PingContext(context.Background()); err != nil { + t.Fatal(err) + } + rtt := time.Since(start) + if rtt <= 0 { + // In some environments, rtt may become 0, so set it to at least 1ms. + rtt = time.Millisecond + } + + count := 1000 + if testing.Short() { + count = 10 + } + + for i := 0; i < count; i++ { + timeout := time.Duration(mrand.Int63n(int64(rtt))) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + dbt.db.PingContext(ctx) + cancel() + } + }) +} diff --git a/packets.go b/packets.go index d727f00fe..90a34728b 100644 --- a/packets.go +++ b/packets.go @@ -351,7 +351,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if err := tlsConn.Handshake(); err != nil { return err } - mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn }