Skip to content

Commit

Permalink
refactor: cli/internal/oauth/
Browse files Browse the repository at this point in the history
Signed-off-by: Laura Brehm <laurabrehm@hey.com>
  • Loading branch information
laurazard committed Aug 12, 2024
1 parent 16bda05 commit 569af38
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 32 deletions.
63 changes: 38 additions & 25 deletions cli/internal/oauth/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ var ErrTimeout = errors.New("timed out waiting for device token")
// GetDeviceCode initiates the device-code auth flow with the tenant.
// The state returned contains the device code that the user must use to
// authenticate, as well as the URL to visit, etc.
func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, err error) {
func (a API) GetDeviceCode(ctx context.Context, audience string) (State, error) {
data := url.Values{
"client_id": {a.ClientID},
"audience": {audience},
Expand All @@ -62,24 +62,33 @@ func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, e
deviceCodeURL := a.TenantURL + "/oauth/device/code"
resp, err := postForm(ctx, deviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return
return State{}, err
}
defer func() {
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
var body map[string]any
err = json.NewDecoder(resp.Body).Decode(&body)
if errorDescription, ok := body["error_description"].(string); ok {
return state, errors.New(errorDescription)
}
return state, fmt.Errorf("failed to get device code: %w", err)
return State{}, tryDecodeOAuthError(resp)
}

var state State
err = json.NewDecoder(resp.Body).Decode(&state)
if err != nil {
return state, fmt.Errorf("failed to get device code: %w", err)
}

return
return state, nil
}

func tryDecodeOAuthError(resp *http.Response) error {
var body map[string]any
if err := json.NewDecoder(resp.Body).Decode(&body); err == nil {
if errorDescription, ok := body["error_description"].(string); ok {
return errors.New(errorDescription)
}
}
return errors.New("unexpected response from tenant: " + resp.Status)
}

// WaitForDeviceToken polls the tenant to get access/refresh tokens for the user.
Expand Down Expand Up @@ -118,7 +127,7 @@ func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse
}

// getToken calls the token endpoint of Auth0 and returns the response.
func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse, err error) {
func (a API) getDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
data := url.Values{
"client_id": {a.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
Expand All @@ -128,13 +137,19 @@ func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse

resp, err := postForm(ctx, oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return res, fmt.Errorf("failed to get code: %w", err)
return TokenResponse{}, fmt.Errorf("failed to get tokens: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()

var res TokenResponse
err = json.NewDecoder(resp.Body).Decode(&res)
_ = resp.Body.Close()
if err != nil {
return res, fmt.Errorf("failed to decode response: %w", err)
}

return
return res, nil
}

// RevokeToken revokes a refresh token with the tenant so that it can no longer
Expand All @@ -150,11 +165,14 @@ func (a API) RevokeToken(ctx context.Context, refreshToken string) error {
if err != nil {
return err
}
defer resp.Body.Close()
defer func() {
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
return errors.New("failed to revoke token")
return tryDecodeOAuthError(resp)
}

return nil
}

Expand Down Expand Up @@ -188,22 +206,17 @@ func (a API) GetAutoPAT(ctx context.Context, audience string, res TokenResponse)
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
var body map[string]any
err = json.NewDecoder(resp.Body).Decode(&body)
if errorDescription, ok := body["error_description"].(string); ok {
return "", errors.New(errorDescription)
}
return "", fmt.Errorf("failed to get device code: %w", err)
if resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("unexpected response from Hub: %s", resp.Status)
}

var respo patGenerateResponse
err = json.NewDecoder(resp.Body).Decode(&respo)
var response patGenerateResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", err
}

return respo.Data.Token, nil
return response.Data.Token, nil
}

type patGenerateResponse struct {
Expand Down
85 changes: 84 additions & 1 deletion cli/internal/oauth/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestGetDeviceCode(t *testing.T) {

_, err := api.GetDeviceCode(context.Background(), "anAudience")

assert.ErrorContains(t, err, "failed to get device code")
assert.ErrorContains(t, err, "unexpected response from tenant: 500 Internal Server Error")
})

t.Run("canceled context", func(t *testing.T) {
Expand Down Expand Up @@ -173,6 +173,28 @@ func TestWaitForDeviceToken(t *testing.T) {
assert.DeepEqual(t, token, expectedToken)
})

t.Run("unexpected response", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 30,
}
_, err := api.WaitForDeviceToken(context.Background(), state)

assert.ErrorContains(t, err, "unexpected response from tenant: 500 Internal Server Error")
})

t.Run("timeout", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -265,6 +287,49 @@ func TestRevoke(t *testing.T) {
assert.NilError(t, err)
})

t.Run("unexpected response", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/revoke", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token")

w.WriteHeader(http.StatusNotFound)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}

err := api.RevokeToken(context.Background(), "v1.a-refresh-token")
assert.ErrorContains(t, err, "unexpected response from tenant: 404 Not Found")
})

t.Run("error w/ description", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonState, err := json.Marshal(TokenResponse{
ErrorDescription: "invalid client id",
})
assert.NilError(t, err)

w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}

err := api.RevokeToken(context.Background(), "v1.a-refresh-token")
assert.ErrorContains(t, err, "invalid client id")
})

t.Run("canceled context", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -327,6 +392,24 @@ func TestGetAutoPAT(t *testing.T) {
assert.Equal(t, "a-docker-pat", pat)
})

t.Run("general error", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
api := API{
TenantURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}

_, err := api.GetAutoPAT(context.Background(), ts.URL, TokenResponse{
AccessToken: "bork",
})
assert.ErrorContains(t, err, "unexpected response from Hub: 500 Internal Server Error")
})

t.Run("context canceled", func(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
10 changes: 5 additions & 5 deletions cli/internal/oauth/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ func New(options OAuthManagerOptions) *OAuthManager {
func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (*types.AuthConfig, error) {
state, err := m.api.GetDeviceCode(ctx, m.audience)
if err != nil {
return nil, fmt.Errorf("login failed: %w", err)
return nil, fmt.Errorf("failed to get device code: %w", err)
}

if state.UserCode == "" {
return nil, errors.New("login failed: no user code returned")
return nil, errors.New("no user code returned")
}

_, _ = fmt.Fprintln(w, "\n\033[1mUSING WEB BASED LOGIN\033[0m")
Expand Down Expand Up @@ -110,18 +110,18 @@ func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (*types.Aut
case <-ctx.Done():
return nil, errors.New("login canceled")
case err := <-waitForTokenErrChan:
return nil, fmt.Errorf("login failed: %w", err)
return nil, fmt.Errorf("failed waiting for authentication: %w", err)
case tokenRes = <-tokenResChan:
}

claims, err := oauth.GetClaims(tokenRes.AccessToken)
if err != nil {
return nil, fmt.Errorf("login failed: %w", err)
return nil, fmt.Errorf("failed to parse token claims: %w", err)
}

err = m.storeTokensInStore(tokenRes, claims.Domain.Username)
if err != nil {
return nil, fmt.Errorf("login failed: %w", err)
return nil, fmt.Errorf("failed to store tokens: %w", err)
}

pat, err := m.api.GetAutoPAT(ctx, m.audience, tokenRes)
Expand Down
2 changes: 1 addition & 1 deletion cli/internal/oauth/manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestLoginDevice(t *testing.T) {
}

_, err := manager.LoginDevice(context.Background(), os.Stderr)
assert.ErrorContains(t, err, "login failed: timed out waiting for device token")
assert.ErrorContains(t, err, "failed waiting for authentication: timed out waiting for device token")
})

t.Run("canceled context", func(t *testing.T) {
Expand Down

0 comments on commit 569af38

Please sign in to comment.