From 204c9747ca778fe44f0d50a2c93c0265f85d2ea2 Mon Sep 17 00:00:00 2001 From: Laura Brehm Date: Mon, 12 Aug 2024 13:00:25 +0100 Subject: [PATCH] refactor: cli/internal/oauth/ Signed-off-by: Laura Brehm --- cli/internal/oauth/api/api.go | 66 ++++++++++++++-------- cli/internal/oauth/api/api_test.go | 63 ++++++++++++++++++++- cli/internal/oauth/manager/manager.go | 10 ++-- cli/internal/oauth/manager/manager_test.go | 2 +- 4 files changed, 109 insertions(+), 32 deletions(-) diff --git a/cli/internal/oauth/api/api.go b/cli/internal/oauth/api/api.go index bdeebeec440c..031861297967 100644 --- a/cli/internal/oauth/api/api.go +++ b/cli/internal/oauth/api/api.go @@ -52,7 +52,7 @@ var ErrTimeout = errors.New("timed out waiting for device token") // GetDeviceCode initiates the device-code auth flow with the tenant. // The state returned contains the device code that the user must use to // authenticate, as well as the URL to visit, etc. -func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, err error) { +func (a API) GetDeviceCode(ctx context.Context, audience string) (State, error) { data := url.Values{ "client_id": {a.ClientID}, "audience": {audience}, @@ -62,24 +62,33 @@ func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, e deviceCodeURL := a.TenantURL + "/oauth/device/code" resp, err := postForm(ctx, deviceCodeURL, strings.NewReader(data.Encode())) if err != nil { - return + return State{}, err } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - var body map[string]any - err = json.NewDecoder(resp.Body).Decode(&body) - if errorDescription, ok := body["error_description"].(string); ok { - return state, errors.New(errorDescription) - } - return state, fmt.Errorf("failed to get device code: %w", err) + return State{}, tryDecodeOAuthError(resp) } + var state State err = json.NewDecoder(resp.Body).Decode(&state) + if err != nil { + return state, fmt.Errorf("failed to get device code: %w", err) + } - return + return state, nil +} + +func tryDecodeOAuthError(resp *http.Response) error { + var body map[string]any + if err := json.NewDecoder(resp.Body).Decode(&body); err == nil { + if errorDescription, ok := body["error_description"].(string); ok { + return errors.New(errorDescription) + } + } + return errors.New("unexpected response from tenant: " + resp.Status) } // WaitForDeviceToken polls the tenant to get access/refresh tokens for the user. @@ -118,7 +127,7 @@ func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse } // getToken calls the token endpoint of Auth0 and returns the response. -func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse, err error) { +func (a API) getDeviceToken(ctx context.Context, state State) (TokenResponse, error) { data := url.Values{ "client_id": {a.ClientID}, "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, @@ -128,13 +137,22 @@ func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse resp, err := postForm(ctx, oauthTokenURL, strings.NewReader(data.Encode())) if err != nil { - return res, fmt.Errorf("failed to get code: %w", err) + return TokenResponse{}, fmt.Errorf("failed to get tokens: %w", err) } + defer func() { + _ = resp.Body.Close() + }() + // this endpoint returns a 403 with an `authorization_pending` error until the + // user has authenticated, so we don't check the status code here and instead + // decode the response and check for the error. + var res TokenResponse err = json.NewDecoder(resp.Body).Decode(&res) - _ = resp.Body.Close() + if err != nil { + return res, fmt.Errorf("failed to decode response: %w", err) + } - return + return res, nil } // RevokeToken revokes a refresh token with the tenant so that it can no longer @@ -150,11 +168,14 @@ func (a API) RevokeToken(ctx context.Context, refreshToken string) error { if err != nil { return err } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() if resp.StatusCode != http.StatusOK { - return errors.New("failed to revoke token") + return tryDecodeOAuthError(resp) } + return nil } @@ -188,22 +209,17 @@ func (a API) GetAutoPAT(ctx context.Context, audience string, res TokenResponse) _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - var body map[string]any - err = json.NewDecoder(resp.Body).Decode(&body) - if errorDescription, ok := body["error_description"].(string); ok { - return "", errors.New(errorDescription) - } - return "", fmt.Errorf("failed to get device code: %w", err) + if resp.StatusCode != http.StatusCreated { + return "", fmt.Errorf("unexpected response from Hub: %s", resp.Status) } - var respo patGenerateResponse - err = json.NewDecoder(resp.Body).Decode(&respo) + var response patGenerateResponse + err = json.NewDecoder(resp.Body).Decode(&response) if err != nil { return "", err } - return respo.Data.Token, nil + return response.Data.Token, nil } type patGenerateResponse struct { diff --git a/cli/internal/oauth/api/api_test.go b/cli/internal/oauth/api/api_test.go index ce09104a0e9b..34b9600f5cb3 100644 --- a/cli/internal/oauth/api/api_test.go +++ b/cli/internal/oauth/api/api_test.go @@ -90,7 +90,7 @@ func TestGetDeviceCode(t *testing.T) { _, err := api.GetDeviceCode(context.Background(), "anAudience") - assert.ErrorContains(t, err, "failed to get device code") + assert.ErrorContains(t, err, "unexpected response from tenant: 500 Internal Server Error") }) t.Run("canceled context", func(t *testing.T) { @@ -265,6 +265,49 @@ func TestRevoke(t *testing.T) { assert.NilError(t, err) }) + t.Run("unexpected response", func(t *testing.T) { + t.Parallel() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth/revoke", r.URL.Path) + assert.Equal(t, r.FormValue("client_id"), "aClientID") + assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token") + + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + api := API{ + TenantURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + } + + err := api.RevokeToken(context.Background(), "v1.a-refresh-token") + assert.ErrorContains(t, err, "unexpected response from tenant: 404 Not Found") + }) + + t.Run("error w/ description", func(t *testing.T) { + t.Parallel() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonState, err := json.Marshal(TokenResponse{ + ErrorDescription: "invalid client id", + }) + assert.NilError(t, err) + + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(jsonState) + })) + defer ts.Close() + api := API{ + TenantURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + } + + err := api.RevokeToken(context.Background(), "v1.a-refresh-token") + assert.ErrorContains(t, err, "invalid client id") + }) + t.Run("canceled context", func(t *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -327,6 +370,24 @@ func TestGetAutoPAT(t *testing.T) { assert.Equal(t, "a-docker-pat", pat) }) + t.Run("general error", func(t *testing.T) { + t.Parallel() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() + api := API{ + TenantURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + } + + _, err := api.GetAutoPAT(context.Background(), ts.URL, TokenResponse{ + AccessToken: "bork", + }) + assert.ErrorContains(t, err, "unexpected response from Hub: 500 Internal Server Error") + }) + t.Run("context canceled", func(t *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/cli/internal/oauth/manager/manager.go b/cli/internal/oauth/manager/manager.go index 44a9b4b9014c..389135c7a6bc 100644 --- a/cli/internal/oauth/manager/manager.go +++ b/cli/internal/oauth/manager/manager.go @@ -75,11 +75,11 @@ func New(options OAuthManagerOptions) *OAuthManager { func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (*types.AuthConfig, error) { state, err := m.api.GetDeviceCode(ctx, m.audience) if err != nil { - return nil, fmt.Errorf("login failed: %w", err) + return nil, fmt.Errorf("failed to get device code: %w", err) } if state.UserCode == "" { - return nil, errors.New("login failed: no user code returned") + return nil, errors.New("no user code returned") } _, _ = fmt.Fprintln(w, "\n\033[1mUSING WEB BASED LOGIN\033[0m") @@ -110,18 +110,18 @@ func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (*types.Aut case <-ctx.Done(): return nil, errors.New("login canceled") case err := <-waitForTokenErrChan: - return nil, fmt.Errorf("login failed: %w", err) + return nil, fmt.Errorf("failed waiting for authentication: %w", err) case tokenRes = <-tokenResChan: } claims, err := oauth.GetClaims(tokenRes.AccessToken) if err != nil { - return nil, fmt.Errorf("login failed: %w", err) + return nil, fmt.Errorf("failed to parse token claims: %w", err) } err = m.storeTokensInStore(tokenRes, claims.Domain.Username) if err != nil { - return nil, fmt.Errorf("login failed: %w", err) + return nil, fmt.Errorf("failed to store tokens: %w", err) } pat, err := m.api.GetAutoPAT(ctx, m.audience, tokenRes) diff --git a/cli/internal/oauth/manager/manager_test.go b/cli/internal/oauth/manager/manager_test.go index 575aee60dd98..39a9b86c2fdb 100644 --- a/cli/internal/oauth/manager/manager_test.go +++ b/cli/internal/oauth/manager/manager_test.go @@ -159,7 +159,7 @@ func TestLoginDevice(t *testing.T) { } _, err := manager.LoginDevice(context.Background(), os.Stderr) - assert.ErrorContains(t, err, "login failed: timed out waiting for device token") + assert.ErrorContains(t, err, "failed waiting for authentication: timed out waiting for device token") }) t.Run("canceled context", func(t *testing.T) {