Skip to content

Commit

Permalink
credentials: refresh token in GetRequestMetadata
Browse files Browse the repository at this point in the history
  • Loading branch information
jschwinger233 committed Jul 19, 2020
1 parent 92d7feb commit e1664e1
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions clientv3/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ type Config struct {
// Bundle defines gRPC credential interface.
type Bundle interface {
grpccredentials.Bundle
UpdateAuthToken(token string)
NeedRefreshAuthToken()
}

// NewBundle constructs a new gRPC credential bundle.
func NewBundle(cfg Config) Bundle {
func NewBundle(cfg Config, refreshTokens ...func(context.Context) (string, error)) Bundle {
var refreshToken func(context.Context) (string, error)
if refreshTokens != nil {
refreshToken = refreshTokens[0]
}
return &bundle{
tc: newTransportCredential(cfg.TLSConfig),
rc: newPerRPCCredential(),
rc: newPerRPCCredential(refreshToken),
}
}

Expand All @@ -51,6 +55,10 @@ type bundle struct {
rc *perRPCCredential
}

func (b *bundle) NeedRefreshAuthToken() {
b.rc.needRefreshAuthToken()
}

func (b *bundle) TransportCredentials() grpccredentials.TransportCredentials {
return b.tc
}
Expand Down Expand Up @@ -99,30 +107,38 @@ func (tc *transportCredential) OverrideServerName(serverNameOverride string) err

// perRPCCredential implements "grpccredentials.PerRPCCredentials" interface.
type perRPCCredential struct {
authToken string
authTokenMu sync.RWMutex
authToken string
needFresh bool
// refreshingMux ensures only one request at a time will be allowed to fetch token
refreshingMux sync.Mutex
refreshAuthToken func(context.Context) (string, error)
}

func newPerRPCCredential() *perRPCCredential { return &perRPCCredential{} }
func newPerRPCCredential(refreshToken func(context.Context) (string, error)) *perRPCCredential {
return &perRPCCredential{
needFresh: true,
refreshAuthToken: refreshToken,
}
}

func (rc *perRPCCredential) RequireTransportSecurity() bool { return false }

func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
rc.authTokenMu.RLock()
authToken := rc.authToken
rc.authTokenMu.RUnlock()
return map[string]string{rpctypes.TokenFieldNameGRPC: authToken}, nil
}

func (b *bundle) UpdateAuthToken(token string) {
if b.rc == nil {
// GetRequestMetadata is designed (and also implemented) to be called per RPC to refresh credential
func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) (data map[string]string, err error) {
data = map[string]string{rpctypes.TokenFieldNameGRPC: rc.authToken}
rc.refreshingMux.Lock()
defer rc.refreshingMux.Unlock()
if !rc.needFresh {
return
}
if rc.authToken, err = rc.refreshAuthToken(ctx); err != nil {
return
}
b.rc.UpdateAuthToken(token)
rc.needFresh = false
data[rpctypes.TokenFieldNameGRPC] = rc.authToken
return
}

func (rc *perRPCCredential) UpdateAuthToken(token string) {
rc.authTokenMu.Lock()
rc.authToken = token
rc.authTokenMu.Unlock()
func (rc *perRPCCredential) needRefreshAuthToken() {
rc.needFresh = true
}

0 comments on commit e1664e1

Please sign in to comment.