From 9e9c14db92dfe4a16340173c7295028c1c4fdbd4 Mon Sep 17 00:00:00 2001 From: Dheerendra Rathor Date: Thu, 2 May 2024 05:33:15 +0200 Subject: [PATCH] fix getSysConn to work with TLS (#918) `tls.Conn` by default doesn't implement `syscall.Conn` and hence tchannel emits `onnection does not implement SyscallConn.` log a lot. Since go 1.18, `tls.Conn` exposes method `NetConn()` to expose raw underlying TCP connection. This is used for getting the `syscall.Conn`. --- connection.go | 14 +++++++++++- connection_internal_test.go | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/connection.go b/connection.go index 36cf64ab5..265807dff 100644 --- a/connection.go +++ b/connection.go @@ -21,6 +21,7 @@ package tchannel import ( + "crypto/tls" "errors" "fmt" "io" @@ -971,7 +972,18 @@ func (c *Connection) getLastActivityWriteTime() time.Time { } func getSysConn(conn net.Conn, log Logger) syscall.RawConn { - connSyscall, ok := conn.(syscall.Conn) + var ( + connSyscall syscall.Conn + ok bool + ) + switch v := conn.(type) { + case syscall.Conn: + connSyscall = v + ok = true + case *tls.Conn: + connSyscall, ok = v.NetConn().(syscall.Conn) + } + if !ok { log.WithFields(LogField{"connectionType", fmt.Sprintf("%T", conn)}). Error("Connection does not implement SyscallConn.") diff --git a/connection_internal_test.go b/connection_internal_test.go index 306aceca7..18b6bd913 100644 --- a/connection_internal_test.go +++ b/connection_internal_test.go @@ -22,7 +22,9 @@ package tchannel import ( "bytes" + "crypto/tls" "net" + "net/http/httptest" "syscall" "testing" @@ -79,4 +81,45 @@ func TestGetSysConn(t *testing.T) { require.NotNil(t, sysConn) assert.Empty(t, loggerBuf.String(), "expected no logs on success") }) + + t.Run("SyscallConn is successful with TLS", func(t *testing.T) { + var ( + loggerBuf = &bytes.Buffer{} + logger = NewLogger(loggerBuf) + server = httptest.NewTLSServer(nil) + ) + defer server.Close() + + conn, err := tls.Dial("tcp", server.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true}) + require.NoError(t, err, "failed to dial") + defer conn.Close() + + sysConn := getSysConn(conn, logger) + require.NotNil(t, sysConn) + assert.Empty(t, loggerBuf.String(), "expected no logs on success") + }) + + t.Run("no SyscallConn - nil net.Conn", func(t *testing.T) { + var ( + loggerBuf = &bytes.Buffer{} + logger = NewLogger(loggerBuf) + syscallConn = getSysConn(nil /* conn */, logger) + ) + + require.Nil(t, syscallConn, "expected no syscall.RawConn to be returned") + assert.Contains(t, loggerBuf.String(), "Connection does not implement SyscallConn", "missing log") + assert.Contains(t, loggerBuf.String(), "{connectionType }", "missing type in log") + }) + + t.Run("no SyscallConn - TLS with no net.Conn", func(t *testing.T) { + var ( + loggerBuf = &bytes.Buffer{} + logger = NewLogger(loggerBuf) + syscallConn = getSysConn(&tls.Conn{}, logger) + ) + + require.Nil(t, syscallConn, "expected no syscall.RawConn to be returned") + assert.Contains(t, loggerBuf.String(), "Connection does not implement SyscallConn", "missing log") + assert.Contains(t, loggerBuf.String(), "{connectionType *tls.Conn}", "missing type in log") + }) }