Skip to content

Commit

Permalink
[extension/oauth2clientauth] Use new client auth helpers
Browse files Browse the repository at this point in the history
Signed-off-by: Juraci Paixão Kröhling <juraci@kroehling.de>
  • Loading branch information
jpkrohling committed Apr 8, 2022
1 parent a9b5da4 commit 8a0958b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 56 deletions.
31 changes: 8 additions & 23 deletions extension/oauth2clientauthextension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import (
"fmt"
"net/http"

"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/config/configauth"
"go.uber.org/multierr"
"go.uber.org/zap"
"golang.org/x/oauth2"
Expand All @@ -29,17 +27,14 @@ import (
grpcOAuth "google.golang.org/grpc/credentials/oauth"
)

// ClientCredentialsAuthenticator provides implementation for providing client authentication using OAuth2 client credentials
// clientAuthenticator provides implementation for providing client authentication using OAuth2 client credentials
// workflow for both gRPC and HTTP clients.
type ClientCredentialsAuthenticator struct {
type clientAuthenticator struct {
clientCredentials *clientcredentials.Config
logger *zap.Logger
client *http.Client
}

// ClientCredentialsAuthenticator implements ClientAuthenticator
var _ configauth.ClientAuthenticator = (*ClientCredentialsAuthenticator)(nil)

type errorWrappingTokenSource struct {
ts oauth2.TokenSource
tokenURL string
Expand All @@ -51,7 +46,7 @@ var _ oauth2.TokenSource = (*errorWrappingTokenSource)(nil)
// errFailedToGetSecurityToken indicates a problem communicating with OAuth2 server.
var errFailedToGetSecurityToken = fmt.Errorf("failed to get security token from token endpoint")

func newClientCredentialsExtension(cfg *Config, logger *zap.Logger) (*ClientCredentialsAuthenticator, error) {
func newClientAuthenticator(cfg *Config, logger *zap.Logger) (*clientAuthenticator, error) {
if cfg.ClientID == "" {
return nil, errNoClientIDProvided
}
Expand All @@ -70,7 +65,7 @@ func newClientCredentialsExtension(cfg *Config, logger *zap.Logger) (*ClientCred
}
transport.TLSClientConfig = tlsCfg

return &ClientCredentialsAuthenticator{
return &clientAuthenticator{
clientCredentials: &clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Expand All @@ -86,16 +81,6 @@ func newClientCredentialsExtension(cfg *Config, logger *zap.Logger) (*ClientCred
}, nil
}

// Start for ClientCredentialsAuthenticator extension does nothing
func (o *ClientCredentialsAuthenticator) Start(_ context.Context, _ component.Host) error {
return nil
}

// Shutdown for ClientCredentialsAuthenticator extension does nothing
func (o *ClientCredentialsAuthenticator) Shutdown(_ context.Context) error {
return nil
}

func (ewts errorWrappingTokenSource) Token() (*oauth2.Token, error) {
tok, err := ewts.ts.Token()
if err != nil {
Expand All @@ -106,9 +91,9 @@ func (ewts errorWrappingTokenSource) Token() (*oauth2.Token, error) {
return tok, nil
}

// RoundTripper returns oauth2.Transport, an http.RoundTripper that performs "client-credential" OAuth flow and
// roundTripper returns oauth2.Transport, an http.RoundTripper that performs "client-credential" OAuth flow and
// also auto refreshes OAuth tokens as needed.
func (o *ClientCredentialsAuthenticator) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) {
func (o *clientAuthenticator) roundTripper(base http.RoundTripper) (http.RoundTripper, error) {
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, o.client)
return &oauth2.Transport{
Source: errorWrappingTokenSource{
Expand All @@ -119,9 +104,9 @@ func (o *ClientCredentialsAuthenticator) RoundTripper(base http.RoundTripper) (h
}, nil
}

// PerRPCCredentials returns gRPC PerRPCCredentials that supports "client-credential" OAuth flow. The underneath
// perRPCCredentials returns gRPC PerRPCCredentials that supports "client-credential" OAuth flow. The underneath
// oauth2.clientcredentials.Config instance will manage tokens performing auto refresh as necessary.
func (o *ClientCredentialsAuthenticator) PerRPCCredentials() (credentials.PerRPCCredentials, error) {
func (o *clientAuthenticator) perRPCCredentials() (credentials.PerRPCCredentials, error) {
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, o.client)
return grpcOAuth.TokenSource{
TokenSource: errorWrappingTokenSource{
Expand Down
40 changes: 8 additions & 32 deletions extension/oauth2clientauthextension/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestOAuthClientSettings(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
rc, err := newClientCredentialsExtension(test.settings, zap.NewNop())
rc, err := newClientAuthenticator(test.settings, zap.NewNop())
if test.shouldError {
assert.NotNil(t, err)
assert.Contains(t, err.Error(), test.expectedError)
Expand Down Expand Up @@ -185,15 +185,15 @@ func TestRoundTripper(t *testing.T) {

for _, testcase := range tests {
t.Run(testcase.name, func(t *testing.T) {
oauth2Authenticator, err := newClientCredentialsExtension(testcase.settings, zap.NewNop())
oauth2Authenticator, err := newClientAuthenticator(testcase.settings, zap.NewNop())
if testcase.shouldError {
assert.Error(t, err)
assert.Nil(t, oauth2Authenticator)
return
}

assert.NotNil(t, oauth2Authenticator)
roundTripper, err := oauth2Authenticator.RoundTripper(baseRoundTripper)
roundTripper, err := oauth2Authenticator.roundTripper(baseRoundTripper)
assert.Nil(t, err)

// test roundTripper is an OAuth RoundTripper
Expand Down Expand Up @@ -239,14 +239,14 @@ func TestOAuth2PerRPCCredentials(t *testing.T) {

for _, testcase := range tests {
t.Run(testcase.name, func(t *testing.T) {
oauth2Authenticator, err := newClientCredentialsExtension(testcase.settings, zap.NewNop())
oauth2Authenticator, err := newClientAuthenticator(testcase.settings, zap.NewNop())
if testcase.shouldError {
assert.Error(t, err)
assert.Nil(t, oauth2Authenticator)
return
}
assert.NoError(t, err)
perRPCCredentials, err := oauth2Authenticator.PerRPCCredentials()
perRPCCredentials, err := oauth2Authenticator.perRPCCredentials()
assert.Nil(t, err)
// test perRPCCredentials is an grpc OAuthTokenSource
_, ok := perRPCCredentials.(grpcOAuth.TokenSource)
Expand All @@ -255,30 +255,6 @@ func TestOAuth2PerRPCCredentials(t *testing.T) {
}
}

func TestOAuthExtensionStart(t *testing.T) {
oAuthExtensionAuth, err := newClientCredentialsExtension(
&Config{
ClientID: "testclientid",
ClientSecret: "testsecret",
TokenURL: "https://example.com/v1/token",
Scopes: []string{"resource.read"},
}, nil)
assert.Nil(t, err)
assert.Nil(t, oAuthExtensionAuth.Start(context.Background(), nil))
}

func TestOAuthExtensionShutdown(t *testing.T) {
oAuthExtensionAuth, err := newClientCredentialsExtension(
&Config{
ClientID: "testclientid",
ClientSecret: "testsecret",
TokenURL: "https://example.com/v1/token",
Scopes: []string{"resource.read"},
}, nil)
assert.Nil(t, err)
assert.Nil(t, oAuthExtensionAuth.Shutdown(context.Background()))
}

func TestFailContactingOAuth(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
Expand All @@ -289,15 +265,15 @@ func TestFailContactingOAuth(t *testing.T) {
serverURL, err := url.Parse(server.URL)
assert.NoError(t, err)

oauth2Authenticator, err := newClientCredentialsExtension(&Config{
oauth2Authenticator, err := newClientAuthenticator(&Config{
ClientID: "dummy",
ClientSecret: "ABC",
TokenURL: serverURL.String(),
}, zap.NewNop())
assert.Nil(t, err)

// Test for gRPC connections
credential, err := oauth2Authenticator.PerRPCCredentials()
credential, err := oauth2Authenticator.perRPCCredentials()
assert.Nil(t, err)

_, err = credential.GetRequestMetadata(context.Background())
Expand All @@ -308,7 +284,7 @@ func TestFailContactingOAuth(t *testing.T) {
setting := confighttp.HTTPClientSettings{
Endpoint: "http://example.com/",
CustomRoundTripper: func(next http.RoundTripper) (http.RoundTripper, error) {
return oauth2Authenticator.RoundTripper(next)
return oauth2Authenticator.roundTripper(next)
},
}

Expand Down
11 changes: 10 additions & 1 deletion extension/oauth2clientauthextension/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/config"
"go.opentelemetry.io/collector/config/configauth"
)

const (
Expand All @@ -41,5 +42,13 @@ func createDefaultConfig() config.Extension {
}

func createExtension(_ context.Context, set component.ExtensionCreateSettings, cfg config.Extension) (component.Extension, error) {
return newClientCredentialsExtension(cfg.(*Config), set.Logger)
ca, err := newClientAuthenticator(cfg.(*Config), set.Logger)
if err != nil {
return nil, err
}

return configauth.NewClientAuthenticator(
configauth.WithClientRoundTripper(ca.roundTripper),
configauth.WithPerRPCCredentials(ca.perRPCCredentials),
), nil
}

0 comments on commit 8a0958b

Please sign in to comment.