Skip to content

Commit

Permalink
Propagate correct User-Agent for CLI (#1264)
Browse files Browse the repository at this point in the history
## Changes
This PR migrates `databricks auth login` HTTP client to the one from Go
SDK, making API calls more robust and containing our unified user agent.

## Tests
Unit tests left almost unchanged
  • Loading branch information
nfx authored Mar 11, 2024
1 parent 4a9a12a commit 945d522
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 48 deletions.
47 changes: 18 additions & 29 deletions libs/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@ import (
"crypto/sha256"
_ "embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"

"github.com/databricks/cli/libs/auth/cache"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/retries"
"github.com/pkg/browser"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -43,16 +41,12 @@ type PersistentAuth struct {
Host string
AccountID string

http httpGet
http *httpclient.ApiClient
cache tokenCache
ln net.Listener
browser func(string) error
}

type httpGet interface {
Get(string) (*http.Response, error)
}

type tokenCache interface {
Store(key string, t *oauth2.Token) error
Lookup(key string) (*oauth2.Token, error)
Expand All @@ -77,10 +71,12 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
}
// OAuth2 config is invoked only for expired tokens to speed up
// the happy path in the token retrieval
cfg, err := a.oauth2Config()
cfg, err := a.oauth2Config(ctx)
if err != nil {
return nil, err
}
// make OAuth2 library use our client
ctx = a.http.InContextForOAuth2(ctx)
// eagerly refresh token
refreshed, err := cfg.TokenSource(ctx, t).Token()
if err != nil {
Expand Down Expand Up @@ -110,7 +106,7 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
if err != nil {
return fmt.Errorf("init: %w", err)
}
cfg, err := a.oauth2Config()
cfg, err := a.oauth2Config(ctx)
if err != nil {
return err
}
Expand All @@ -120,6 +116,8 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
}
defer cb.Close()
state, pkce := a.stateAndPKCE()
// make OAuth2 library use our client
ctx = a.http.InContextForOAuth2(ctx)
ts := authhandler.TokenSourceWithPKCE(ctx, cfg, state, cb.Handler, pkce)
t, err := ts.Token()
if err != nil {
Expand All @@ -138,7 +136,9 @@ func (a *PersistentAuth) init(ctx context.Context) error {
return ErrFetchCredentials
}
if a.http == nil {
a.http = http.DefaultClient
a.http = httpclient.NewApiClient(httpclient.ClientConfig{
// noop
})
}
if a.cache == nil {
a.cache = &cache.TokenCache{}
Expand Down Expand Up @@ -172,39 +172,28 @@ func (a *PersistentAuth) Close() error {
return a.ln.Close()
}

func (a *PersistentAuth) oidcEndpoints() (*oauthAuthorizationServer, error) {
func (a *PersistentAuth) oidcEndpoints(ctx context.Context) (*oauthAuthorizationServer, error) {
prefix := a.key()
if a.AccountID != "" {
return &oauthAuthorizationServer{
AuthorizationEndpoint: fmt.Sprintf("%s/v1/authorize", prefix),
TokenEndpoint: fmt.Sprintf("%s/v1/token", prefix),
}, nil
}
var oauthEndpoints oauthAuthorizationServer
oidc := fmt.Sprintf("%s/oidc/.well-known/oauth-authorization-server", prefix)
oidcResponse, err := a.http.Get(oidc)
err := a.http.Do(ctx, "GET", oidc, httpclient.WithResponseUnmarshal(&oauthEndpoints))
if err != nil {
return nil, fmt.Errorf("fetch .well-known: %w", err)
}
if oidcResponse.StatusCode != 200 {
var httpErr *httpclient.HttpError
if errors.As(err, &httpErr) && httpErr.StatusCode != 200 {
return nil, ErrOAuthNotSupported
}
if oidcResponse.Body == nil {
return nil, fmt.Errorf("fetch .well-known: empty body")
}
defer oidcResponse.Body.Close()
raw, err := io.ReadAll(oidcResponse.Body)
if err != nil {
return nil, fmt.Errorf("read .well-known: %w", err)
}
var oauthEndpoints oauthAuthorizationServer
err = json.Unmarshal(raw, &oauthEndpoints)
if err != nil {
return nil, fmt.Errorf("parse .well-known: %w", err)
}
return &oauthEndpoints, nil
}

func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
func (a *PersistentAuth) oauth2Config(ctx context.Context) (*oauth2.Config, error) {
// in this iteration of CLI, we're using all scopes by default,
// because tools like CLI and Terraform do use all apis. This
// decision may be reconsidered later, once we have a proper
Expand All @@ -213,7 +202,7 @@ func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
"offline_access",
"all-apis",
}
endpoints, err := a.oidcEndpoints()
endpoints, err := a.oidcEndpoints(ctx)
if err != nil {
return nil, fmt.Errorf("oidc: %w", err)
}
Expand Down
33 changes: 14 additions & 19 deletions libs/auth/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (
"crypto/tls"
_ "embed"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"

"github.com/databricks/databricks-sdk-go/client"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/databricks/databricks-sdk-go/qa"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
Expand All @@ -24,34 +24,29 @@ func TestOidcEndpointsForAccounts(t *testing.T) {
AccountID: "xyz",
}
defer p.Close()
s, err := p.oidcEndpoints()
s, err := p.oidcEndpoints(context.Background())
assert.NoError(t, err)
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/authorize", s.AuthorizationEndpoint)
assert.Equal(t, "https://abc/oidc/accounts/xyz/v1/token", s.TokenEndpoint)
}

type mockGet func(url string) (*http.Response, error)

func (m mockGet) Get(url string) (*http.Response, error) {
return m(url)
}

func TestOidcForWorkspace(t *testing.T) {
p := &PersistentAuth{
Host: "abc",
http: mockGet(func(url string) (*http.Response, error) {
assert.Equal(t, "https://abc/oidc/.well-known/oauth-authorization-server", url)
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{
"authorization_endpoint": "a",
"token_endpoint": "b"
}`)),
}, nil
http: httpclient.NewApiClient(httpclient.ClientConfig{
Transport: fixtures.MappingTransport{
"GET /oidc/.well-known/oauth-authorization-server": {
Status: 200,
Response: map[string]string{
"authorization_endpoint": "a",
"token_endpoint": "b",
},
},
},
}),
}
defer p.Close()
endpoints, err := p.oidcEndpoints()
endpoints, err := p.oidcEndpoints(context.Background())
assert.NoError(t, err)
assert.Equal(t, "a", endpoints.AuthorizationEndpoint)
assert.Equal(t, "b", endpoints.TokenEndpoint)
Expand Down

0 comments on commit 945d522

Please sign in to comment.