diff --git a/pkg/registry/common/begin/nse_server.go b/pkg/registry/common/begin/nse_server.go index 8c4d39ed1..c949ec02d 100644 --- a/pkg/registry/common/begin/nse_server.go +++ b/pkg/registry/common/begin/nse_server.go @@ -100,8 +100,12 @@ func (b *beginNSEServer) Unregister(ctx context.Context, in *registry.NetworkSer _, err = b.Unregister(ctx, in) return } + registration := in + if eventFactoryServer.registration != nil { + registration = eventFactoryServer.registration + } withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer) - _, err = next.NetworkServiceEndpointRegistryServer(withEventFactoryCtx).Unregister(withEventFactoryCtx, eventFactoryServer.registration) + _, err = next.NetworkServiceEndpointRegistryServer(withEventFactoryCtx).Unregister(withEventFactoryCtx, registration) eventFactoryServer.afterCloseFunc() }) return &emptypb.Empty{}, err diff --git a/pkg/registry/common/begin/serialize_client_test.go b/pkg/registry/common/begin/serialize_client_test.go index 2429bb989..883b18efb 100644 --- a/pkg/registry/common/begin/serialize_client_test.go +++ b/pkg/registry/common/begin/serialize_client_test.go @@ -66,6 +66,7 @@ func TestSerializeClient_StressTest(t *testing.T) { type parallelClient struct { t *testing.T states sync.Map + mu sync.Mutex } func newParallelClient(t *testing.T) *parallelClient { @@ -78,10 +79,12 @@ func (s *parallelClient) Register(ctx context.Context, in *registry.NetworkServi raw, _ := s.states.LoadOrStore(in.GetName(), new(int32)) statePtr := raw.(*int32) + s.mu.Lock() state := atomic.LoadInt32(statePtr) if !atomic.CompareAndSwapInt32(statePtr, state, state+1) { assert.Failf(s.t, "", "state has been changed for connection %s expected %d actual %d", in.GetName(), state, atomic.LoadInt32(statePtr)) } + s.mu.Unlock() return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) } @@ -94,10 +97,12 @@ func (s *parallelClient) Unregister(ctx context.Context, in *registry.NetworkSer raw, _ := s.states.LoadOrStore(in.GetName(), new(int32)) statePtr := raw.(*int32) + s.mu.Lock() state := atomic.LoadInt32(statePtr) if !atomic.CompareAndSwapInt32(statePtr, state, state+1) { assert.Failf(s.t, "", "state has been changed for connection %s expected %d actual %d", in.GetName(), state, atomic.LoadInt32(statePtr)) } + s.mu.Unlock() return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) }