Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: nil client causes panic #408

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions providers/flagd/pkg/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@ package flagd
import (
"context"
"fmt"

"github.com/go-logr/logr"
"github.com/open-feature/go-sdk-contrib/providers/flagd/internal/cache"
"github.com/open-feature/go-sdk-contrib/providers/flagd/internal/logger"
"github.com/open-feature/go-sdk-contrib/providers/flagd/pkg/service/in_process"
rpcService "github.com/open-feature/go-sdk-contrib/providers/flagd/pkg/service/rpc"
of "github.com/open-feature/go-sdk/openfeature"
"sync"
)

type Provider struct {
logger logr.Logger
providerConfiguration *providerConfiguration
service IService
status of.State
mtx sync.RWMutex

eventStream chan of.Event
}
Expand Down Expand Up @@ -94,9 +95,9 @@ func (p *Provider) Init(evaluationContext of.EvaluationContext) error {
switch event.EventType {
case of.ProviderReady:
case of.ProviderConfigChange:
p.status = of.ReadyState
p.setStatus(of.ReadyState)
case of.ProviderError:
p.status = of.ErrorState
p.setStatus(of.ErrorState)
}
}
}()
Expand All @@ -105,6 +106,8 @@ func (p *Provider) Init(evaluationContext of.EvaluationContext) error {
}

func (p *Provider) Status() of.State {
p.mtx.RLock()
defer p.mtx.RUnlock()
return p.status
}

Expand Down Expand Up @@ -158,6 +161,12 @@ func (p *Provider) ObjectEvaluation(
return p.service.ResolveObject(ctx, flagKey, defaultValue, evalCtx)
}

func (p *Provider) setStatus(status of.State) {
p.mtx.Lock()
defer p.mtx.Unlock()
p.status = status
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume this is atomic and therefore locking doesn't help? 🤔

Copy link
Member

@toddbaert toddbaert Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this might be needed actually, to prevent interleaving threads checking it during init.

I think I thread isn't blocked by a lock it already owns, right? So I guess you might want to lock in Init() and well as here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do need it, yes - if you execute the tests with the -race flag enabled, a data race error occurs, due to p.status being set in the goroutine created within the Init() function, while being accessed via the Status() function in the test

}

// ProviderOptions

type ProviderOption func(*Provider)
Expand Down
58 changes: 56 additions & 2 deletions providers/flagd/pkg/service/rpc/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ import (
)

const (
ReasonCached = "CACHED"
ReasonCached = "CACHED"
ClientNotReadyMsg = "client did not yet finish the initialization"
)

var ErrClientNotReady = of.NewProviderNotReadyResolutionError(ClientNotReadyMsg)

type Configuration struct {
Port uint16
Host string
Expand Down Expand Up @@ -91,7 +94,9 @@ func (s *Service) Init() error {
}

func (s *Service) Shutdown() {
s.cancelHook()
if s.cancelHook != nil {
s.cancelHook()
}
}

// ResolveBoolean handles the flag evaluation response from the flagd ResolveBoolean rpc
Expand All @@ -109,6 +114,15 @@ func (s *Service) ResolveBoolean(ctx context.Context, key string, defaultValue b
}
}

if !s.isInitialised() {
return of.BoolResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: ErrClientNotReady,
},
}
}

var e of.ResolutionError
resp, err := resolve[schemaV1.ResolveBooleanRequest, schemaV1.ResolveBooleanResponse](
ctx, s.logger, s.client.ResolveBoolean, key, evalCtx,
Expand Down Expand Up @@ -158,6 +172,15 @@ func (s *Service) ResolveString(ctx context.Context, key string, defaultValue st
}
}

if !s.isInitialised() {
return of.StringResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: ErrClientNotReady,
},
}
}

var e of.ResolutionError
resp, err := resolve[schemaV1.ResolveStringRequest, schemaV1.ResolveStringResponse](
ctx, s.logger, s.client.ResolveString, key, evalCtx,
Expand Down Expand Up @@ -207,6 +230,15 @@ func (s *Service) ResolveFloat(ctx context.Context, key string, defaultValue flo
}
}

if !s.isInitialised() {
return of.FloatResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: ErrClientNotReady,
},
}
}

var e of.ResolutionError
resp, err := resolve[schemaV1.ResolveFloatRequest, schemaV1.ResolveFloatResponse](
ctx, s.logger, s.client.ResolveFloat, key, evalCtx,
Expand Down Expand Up @@ -256,6 +288,15 @@ func (s *Service) ResolveInt(ctx context.Context, key string, defaultValue int64
}
}

if !s.isInitialised() {
return of.IntResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: ErrClientNotReady,
},
}
}

var e of.ResolutionError
resp, err := resolve[schemaV1.ResolveIntRequest, schemaV1.ResolveIntResponse](
ctx, s.logger, s.client.ResolveInt, key, evalCtx,
Expand Down Expand Up @@ -304,6 +345,15 @@ func (s *Service) ResolveObject(ctx context.Context, key string, defaultValue in
}
}

if !s.isInitialised() {
return of.InterfaceResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: ErrClientNotReady,
},
}
}

var e of.ResolutionError
resp, err := resolve[schemaV1.ResolveObjectRequest, schemaV1.ResolveObjectResponse](
ctx, s.logger, s.client.ResolveObject, key, evalCtx,
Expand Down Expand Up @@ -338,6 +388,10 @@ func (s *Service) ResolveObject(ctx context.Context, key string, defaultValue in
return detail
}

func (s *Service) isInitialised() bool {
return s.client != nil
}

func resolve[req resolutionRequestConstraints, resp resolutionResponseConstraints](
ctx context.Context, logger logr.Logger,
resolver func(context.Context, *connect.Request[req]) (*connect.Response[resp], error),
Expand Down
122 changes: 122 additions & 0 deletions providers/flagd/pkg/service/rpc/service_evaluation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rpc

import (
"context"
"net/http"
"strings"
"testing"

Expand Down Expand Up @@ -151,6 +152,23 @@ func TestBooleanEvaluation(t *testing.T) {
isCached: false,
errorText: string(of.FlagNotFoundCode),
},
{
name: "simple error check - client not initialised",
getCache: func() *cache.Service {
return cache.NewCacheService(cache.DisabledValue, 0, log)
},
getMockClient: func() schemaConnectV1.ServiceClient {
return nil
},
expectResponse: of.BoolResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"),
},
},
isCached: false,
errorText: string(of.ProviderNotReadyCode),
},
}

for _, test := range tests {
Expand Down Expand Up @@ -269,6 +287,23 @@ func TestStringEvaluation(t *testing.T) {
isCached: false,
errorText: string(of.FlagNotFoundCode),
},
{
name: "simple error check - client not initialised",
getCache: func() *cache.Service {
return cache.NewCacheService(cache.DisabledValue, 0, log)
},
getMockClient: func() schemaConnectV1.ServiceClient {
return nil
},
expectResponse: of.StringResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"),
},
},
isCached: false,
errorText: string(of.ProviderNotReadyCode),
},
}

for _, test := range tests {
Expand Down Expand Up @@ -387,6 +422,23 @@ func TestFloatEvaluation(t *testing.T) {
isCached: false,
errorText: string(of.FlagNotFoundCode),
},
{
name: "simple error check - client not initialised",
getCache: func() *cache.Service {
return cache.NewCacheService(cache.DisabledValue, 0, log)
},
getMockClient: func() schemaConnectV1.ServiceClient {
return nil
},
expectResponse: of.FloatResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"),
},
},
isCached: false,
errorText: string(of.ProviderNotReadyCode),
},
}

for _, test := range tests {
Expand Down Expand Up @@ -505,6 +557,23 @@ func TestIntEvaluation(t *testing.T) {
isCached: false,
errorText: string(of.FlagNotFoundCode),
},
{
name: "simple error check - client not initialised",
getCache: func() *cache.Service {
return cache.NewCacheService(cache.DisabledValue, 0, log)
},
getMockClient: func() schemaConnectV1.ServiceClient {
return nil
},
expectResponse: of.IntResolutionDetail{
Value: int64(defaultValue),
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"),
},
},
isCached: false,
errorText: string(of.ProviderNotReadyCode),
},
}

for _, test := range tests {
Expand Down Expand Up @@ -636,6 +705,23 @@ func TestObjectEvaluation(t *testing.T) {
isCached: false,
errorText: string(of.FlagNotFoundCode),
},
{
name: "simple error check - client not ready",
getCache: func() *cache.Service {
return cache.NewCacheService(cache.DisabledValue, 0, log)
},
getMockClient: func() schemaConnectV1.ServiceClient {
return nil
},
expectResponse: of.InterfaceResolutionDetail{
Value: defaultValue,
ProviderResolutionDetail: of.ProviderResolutionDetail{
ResolutionError: of.NewFlagNotFoundResolutionError("requested flag not found"),
},
},
isCached: false,
errorText: string(of.ProviderNotReadyCode),
},
}

for _, test := range tests {
Expand Down Expand Up @@ -672,3 +758,39 @@ func validate[T responseType](t *testing.T, test testStruct[T], resolutionDetail
test.name, test.errorText, error.Error())
}
}

func TestService_isInitialised(t *testing.T) {
type fields struct {
client schemaConnectV1.ServiceClient
}
tests := []struct {
name string
fields fields
want bool
}{
{
name: "not initialised",
fields: fields{
client: nil,
},
want: false,
},
{
name: "initialised",
fields: fields{
client: schemaConnectV1.NewServiceClient(http.DefaultClient, ""),
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Service{
client: tt.fields.client,
}
if got := s.isInitialised(); got != tt.want {
t.Errorf("isInitialised() = %v, want %v", got, tt.want)
}
})
}
}
Loading