Skip to content

Commit

Permalink
Rework memory registry server (#699)
Browse files Browse the repository at this point in the history
* Rework memory registry server

Signed-off-by: Vladimir Popov <vladimir.popov@xored.com>

* Fix memory registry slow receiver test

Signed-off-by: Vladimir Popov <vladimir.popov@xored.com>

* Fix tests

Signed-off-by: Vladimir Popov <vladimir.popov@xored.com>
  • Loading branch information
Vladimir Popov authored Feb 10, 2021
1 parent ac14b1d commit 536f38f
Show file tree
Hide file tree
Showing 4 changed files with 532 additions and 160 deletions.
176 changes: 103 additions & 73 deletions pkg/registry/common/memory/ns_server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 Doc.ai and/or its affiliates.
// Copyright (c) 2020-2021 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -18,113 +18,143 @@ package memory

import (
"context"
"errors"
"io"

"github.com/edwarnicke/serialize"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/uuid"
"github.com/networkservicemesh/api/pkg/api/registry"

"github.com/edwarnicke/serialize"
"github.com/networkservicemesh/api/pkg/api/registry"

"github.com/networkservicemesh/sdk/pkg/registry/core/next"
"github.com/networkservicemesh/sdk/pkg/tools/matchutils"
)

type networkServiceRegistryServer struct {
type memoryNSServer struct {
networkServices NetworkServiceSyncMap
executor serialize.Executor
eventChannels map[string]chan *registry.NetworkService
eventChannelSize int
}

func (n *networkServiceRegistryServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) {
// NewNetworkServiceRegistryServer creates new memory based NetworkServiceRegistryServer
func NewNetworkServiceRegistryServer(options ...Option) registry.NetworkServiceRegistryServer {
s := &memoryNSServer{
eventChannelSize: defaultEventChannelSize,
eventChannels: make(map[string]chan *registry.NetworkService),
}
for _, o := range options {
o.apply(s)
}
return s
}

func (s *memoryNSServer) setEventChannelSize(l int) {
s.eventChannelSize = l
}

func (s *memoryNSServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) {
r, err := next.NetworkServiceRegistryServer(ctx).Register(ctx, ns)
if err != nil {
return nil, err
}
n.networkServices.Store(r.Name, r.Clone())
n.executor.AsyncExec(func() {
for _, ch := range n.eventChannels {
ch <- r.Clone()

s.networkServices.Store(r.Name, r.Clone())

s.sendEvent(r)

return r, nil
}

func (s *memoryNSServer) sendEvent(event *registry.NetworkService) {
event = event.Clone()
s.executor.AsyncExec(func() {
for _, ch := range s.eventChannels {
ch <- event.Clone()
}
})
return r, nil
}

func (n *networkServiceRegistryServer) Find(query *registry.NetworkServiceQuery, s registry.NetworkServiceRegistry_FindServer) error {
sendAllMatches := func(ns *registry.NetworkService) error {
var err error
n.networkServices.Range(func(key string, value *registry.NetworkService) bool {
if matchutils.MatchNetworkServices(ns, value) {
err = s.Send(value.Clone())
return err == nil
func (s *memoryNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error {
if !query.Watch {
for _, ns := range s.allMatches(query) {
if err := server.Send(ns); err != nil {
return err
}
return true
})
}
return next.NetworkServiceRegistryServer(server.Context()).Find(query, server)
}

eventCh := make(chan *registry.NetworkService, s.eventChannelSize)
id := uuid.New().String()

s.executor.AsyncExec(func() {
s.eventChannels[id] = eventCh
for _, entity := range s.allMatches(query) {
eventCh <- entity
}
})
defer s.closeEventChannel(id, eventCh)

var err error
for ; err == nil; err = s.receiveEvent(query, server, eventCh) {
}
if err != io.EOF {
return err
}
if query.Watch {
eventCh := make(chan *registry.NetworkService, n.eventChannelSize)
id := uuid.New().String()
n.executor.AsyncExec(func() {
n.eventChannels[id] = eventCh
})
defer n.executor.AsyncExec(func() {
delete(n.eventChannels, id)
})
err := sendAllMatches(query.NetworkService)
if err != nil {
return err
return next.NetworkServiceRegistryServer(server.Context()).Find(query, server)
}

func (s *memoryNSServer) allMatches(query *registry.NetworkServiceQuery) (matches []*registry.NetworkService) {
s.networkServices.Range(func(_ string, ns *registry.NetworkService) bool {
if matchutils.MatchNetworkServices(query.NetworkService, ns) {
matches = append(matches, ns.Clone())
}
notifyChannel := func() error {
select {
case <-s.Context().Done():
return io.EOF
case event := <-eventCh:
if matchutils.MatchNetworkServices(query.NetworkService, event) {
if s.Context().Err() != nil {
return io.EOF
}
if err := s.Send(event); err != nil {
return err
}
}
return nil
}
return true
})
return matches
}

func (s *memoryNSServer) closeEventChannel(id string, eventCh <-chan *registry.NetworkService) {
ctx, cancel := context.WithCancel(context.Background())

s.executor.AsyncExec(func() {
delete(s.eventChannels, id)
cancel()
})

for {
select {
case <-ctx.Done():
return
case <-eventCh:
}
for {
err := notifyChannel()
if errors.Is(err, io.EOF) {
break
}
}

func (s *memoryNSServer) receiveEvent(
query *registry.NetworkServiceQuery,
server registry.NetworkServiceRegistry_FindServer,
eventCh <-chan *registry.NetworkService,
) error {
select {
case <-server.Context().Done():
return io.EOF
case event := <-eventCh:
if matchutils.MatchNetworkServices(query.NetworkService, event) {
if server.Context().Err() != nil {
return io.EOF
}
if err != nil {
if err := server.Send(event); err != nil {
return err
}
}
} else if err := sendAllMatches(query.NetworkService); err != nil {
return err
return nil
}
return next.NetworkServiceRegistryServer(s.Context()).Find(query, s)
}

func (n *networkServiceRegistryServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*empty.Empty, error) {
n.networkServices.Delete(ns.Name)
return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns)
}
func (s *memoryNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*empty.Empty, error) {
s.networkServices.Delete(ns.Name)

func (n *networkServiceRegistryServer) setEventChannelSize(l int) {
n.eventChannelSize = l
}

// NewNetworkServiceRegistryServer creates new memory based NetworkServiceRegistryServer
func NewNetworkServiceRegistryServer(options ...Option) registry.NetworkServiceRegistryServer {
r := &networkServiceRegistryServer{
eventChannelSize: defaultEventChannelSize,
eventChannels: make(map[string]chan *registry.NetworkService),
}
for _, o := range options {
o.apply(r)
}
return r
return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns)
}
136 changes: 134 additions & 2 deletions pkg/registry/common/memory/ns_server_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 Doc.ai and/or its affiliates.
// Copyright (c) 2020-2021 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -18,14 +18,20 @@ package memory_test

import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/networkservicemesh/api/pkg/api/registry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/protobuf/proto"

"github.com/networkservicemesh/api/pkg/api/registry"

"github.com/networkservicemesh/sdk/pkg/registry/common/memory"
"github.com/networkservicemesh/sdk/pkg/registry/core/adapters"
"github.com/networkservicemesh/sdk/pkg/registry/core/next"
"github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel"
)
Expand Down Expand Up @@ -106,3 +112,129 @@ func TestNetworkServiceRegistryServer_RegisterAndFindWatch(t *testing.T) {
require.NoError(t, err)
require.True(t, proto.Equal(expected, <-ch))
}

func TestNetworkServiceRegistryServer_DataRace(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s := memory.NewNetworkServiceRegistryServer()

_, err := s.Register(ctx, &registry.NetworkService{Name: "ns"})
require.NoError(t, err)

c := adapters.NetworkServiceServerToClient(s)

var wgStart, wgEnd sync.WaitGroup
for i := 0; i < 10; i++ {
wgStart.Add(1)
wgEnd.Add(1)
go func() {
defer wgEnd.Done()

findCtx, findCancel := context.WithTimeout(ctx, time.Second)
defer findCancel()

stream, err := c.Find(findCtx, &registry.NetworkServiceQuery{
NetworkService: &registry.NetworkService{Name: "ns"},
Watch: true,
})
assert.NoError(t, err)

_, err = stream.Recv()
assert.NoError(t, err)

wgStart.Done()

for j := 0; j < 100; j++ {
ns, err := stream.Recv()
assert.NoError(t, err)

ns.Name = ""
}
}()
}
wgStart.Wait()

for i := 0; i < 100; i++ {
_, err := s.Register(ctx, &registry.NetworkService{Name: fmt.Sprintf("ns-%d", i)})
require.NoError(t, err)
}

wgEnd.Wait()
}

func TestNetworkServiceRegistryServer_SlowReceiver(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s := memory.NewNetworkServiceRegistryServer()

c := adapters.NetworkServiceServerToClient(s)

findCtx, findCancel := context.WithCancel(ctx)

stream, err := c.Find(findCtx, &registry.NetworkServiceQuery{
NetworkService: &registry.NetworkService{Name: "ns"},
Watch: true,
})
require.NoError(t, err)

for i := 0; i < 1000; i++ {
_, err = s.Register(ctx, &registry.NetworkService{Name: fmt.Sprintf("ns-%d", i)})
require.NoError(t, err)
}

ignoreCurrent := goleak.IgnoreCurrent()

_, err = stream.Recv()
require.NoError(t, err)

findCancel()

require.Eventually(t, func() bool {
return goleak.Find(ignoreCurrent) == nil
}, 100*time.Millisecond, time.Millisecond)
}

func TestNetworkServiceRegistryServer_ShouldReceiveAllRegisters(t *testing.T) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s := memory.NewNetworkServiceRegistryServer()

c := adapters.NetworkServiceServerToClient(s)

var wg sync.WaitGroup
for i := 0; i < 300; i++ {
wg.Add(1)
name := fmt.Sprintf("ns-%d", i)

go func() {
_, err := s.Register(ctx, &registry.NetworkService{Name: name})
require.NoError(t, err)
}()

go func() {
defer wg.Done()

findCtx, findCancel := context.WithTimeout(ctx, time.Second)
defer findCancel()

stream, err := c.Find(findCtx, &registry.NetworkServiceQuery{
NetworkService: &registry.NetworkService{Name: name},
Watch: true,
})
assert.NoError(t, err)

_, err = stream.Recv()
assert.NoError(t, err)
}()
}
wg.Wait()
}
Loading

0 comments on commit 536f38f

Please sign in to comment.