From 34ab19f2a884a54d241d33980ac29cfdcaa1589e Mon Sep 17 00:00:00 2001 From: Denis Tingaikin Date: Wed, 3 Jul 2024 17:06:02 +0300 Subject: [PATCH] Fix potential leaks of nse/ns streams in case of lost close (#1641) * fix linter Signed-off-by: denis-tingaikin * fix ci issues Signed-off-by: denis-tingaikin * fix tests Signed-off-by: denis-tingaikin --------- Signed-off-by: denis-tingaikin --- .../common/netsvcmonitor/server.go | 129 ++++++++++-------- .../common/netsvcmonitor/server_test.go | 59 +++++++- 2 files changed, 127 insertions(+), 61 deletions(-) diff --git a/pkg/networkservice/common/netsvcmonitor/server.go b/pkg/networkservice/common/netsvcmonitor/server.go index 5c75f92a9..8527b139c 100644 --- a/pkg/networkservice/common/netsvcmonitor/server.go +++ b/pkg/networkservice/common/netsvcmonitor/server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cisco Systems, Inc. +// Copyright (c) 2023-2024 Cisco Systems, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -57,69 +57,24 @@ func (m *monitorServer) Request(ctx context.Context, request *networkservice.Net return resp, err } - var conn = resp.Clone() + monitorCtx, cancel := context.WithCancel(m.chainCtx) + minT := time.Time{} - var monitorCtx, cancel = context.WithCancel(m.chainCtx) - - storeCancelFunction(ctx, cancel) - - var logger = log.FromContext(ctx).WithField("monitorServer", "Find") - - var monitorNetworkServiceGoroutine = func() { - for ; monitorCtx.Err() == nil; time.Sleep(time.Millisecond * 100) { - // nolint:govet - var stream, err = m.nsClient.Find(monitorCtx, ®istry.NetworkServiceQuery{ - Watch: true, - NetworkService: ®istry.NetworkService{ - Name: conn.GetNetworkService(), - }, - }) - if err != nil { - logger.Errorf("an error happened during finding network service: %v", err.Error()) - continue - } - - var networkServiceCh = registry.ReadNetworkServiceChannel(stream) - var netsvcStreamIsAlive = true - - for netsvcStreamIsAlive && monitorCtx.Err() == nil { - select { - case <-monitorCtx.Done(): - return - case netsvc, ok := <-networkServiceCh: - if !ok { - netsvcStreamIsAlive = false - break - } - - nseStream, err := m.nseClient.Find(monitorCtx, ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ - Name: conn.GetNetworkServiceEndpointName(), - }, - }) - if err != nil { - logger.Errorf("an error happened during finding nse: %v", err.Error()) - break - } - - var nses = registry.ReadNetworkServiceEndpointList(nseStream) - - if len(nses) == 0 { - continue - } - - if len(matchutils.MatchEndpoint(resp.GetLabels(), netsvc.GetNetworkService(), nses...)) == 0 { - begin.FromContext(ctx).Close() - logger.Warnf("nse %v doesn't match with networkservice: %v", conn.GetNetworkServiceEndpointName(), conn.GetNetworkService()) - - return - } - } - } + for _, seg := range resp.GetPath().GetPathSegments() { + var t = seg.Expires.AsTime().Local() + if minT.After(t) || minT.IsZero() { + minT = t } } - go monitorNetworkServiceGoroutine() + if !minT.IsZero() { + cancel() + monitorCtx, cancel = context.WithTimeout(m.chainCtx, time.Until(minT)) + } + + storeCancelFunction(ctx, cancel) + + go m.monitorNetworkService(monitorCtx, resp.Clone(), begin.FromContext(ctx)) return resp, err } @@ -131,3 +86,57 @@ func (m *monitorServer) Close(ctx context.Context, conn *networkservice.Connecti return next.Server(ctx).Close(ctx, conn) } + +func (m *monitorServer) monitorNetworkService(monitorCtx context.Context, conn *networkservice.Connection, factory begin.EventFactory) { + var logger = log.FromContext(monitorCtx).WithField("monitorServer", "Find") + for ; monitorCtx.Err() == nil; time.Sleep(time.Millisecond * 100) { + // nolint:govet + var stream, err = m.nsClient.Find(monitorCtx, ®istry.NetworkServiceQuery{ + Watch: true, + NetworkService: ®istry.NetworkService{ + Name: conn.GetNetworkService(), + }, + }) + if err != nil { + logger.Errorf("an error happened during finding network service: %v", err.Error()) + continue + } + + var networkServiceCh = registry.ReadNetworkServiceChannel(stream) + var netsvcStreamIsAlive = true + + for netsvcStreamIsAlive && monitorCtx.Err() == nil { + select { + case <-monitorCtx.Done(): + return + case netsvc, ok := <-networkServiceCh: + if !ok { + netsvcStreamIsAlive = false + break + } + + nseStream, err := m.nseClient.Find(monitorCtx, ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + Name: conn.GetNetworkServiceEndpointName(), + }, + }) + if err != nil { + logger.Errorf("an error happened during finding nse: %v", err.Error()) + break + } + + var nses = registry.ReadNetworkServiceEndpointList(nseStream) + + if len(nses) == 0 { + continue + } + + if len(matchutils.MatchEndpoint(conn.GetLabels(), netsvc.GetNetworkService(), nses...)) == 0 { + factory.Close() + logger.Warnf("nse %v doesn't match with networkservice: %v", conn.GetNetworkServiceEndpointName(), conn.GetNetworkService()) + return + } + } + } + } +} diff --git a/pkg/networkservice/common/netsvcmonitor/server_test.go b/pkg/networkservice/common/netsvcmonitor/server_test.go index 7847201dd..4ef66bf1c 100644 --- a/pkg/networkservice/common/netsvcmonitor/server_test.go +++ b/pkg/networkservice/common/netsvcmonitor/server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cisco Systems, Inc. +// Copyright (c) 2023-2024 Cisco Systems, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -25,6 +25,8 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/api/pkg/api/registry" "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" "github.com/networkservicemesh/sdk/pkg/networkservice/common/netsvcmonitor" @@ -36,6 +38,10 @@ import ( ) func Test_Netsvcmonitor_And_GroupOfSimilarNetworkServices(t *testing.T) { + t.Cleanup(func() { + goleak.VerifyNone(t) + }) + var testCtx, cancel = context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -91,3 +97,54 @@ func Test_Netsvcmonitor_And_GroupOfSimilarNetworkServices(t *testing.T) { return counter.Closes() > 0 }, time.Millisecond*300, time.Millisecond*50) } + +func Test_NetsvcMonitor_ShouldNotLeakWithoutClose(t *testing.T) { + var testCtx, cancel = context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(func() { + require.Eventually(t, func() bool { + return goleak.Find(goleak.IgnoreAnyFunction("github.com/stretchr/testify/assert.Eventually")) == nil + }, time.Second*2, time.Second/10) + cancel() + }) + var nsServer = memory.NewNetworkServiceRegistryServer() + var nseServer = memory.NewNetworkServiceEndpointRegistryServer() + var counter count.Server + + _, _ = nsServer.Register(context.Background(), ®istry.NetworkService{ + Name: "service-1", + }) + + _, _ = nseServer.Register(context.Background(), ®istry.NetworkServiceEndpoint{ + Name: "endpoint-1", + NetworkServiceNames: []string{"service-1"}, + }) + + var server = chain.NewNetworkServiceServer( + metadata.NewServer(), + begin.NewServer(), + netsvcmonitor.NewServer( + testCtx, + adapters.NetworkServiceServerToClient(nsServer), + adapters.NetworkServiceEndpointServerToClient(nseServer), + ), + &counter, + ) + + var request = &networkservice.NetworkServiceRequest{ + Connection: &networkservice.Connection{ + Id: "1", + NetworkService: "service-1", + NetworkServiceEndpointName: "endpoint-1", + Path: &networkservice.Path{ + PathSegments: []*networkservice.PathSegment{ + { + Expires: timestamppb.New(time.Now().Add(time.Second)), + }, + }, + }, + }, + } + + var _, err = server.Request(testCtx, request) + require.NoError(t, err) +}