Skip to content

Commit

Permalink
Fix potential leaks of nse/ns streams in case of lost close (#1641)
Browse files Browse the repository at this point in the history
* fix linter

Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>

* fix ci issues

Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>

* fix tests

Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>

---------

Signed-off-by: denis-tingaikin <denis.tingajkin@xored.com>
  • Loading branch information
denis-tingaikin authored Jul 3, 2024
1 parent 21369bd commit 34ab19f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 61 deletions.
129 changes: 69 additions & 60 deletions pkg/networkservice/common/netsvcmonitor/server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023 Cisco Systems, Inc.
// Copyright (c) 2023-2024 Cisco Systems, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -57,69 +57,24 @@ func (m *monitorServer) Request(ctx context.Context, request *networkservice.Net
return resp, err
}

var conn = resp.Clone()
monitorCtx, cancel := context.WithCancel(m.chainCtx)
minT := time.Time{}

var monitorCtx, cancel = context.WithCancel(m.chainCtx)

storeCancelFunction(ctx, cancel)

var logger = log.FromContext(ctx).WithField("monitorServer", "Find")

var monitorNetworkServiceGoroutine = func() {
for ; monitorCtx.Err() == nil; time.Sleep(time.Millisecond * 100) {
// nolint:govet
var stream, err = m.nsClient.Find(monitorCtx, &registry.NetworkServiceQuery{
Watch: true,
NetworkService: &registry.NetworkService{
Name: conn.GetNetworkService(),
},
})
if err != nil {
logger.Errorf("an error happened during finding network service: %v", err.Error())
continue
}

var networkServiceCh = registry.ReadNetworkServiceChannel(stream)
var netsvcStreamIsAlive = true

for netsvcStreamIsAlive && monitorCtx.Err() == nil {
select {
case <-monitorCtx.Done():
return
case netsvc, ok := <-networkServiceCh:
if !ok {
netsvcStreamIsAlive = false
break
}

nseStream, err := m.nseClient.Find(monitorCtx, &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: &registry.NetworkServiceEndpoint{
Name: conn.GetNetworkServiceEndpointName(),
},
})
if err != nil {
logger.Errorf("an error happened during finding nse: %v", err.Error())
break
}

var nses = registry.ReadNetworkServiceEndpointList(nseStream)

if len(nses) == 0 {
continue
}

if len(matchutils.MatchEndpoint(resp.GetLabels(), netsvc.GetNetworkService(), nses...)) == 0 {
begin.FromContext(ctx).Close()
logger.Warnf("nse %v doesn't match with networkservice: %v", conn.GetNetworkServiceEndpointName(), conn.GetNetworkService())

return
}
}
}
for _, seg := range resp.GetPath().GetPathSegments() {
var t = seg.Expires.AsTime().Local()
if minT.After(t) || minT.IsZero() {
minT = t
}
}

go monitorNetworkServiceGoroutine()
if !minT.IsZero() {
cancel()
monitorCtx, cancel = context.WithTimeout(m.chainCtx, time.Until(minT))
}

storeCancelFunction(ctx, cancel)

go m.monitorNetworkService(monitorCtx, resp.Clone(), begin.FromContext(ctx))

return resp, err
}
Expand All @@ -131,3 +86,57 @@ func (m *monitorServer) Close(ctx context.Context, conn *networkservice.Connecti

return next.Server(ctx).Close(ctx, conn)
}

func (m *monitorServer) monitorNetworkService(monitorCtx context.Context, conn *networkservice.Connection, factory begin.EventFactory) {
var logger = log.FromContext(monitorCtx).WithField("monitorServer", "Find")
for ; monitorCtx.Err() == nil; time.Sleep(time.Millisecond * 100) {
// nolint:govet
var stream, err = m.nsClient.Find(monitorCtx, &registry.NetworkServiceQuery{
Watch: true,
NetworkService: &registry.NetworkService{
Name: conn.GetNetworkService(),
},
})
if err != nil {
logger.Errorf("an error happened during finding network service: %v", err.Error())
continue
}

var networkServiceCh = registry.ReadNetworkServiceChannel(stream)
var netsvcStreamIsAlive = true

for netsvcStreamIsAlive && monitorCtx.Err() == nil {
select {
case <-monitorCtx.Done():
return
case netsvc, ok := <-networkServiceCh:
if !ok {
netsvcStreamIsAlive = false
break
}

nseStream, err := m.nseClient.Find(monitorCtx, &registry.NetworkServiceEndpointQuery{
NetworkServiceEndpoint: &registry.NetworkServiceEndpoint{
Name: conn.GetNetworkServiceEndpointName(),
},
})
if err != nil {
logger.Errorf("an error happened during finding nse: %v", err.Error())
break
}

var nses = registry.ReadNetworkServiceEndpointList(nseStream)

if len(nses) == 0 {
continue
}

if len(matchutils.MatchEndpoint(conn.GetLabels(), netsvc.GetNetworkService(), nses...)) == 0 {
factory.Close()
logger.Warnf("nse %v doesn't match with networkservice: %v", conn.GetNetworkServiceEndpointName(), conn.GetNetworkService())
return
}
}
}
}
}
59 changes: 58 additions & 1 deletion pkg/networkservice/common/netsvcmonitor/server_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023 Cisco Systems, Inc.
// Copyright (c) 2023-2024 Cisco Systems, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -25,6 +25,8 @@ import (
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/networkservicemesh/api/pkg/api/registry"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/netsvcmonitor"
Expand All @@ -36,6 +38,10 @@ import (
)

func Test_Netsvcmonitor_And_GroupOfSimilarNetworkServices(t *testing.T) {
t.Cleanup(func() {
goleak.VerifyNone(t)
})

var testCtx, cancel = context.WithTimeout(context.Background(), time.Second)
defer cancel()

Expand Down Expand Up @@ -91,3 +97,54 @@ func Test_Netsvcmonitor_And_GroupOfSimilarNetworkServices(t *testing.T) {
return counter.Closes() > 0
}, time.Millisecond*300, time.Millisecond*50)
}

func Test_NetsvcMonitor_ShouldNotLeakWithoutClose(t *testing.T) {
var testCtx, cancel = context.WithTimeout(context.Background(), time.Second*5)
t.Cleanup(func() {
require.Eventually(t, func() bool {
return goleak.Find(goleak.IgnoreAnyFunction("github.com/stretchr/testify/assert.Eventually")) == nil
}, time.Second*2, time.Second/10)
cancel()
})
var nsServer = memory.NewNetworkServiceRegistryServer()
var nseServer = memory.NewNetworkServiceEndpointRegistryServer()
var counter count.Server

_, _ = nsServer.Register(context.Background(), &registry.NetworkService{
Name: "service-1",
})

_, _ = nseServer.Register(context.Background(), &registry.NetworkServiceEndpoint{
Name: "endpoint-1",
NetworkServiceNames: []string{"service-1"},
})

var server = chain.NewNetworkServiceServer(
metadata.NewServer(),
begin.NewServer(),
netsvcmonitor.NewServer(
testCtx,
adapters.NetworkServiceServerToClient(nsServer),
adapters.NetworkServiceEndpointServerToClient(nseServer),
),
&counter,
)

var request = &networkservice.NetworkServiceRequest{
Connection: &networkservice.Connection{
Id: "1",
NetworkService: "service-1",
NetworkServiceEndpointName: "endpoint-1",
Path: &networkservice.Path{
PathSegments: []*networkservice.PathSegment{
{
Expires: timestamppb.New(time.Now().Add(time.Second)),
},
},
},
},
}

var _, err = server.Request(testCtx, request)
require.NoError(t, err)
}

0 comments on commit 34ab19f

Please sign in to comment.