Skip to content

Commit

Permalink
Fix a memory leak in socket_listener when using SSL
Browse files Browse the repository at this point in the history
This fixes a leak in the socket code (used by socket_listener) for tracking connections. Connections are stored so that they can be closed on plugin shutdown, however in the case of SSL sockets, the underlying net.Conn for the TCP connection was stored, but when a socket shut down on its own, it was looked up by it's tls net.Conn handle, which didn't match, resulting in the connection list growing endlessly.

This fix addresses the issue by switching to using context for shutdown notification, and dropping the need for connection tracking. socket_listener predated context being in the minimum supported version, and context.AfterFunc didn't come around until 6 years later, which is why this approach wasn't used initially.

This change also adds some minor refactoring to deduplicate some relevant code.

The added test to check for any future memleak issues does unfortunately take a little while to run (2.2s on my workstation). It basically creates/destroys a few thousand connections and checks the memory profile for allocations. When the number of objects on the heap grows by more than half the number of loops in the test, the test fails. There is inherent noise in the number of objects in the heap, so it requires a few thousand loops. 1000 was not enough to prevent random failures. 2000 seems fine, but I went well above to 5000 for safety. The threshold could possibly be increased to 75% or so, allowing the loop count to be lowered, but this may reduce the safety margin.
  • Loading branch information
phemmer committed Jun 29, 2024
1 parent e2a8625 commit b3f1771
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 88 deletions.
67 changes: 59 additions & 8 deletions plugins/common/socket/socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -477,14 +478,6 @@ func TestClosingConnections(t *testing.T) {
return acc.NMetrics() >= 1
}, time.Second, 100*time.Millisecond, "did not receive metric")

// This has to be a stream-listener...
listener, ok := sock.listener.(*streamListener)
require.True(t, ok)
listener.Lock()
conns := len(listener.connections)
listener.Unlock()
require.NotZero(t, conns)

sock.Close()

// Verify that plugin.Stop() closed the client's connection
Expand Down Expand Up @@ -605,6 +598,64 @@ func TestNoSplitter(t *testing.T) {
testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics())
}

func TestMemoryLeak(t *testing.T) {
cfg := &Config{
ServerConfig: *pki.TLSServerConfig(),
}

sock, err := cfg.NewSocket("tcp://127.0.0.1:0", nil, &testutil.Logger{})
require.NoError(t, err)
require.NoError(t, sock.Setup())
sock.ListenConnection(func(_ net.Addr, r io.ReadCloser) {
_, _ = io.Copy(io.Discard, r)
}, func(_ error) {})
defer sock.Close()

clientTLS := pki.TLSClientConfig()
tlsConfig, err := clientTLS.TLSConfig()
require.NoError(t, err)
msg := []byte("test v=1i")
client := func() {
conn, err := tls.Dial("tcp", sock.Address().String(), tlsConfig)
require.NoError(t, err)
_, err = conn.Write(msg)
require.NoError(t, err)
require.NoError(t, conn.Close())
}

var memStart, mem runtime.MemStats

run := func(nClients int) {
nWorkers := runtime.GOMAXPROCS(0)
var wg sync.WaitGroup
for i := 0; i < nWorkers; i++ {
wg.Add(1)
go func(i int) {
for j := 0; j < nClients/nWorkers; j++ {
client()
}
wg.Done()
}(i)
}
wg.Wait()
}
// warmup
run(100)
runtime.GC()
runtime.GC()
runtime.ReadMemStats(&memStart)

n := 5000
run(n)
runtime.GC()
runtime.GC()
runtime.ReadMemStats(&mem)

// It's unavoidable that there's going to be some fluctuation. But if there's going to be a leak, it's likely to be at
// least 1 object per loop. So use half the loop count as the threshold.
require.Less(t, mem.HeapObjects, memStart.HeapObjects+uint64(n/2))
}

func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) {
// Determine the protocol in a crude fashion
parts := strings.SplitN(endpoint, "://", 2)
Expand Down
126 changes: 46 additions & 80 deletions plugins/common/socket/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package socket

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
Expand All @@ -15,6 +16,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand All @@ -38,9 +40,9 @@ type streamListener struct {
Splitter bufio.SplitFunc
Log telegraf.Logger

listener net.Listener
connections map[net.Conn]bool
path string
cancel func()
listener net.Listener
path string

wg sync.WaitGroup
sync.Mutex
Expand Down Expand Up @@ -126,17 +128,6 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
conn = c.NetConn()
}

addr := conn.RemoteAddr().String()
l.Lock()
if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections {
l.Unlock()
// Ignore the returned error as we cannot do anything about it anyway
_ = conn.Close()
return fmt.Errorf("unable to accept connection from %q: too many connections", addr)
}
l.connections[conn] = true
l.Unlock()

if l.ReadBufferSize > 0 {
if rb, ok := conn.(hasSetReadBuffer); ok {
if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil {
Expand Down Expand Up @@ -171,14 +162,6 @@ func (l *streamListener) setupConnection(conn net.Conn) error {
return nil
}

func (l *streamListener) closeConnection(conn net.Conn) {
addr := conn.RemoteAddr().String()
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) {
l.Log.Warnf("Cannot close connection to %q: %v", addr, err)
}
delete(l.connections, conn)
}

func (l *streamListener) address() net.Addr {
return l.listener.Addr()
}
Expand All @@ -188,11 +171,7 @@ func (l *streamListener) close() error {
return err
}

l.Lock()
for conn := range l.connections {
l.closeConnection(conn)
}
l.Unlock()
l.cancel()
l.wg.Wait()

if l.path != "" {
Expand All @@ -208,12 +187,16 @@ func (l *streamListener) close() error {
return nil
}

func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
l.connections = make(map[net.Conn]bool)
func (l *streamListener) listen(connFunc func(c net.Conn), onError CallbackError) {
var ctx context.Context
ctx, l.cancel = context.WithCancel(context.Background())

l.wg.Add(1)
go func() {
defer l.wg.Done()
defer context.AfterFunc(ctx, func() { _ = l.listener.Close() })()

var connCount int32

var wg sync.WaitGroup
for {
Expand All @@ -225,76 +208,59 @@ func (l *streamListener) listenData(onData CallbackData, onError CallbackError)
break
}

if err := l.setupConnection(conn); err != nil && onError != nil {
onError(err)
if l.MaxConnections > 0 && int(atomic.LoadInt32(&connCount)) >= l.MaxConnections {
onError(fmt.Errorf("unable to accept connection from %q: too many connections", conn.RemoteAddr().String()))
_ = conn.Close()
continue
}

atomic.AddInt32(&connCount, 1)

wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
defer func() {
l.Lock()
l.closeConnection(conn)
l.Unlock()
}()

reader := l.read
if l.Splitter == nil {
reader = l.readAll
}
if err := reader(c, onData); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
}
defer func() { _ = c.Close() }()
defer context.AfterFunc(ctx, func() { _ = conn.Close() })()
defer atomic.AddInt32(&connCount, -1)

if err := l.setupConnection(c); err != nil && onError != nil {
onError(err)
return
}

connFunc(c)
}(conn)
}
wg.Wait()
}()
}

func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
l.connections = make(map[net.Conn]bool)

l.wg.Add(1)
go func() {
defer l.wg.Done()

var wg sync.WaitGroup
for {
conn, err := l.listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) && onError != nil {
func (l *streamListener) listenData(onData CallbackData, onError CallbackError) {
l.listen(func(c net.Conn) {
reader := l.read
if l.Splitter == nil {
reader = l.readAll
}
if err := reader(c, onData); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
break
}

if err := l.setupConnection(conn); err != nil && onError != nil {
onError(err)
continue
}
}
}, onError)
}

wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
if err := l.handleConnection(c, onConnection); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
}
func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) {
l.listen(func(c net.Conn) {
if err := l.handleConnection(c, onConnection); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) {
if onError != nil {
onError(err)
}
l.Lock()
l.closeConnection(conn)
l.Unlock()
}(conn)
}
}
wg.Wait()
}()
}, onError)
}

func (l *streamListener) read(conn net.Conn, onData CallbackData) error {
Expand Down

0 comments on commit b3f1771

Please sign in to comment.