From 74df9580ee627a5eb15427c9888b72ca10bf57e8 Mon Sep 17 00:00:00 2001 From: Giovanni Liva Date: Wed, 20 Dec 2023 21:08:47 +0100 Subject: [PATCH] fix: nil client causes panic (#408) Signed-off-by: Giovanni Liva Signed-off-by: Florian Bacher Co-authored-by: Florian Bacher --- providers/flagd/pkg/provider.go | 15 ++- providers/flagd/pkg/service/rpc/service.go | 58 ++++++++- .../service/rpc/service_evaluation_test.go | 122 ++++++++++++++++++ 3 files changed, 190 insertions(+), 5 deletions(-) diff --git a/providers/flagd/pkg/provider.go b/providers/flagd/pkg/provider.go index 6c88c331e..a335af7bf 100644 --- a/providers/flagd/pkg/provider.go +++ b/providers/flagd/pkg/provider.go @@ -3,13 +3,13 @@ package flagd import ( "context" "fmt" - "github.com/go-logr/logr" "github.com/open-feature/go-sdk-contrib/providers/flagd/internal/cache" "github.com/open-feature/go-sdk-contrib/providers/flagd/internal/logger" "github.com/open-feature/go-sdk-contrib/providers/flagd/pkg/service/in_process" rpcService "github.com/open-feature/go-sdk-contrib/providers/flagd/pkg/service/rpc" of "github.com/open-feature/go-sdk/openfeature" + "sync" ) type Provider struct { @@ -17,6 +17,7 @@ type Provider struct { providerConfiguration *providerConfiguration service IService status of.State + mtx sync.RWMutex eventStream chan of.Event } @@ -94,9 +95,9 @@ func (p *Provider) Init(evaluationContext of.EvaluationContext) error { switch event.EventType { case of.ProviderReady: case of.ProviderConfigChange: - p.status = of.ReadyState + p.setStatus(of.ReadyState) case of.ProviderError: - p.status = of.ErrorState + p.setStatus(of.ErrorState) } } }() @@ -105,6 +106,8 @@ func (p *Provider) Init(evaluationContext of.EvaluationContext) error { } func (p *Provider) Status() of.State { + p.mtx.RLock() + defer p.mtx.RUnlock() return p.status } @@ -158,6 +161,12 @@ func (p *Provider) ObjectEvaluation( return p.service.ResolveObject(ctx, flagKey, defaultValue, evalCtx) } +func (p *Provider) setStatus(status of.State) { + p.mtx.Lock() + defer p.mtx.Unlock() + p.status = status +} + // ProviderOptions type ProviderOption func(*Provider) diff --git a/providers/flagd/pkg/service/rpc/service.go b/providers/flagd/pkg/service/rpc/service.go index c7a29d1ba..402211cc6 100644 --- a/providers/flagd/pkg/service/rpc/service.go +++ b/providers/flagd/pkg/service/rpc/service.go @@ -26,9 +26,12 @@ import ( ) const ( - ReasonCached = "CACHED" + ReasonCached = "CACHED" + ClientNotReadyMsg = "client did not yet finish the initialization" ) +var ErrClientNotReady = of.NewProviderNotReadyResolutionError(ClientNotReadyMsg) + type Configuration struct { Port uint16 Host string @@ -91,7 +94,9 @@ func (s *Service) Init() error { } func (s *Service) Shutdown() { - s.cancelHook() + if s.cancelHook != nil { + s.cancelHook() + } } // ResolveBoolean handles the flag evaluation response from the flagd ResolveBoolean rpc @@ -109,6 +114,15 @@ func (s *Service) ResolveBoolean(ctx context.Context, key string, defaultValue b } } + if !s.isInitialised() { + return of.BoolResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: ErrClientNotReady, + }, + } + } + var e of.ResolutionError resp, err := resolve[schemaV1.ResolveBooleanRequest, schemaV1.ResolveBooleanResponse]( ctx, s.logger, s.client.ResolveBoolean, key, evalCtx, @@ -158,6 +172,15 @@ func (s *Service) ResolveString(ctx context.Context, key string, defaultValue st } } + if !s.isInitialised() { + return of.StringResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: ErrClientNotReady, + }, + } + } + var e of.ResolutionError resp, err := resolve[schemaV1.ResolveStringRequest, schemaV1.ResolveStringResponse]( ctx, s.logger, s.client.ResolveString, key, evalCtx, @@ -207,6 +230,15 @@ func (s *Service) ResolveFloat(ctx context.Context, key string, defaultValue flo } } + if !s.isInitialised() { + return of.FloatResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: ErrClientNotReady, + }, + } + } + var e of.ResolutionError resp, err := resolve[schemaV1.ResolveFloatRequest, schemaV1.ResolveFloatResponse]( ctx, s.logger, s.client.ResolveFloat, key, evalCtx, @@ -256,6 +288,15 @@ func (s *Service) ResolveInt(ctx context.Context, key string, defaultValue int64 } } + if !s.isInitialised() { + return of.IntResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: ErrClientNotReady, + }, + } + } + var e of.ResolutionError resp, err := resolve[schemaV1.ResolveIntRequest, schemaV1.ResolveIntResponse]( ctx, s.logger, s.client.ResolveInt, key, evalCtx, @@ -304,6 +345,15 @@ func (s *Service) ResolveObject(ctx context.Context, key string, defaultValue in } } + if !s.isInitialised() { + return of.InterfaceResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: ErrClientNotReady, + }, + } + } + var e of.ResolutionError resp, err := resolve[schemaV1.ResolveObjectRequest, schemaV1.ResolveObjectResponse]( ctx, s.logger, s.client.ResolveObject, key, evalCtx, @@ -338,6 +388,10 @@ func (s *Service) ResolveObject(ctx context.Context, key string, defaultValue in return detail } +func (s *Service) isInitialised() bool { + return s.client != nil +} + func resolve[req resolutionRequestConstraints, resp resolutionResponseConstraints]( ctx context.Context, logger logr.Logger, resolver func(context.Context, *connect.Request[req]) (*connect.Response[resp], error), diff --git a/providers/flagd/pkg/service/rpc/service_evaluation_test.go b/providers/flagd/pkg/service/rpc/service_evaluation_test.go index 7c215cbeb..884ae1509 100644 --- a/providers/flagd/pkg/service/rpc/service_evaluation_test.go +++ b/providers/flagd/pkg/service/rpc/service_evaluation_test.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "net/http" "strings" "testing" @@ -151,6 +152,23 @@ func TestBooleanEvaluation(t *testing.T) { isCached: false, errorText: string(of.FlagNotFoundCode), }, + { + name: "simple error check - client not initialised", + getCache: func() *cache.Service { + return cache.NewCacheService(cache.DisabledValue, 0, log) + }, + getMockClient: func() schemaConnectV1.ServiceClient { + return nil + }, + expectResponse: of.BoolResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"), + }, + }, + isCached: false, + errorText: string(of.ProviderNotReadyCode), + }, } for _, test := range tests { @@ -269,6 +287,23 @@ func TestStringEvaluation(t *testing.T) { isCached: false, errorText: string(of.FlagNotFoundCode), }, + { + name: "simple error check - client not initialised", + getCache: func() *cache.Service { + return cache.NewCacheService(cache.DisabledValue, 0, log) + }, + getMockClient: func() schemaConnectV1.ServiceClient { + return nil + }, + expectResponse: of.StringResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"), + }, + }, + isCached: false, + errorText: string(of.ProviderNotReadyCode), + }, } for _, test := range tests { @@ -387,6 +422,23 @@ func TestFloatEvaluation(t *testing.T) { isCached: false, errorText: string(of.FlagNotFoundCode), }, + { + name: "simple error check - client not initialised", + getCache: func() *cache.Service { + return cache.NewCacheService(cache.DisabledValue, 0, log) + }, + getMockClient: func() schemaConnectV1.ServiceClient { + return nil + }, + expectResponse: of.FloatResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"), + }, + }, + isCached: false, + errorText: string(of.ProviderNotReadyCode), + }, } for _, test := range tests { @@ -505,6 +557,23 @@ func TestIntEvaluation(t *testing.T) { isCached: false, errorText: string(of.FlagNotFoundCode), }, + { + name: "simple error check - client not initialised", + getCache: func() *cache.Service { + return cache.NewCacheService(cache.DisabledValue, 0, log) + }, + getMockClient: func() schemaConnectV1.ServiceClient { + return nil + }, + expectResponse: of.IntResolutionDetail{ + Value: int64(defaultValue), + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"), + }, + }, + isCached: false, + errorText: string(of.ProviderNotReadyCode), + }, } for _, test := range tests { @@ -636,6 +705,23 @@ func TestObjectEvaluation(t *testing.T) { isCached: false, errorText: string(of.FlagNotFoundCode), }, + { + name: "simple error check - client not ready", + getCache: func() *cache.Service { + return cache.NewCacheService(cache.DisabledValue, 0, log) + }, + getMockClient: func() schemaConnectV1.ServiceClient { + return nil + }, + expectResponse: of.InterfaceResolutionDetail{ + Value: defaultValue, + ProviderResolutionDetail: of.ProviderResolutionDetail{ + ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"), + }, + }, + isCached: false, + errorText: string(of.ProviderNotReadyCode), + }, } for _, test := range tests { @@ -672,3 +758,39 @@ func validate[T responseType](t *testing.T, test testStruct[T], resolutionDetail test.name, test.errorText, error.Error()) } } + +func TestService_isInitialised(t *testing.T) { + type fields struct { + client schemaConnectV1.ServiceClient + } + tests := []struct { + name string + fields fields + want bool + }{ + { + name: "not initialised", + fields: fields{ + client: nil, + }, + want: false, + }, + { + name: "initialised", + fields: fields{ + client: schemaConnectV1.NewServiceClient(http.DefaultClient, ""), + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{ + client: tt.fields.client, + } + if got := s.isInitialised(); got != tt.want { + t.Errorf("isInitialised() = %v, want %v", got, tt.want) + } + }) + } +}