Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get IDP endpoints from well-known config #470

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions internal/cmd/config/set/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ import (
)

const (
sessionTimeLimitFlag = "session-time-limit"
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
allowedUrlDomainFlag = "allowed-url-domain"
sessionTimeLimitFlag = "session-time-limit"
identityProviderCustomWellKnownConfigurationFlag = "identity-provider-custom-well-known-configuration"
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
allowedUrlDomainFlag = "allowed-url-domain"

authorizationCustomEndpointFlag = "authorization-custom-endpoint"
dnsCustomEndpointFlag = "dns-custom-endpoint"
Expand Down Expand Up @@ -131,7 +131,7 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e

func configureFlags(cmd *cobra.Command) {
cmd.Flags().String(sessionTimeLimitFlag, "", "Maximum time before authentication is required again. After this time, you will be prompted to login again to execute commands that require authentication. Can't be larger than 24h. Requires authentication after being set to take effect. Examples: 3h, 5h30m40s (BETA: currently values greater than 2h have no effect)")
cmd.Flags().String(identityProviderCustomEndpointFlag, "", "Identity Provider base URL, used for user authentication")
cmd.Flags().String(identityProviderCustomWellKnownConfigurationFlag, "", "Identity Provider well-known OpenID configuration URL, used for user authentication")
cmd.Flags().String(identityProviderCustomClientIdFlag, "", "Identity Provider client ID, used for user authentication")
cmd.Flags().String(allowedUrlDomainFlag, "", `Domain name, used for the verification of the URLs that are given in the custom identity provider endpoint and "STACKIT curl" command`)
cmd.Flags().String(observabilityCustomEndpointFlag, "", "Observability API base URL, used in calls to this API")
Expand Down Expand Up @@ -159,7 +159,7 @@ func configureFlags(cmd *cobra.Command) {

err := viper.BindPFlag(config.SessionTimeLimitKey, cmd.Flags().Lookup(sessionTimeLimitFlag))
cobra.CheckErr(err)
err = viper.BindPFlag(config.IdentityProviderCustomEndpointKey, cmd.Flags().Lookup(identityProviderCustomEndpointFlag))
err = viper.BindPFlag(config.IdentityProviderCustomWellKnownConfigurationKey, cmd.Flags().Lookup(identityProviderCustomWellKnownConfigurationFlag))
cobra.CheckErr(err)
err = viper.BindPFlag(config.IdentityProviderCustomClientIdKey, cmd.Flags().Lookup(identityProviderCustomClientIdFlag))
cobra.CheckErr(err)
Expand Down Expand Up @@ -190,7 +190,7 @@ func configureFlags(cmd *cobra.Command) {
cobra.CheckErr(err)
err = viper.BindPFlag(config.RedisCustomEndpointKey, cmd.Flags().Lookup(redisCustomEndpointFlag))
cobra.CheckErr(err)
err = viper.BindPFlag(config.ResourceManagerEndpointKey, cmd.Flags().Lookup(skeCustomEndpointFlag))
err = viper.BindPFlag(config.ResourceManagerEndpointKey, cmd.Flags().Lookup(resourceManagerCustomEndpointFlag))
cobra.CheckErr(err)
err = viper.BindPFlag(config.SecretsManagerCustomEndpointKey, cmd.Flags().Lookup(secretsManagerCustomEndpointFlag))
cobra.CheckErr(err)
Expand Down
14 changes: 7 additions & 7 deletions internal/cmd/config/unset/unset.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ const (
projectIdFlag = globalflags.ProjectIdFlag
verbosityFlag = globalflags.VerbosityFlag

sessionTimeLimitFlag = "session-time-limit"
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
allowedUrlDomainFlag = "allowed-url-domain"
sessionTimeLimitFlag = "session-time-limit"
identityProviderCustomWellKnownConfigurationFlag = "identity-provider-custom-well-known-configuration"
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
allowedUrlDomainFlag = "allowed-url-domain"

authorizationCustomEndpointFlag = "authorization-custom-endpoint"
dnsCustomEndpointFlag = "dns-custom-endpoint"
Expand Down Expand Up @@ -121,7 +121,7 @@ func NewCmd(p *print.Printer) *cobra.Command {
viper.Set(config.SessionTimeLimitKey, config.SessionTimeLimitDefault)
}
if model.IdentityProviderCustomEndpoint {
viper.Set(config.IdentityProviderCustomEndpointKey, "")
viper.Set(config.IdentityProviderCustomWellKnownConfigurationKey, "")
}
if model.IdentityProviderCustomClientID {
viper.Set(config.IdentityProviderCustomClientIdKey, "")
Expand Down Expand Up @@ -215,7 +215,7 @@ func configureFlags(cmd *cobra.Command) {
cmd.Flags().Bool(verbosityFlag, false, "Verbosity of the CLI")

cmd.Flags().Bool(sessionTimeLimitFlag, false, fmt.Sprintf("Maximum time before authentication is required again. If unset, defaults to %s", config.SessionTimeLimitDefault))
cmd.Flags().Bool(identityProviderCustomEndpointFlag, false, "Identity Provider base URL. If unset, uses the default base URL")
cmd.Flags().Bool(identityProviderCustomWellKnownConfigurationFlag, false, "Identity Provider well-known OpenID configuration URL. If unset, uses the default identity provider")
cmd.Flags().Bool(identityProviderCustomClientIdFlag, false, "Identity Provider client ID, used for user authentication")
cmd.Flags().Bool(allowedUrlDomainFlag, false, fmt.Sprintf("Domain name, used for the verification of the URLs that are given in the IDP endpoint and curl commands. If unset, defaults to %s", config.AllowedUrlDomainDefault))

Expand Down Expand Up @@ -251,7 +251,7 @@ func parseInput(p *print.Printer, cmd *cobra.Command) *inputModel {
Verbosity: flags.FlagToBoolValue(p, cmd, verbosityFlag),

SessionTimeLimit: flags.FlagToBoolValue(p, cmd, sessionTimeLimitFlag),
IdentityProviderCustomEndpoint: flags.FlagToBoolValue(p, cmd, identityProviderCustomEndpointFlag),
IdentityProviderCustomEndpoint: flags.FlagToBoolValue(p, cmd, identityProviderCustomWellKnownConfigurationFlag),
IdentityProviderCustomClientID: flags.FlagToBoolValue(p, cmd, identityProviderCustomClientIdFlag),
AllowedUrlDomain: flags.FlagToBoolValue(p, cmd, allowedUrlDomainFlag),

Expand Down
10 changes: 5 additions & 5 deletions internal/cmd/config/unset/unset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ func fixtureFlagValues(mods ...func(flagValues map[string]bool)) map[string]bool
projectIdFlag: true,
verbosityFlag: true,

sessionTimeLimitFlag: true,
identityProviderCustomEndpointFlag: true,
identityProviderCustomClientIdFlag: true,
allowedUrlDomainFlag: true,
sessionTimeLimitFlag: true,
identityProviderCustomWellKnownConfigurationFlag: true,
identityProviderCustomClientIdFlag: true,
allowedUrlDomainFlag: true,

authorizationCustomEndpointFlag: true,
dnsCustomEndpointFlag: true,
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestParseInput(t *testing.T) {
{
description: "identity provider custom endpoint empty",
flagValues: fixtureFlagValues(func(flagValues map[string]bool) {
flagValues[identityProviderCustomEndpointFlag] = false
flagValues[identityProviderCustomWellKnownConfigurationFlag] = false
}),
isValid: true,
expectedModel: fixtureInputModel(func(model *inputModel) {
Expand Down
2 changes: 2 additions & 0 deletions internal/pkg/auth/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const (
SERVICE_ACCOUNT_KEY authFieldKey = "service_account_key"
PRIVATE_KEY authFieldKey = "private_key"
TOKEN_CUSTOM_ENDPOINT authFieldKey = "token_custom_endpoint"
IDP_TOKEN_ENDPOINT authFieldKey = "idp_token_endpoint" //nolint:gosec // linter false positive
)

const (
Expand All @@ -57,6 +58,7 @@ var authFieldKeys = []authFieldKey{
SERVICE_ACCOUNT_KEY,
PRIVATE_KEY,
TOKEN_CUSTOM_ENDPOINT,
IDP_TOKEN_ENDPOINT,
authFlowType,
}

Expand Down
81 changes: 68 additions & 13 deletions internal/pkg/auth/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
)

const (
defaultIDPEndpoint = "https://accounts.stackit.cloud/oauth/v2"
defaultCLIClientID = "stackit-cli-0000-0000-000000000001"
defaultWellKnownConfig = "https://accounts.stackit.cloud/.well-known/openid-configuration"
defaultCLIClientID = "stackit-cli-0000-0000-000000000001"

loginSuccessPath = "/login-successful"
stackitLandingPage = "https://www.stackit.de"
Expand All @@ -44,20 +44,31 @@ type User struct {
Email string
}

type apiClient interface {
Do(req *http.Request) (*http.Response, error)
}

// AuthorizeUser implements the PKCE OAuth2 flow.
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
idpEndpoint, err := getIDPEndpoint()
idpWellKnownConfigURL, err := getIDPWellKnownConfigURL()
if err != nil {
return err
return fmt.Errorf("get IDP well-known configuration: %w", err)
}
if idpEndpoint != defaultIDPEndpoint {
p.Warn("You are using a custom identity provider (%s) for authentication.\n", idpEndpoint)
if idpWellKnownConfigURL != defaultWellKnownConfig {
p.Warn("You are using a custom identity provider well-known configuration (%s) for authentication.\n", idpWellKnownConfigURL)
err := p.PromptForEnter("Press Enter to proceed with the login...")
if err != nil {
return err
}
}

p.Debug(print.DebugLevel, "get IDP well-known configuration from %s", idpWellKnownConfigURL)
httpClient := &http.Client{}
idpWellKnownConfig, err := parseWellKnownConfiguration(httpClient, idpWellKnownConfigURL)
if err != nil {
return fmt.Errorf("parse IDP well-known configuration: %w", err)
}

idpClientID, err := getIDPClientID()
if err != nil {
return err
Expand Down Expand Up @@ -100,7 +111,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
conf := &oauth2.Config{
ClientID: idpClientID,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/authorize", idpEndpoint),
AuthURL: idpWellKnownConfig.AuthorizationEndpoint,
},
Scopes: []string{"openid offline_access email"},
RedirectURL: redirectURL,
Expand Down Expand Up @@ -147,7 +158,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
p.Debug(print.DebugLevel, "trading authorization code for access and refresh tokens")

// Trade the authorization code and the code verifier for access and refresh tokens
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, idpClientID, codeVerifier, code, redirectURL)
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpWellKnownConfig, idpClientID, codeVerifier, code, redirectURL)
if err != nil {
errServer = fmt.Errorf("retrieve tokens: %w", err)
return
Expand Down Expand Up @@ -222,7 +233,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
})

p.Debug(print.DebugLevel, "opening browser for authentication")
p.Debug(print.DebugLevel, "using authentication server on %s", idpEndpoint)
p.Debug(print.DebugLevel, "using authentication server on %s", idpWellKnownConfig.Issuer)
p.Debug(print.DebugLevel, "using client ID %s for authentication ", idpClientID)

// Open a browser window to the authorizationURL
Expand All @@ -248,9 +259,8 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
}

// getUserAccessAndRefreshTokens trades the authorization code retrieved from the first OAuth2 leg for an access token and a refresh token
func getUserAccessAndRefreshTokens(authDomain, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
// Set the authUrl and form-encoded data for the POST to the access token endpoint
authUrl := fmt.Sprintf("%s/token", authDomain)
func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
// Set form-encoded data for the POST to the access token endpoint
data := fmt.Sprintf(
"grant_type=authorization_code&client_id=%s"+
"&code_verifier=%s"+
Expand All @@ -260,7 +270,7 @@ func getUserAccessAndRefreshTokens(authDomain, clientID, codeVerifier, authoriza
payload := strings.NewReader(data)

// Create the request and execute it
req, _ := http.NewRequest("POST", authUrl, payload)
req, _ := http.NewRequest("POST", idpWellKnownConfig.TokenEndpoint, payload)
req.Header.Add("content-type", "application/x-www-form-urlencoded")
httpClient := &http.Client{}
res, err := httpClient.Do(req)
Expand Down Expand Up @@ -331,3 +341,48 @@ func openBrowser(pageUrl string) error {
}
return nil
}

// parseWellKnownConfiguration gets the well-known OpenID configuration from the provided URL and returns it as a JSON
// the method also stores the IDP token endpoint in the authentication storage
func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string) (wellKnownConfig *wellKnownConfig, err error) {
req, _ := http.NewRequest("GET", wellKnownConfigURL, http.NoBody)
res, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("make the request: %w", err)
}

// Process the response
defer func() {
closeErr := res.Body.Close()
if closeErr != nil {
err = fmt.Errorf("close response body: %w", closeErr)
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}

err = json.Unmarshal(body, &wellKnownConfig)
if err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
if wellKnownConfig == nil {
return nil, fmt.Errorf("nil well-known configuration response")
}
if wellKnownConfig.Issuer == "" {
return nil, fmt.Errorf("found no issuer")
}
if wellKnownConfig.AuthorizationEndpoint == "" {
return nil, fmt.Errorf("found no authorization endpoint")
}
if wellKnownConfig.TokenEndpoint == "" {
return nil, fmt.Errorf("found no token endpoint")
}

err = SetAuthField(IDP_TOKEN_ENDPOINT, wellKnownConfig.TokenEndpoint)
if err != nil {
return nil, fmt.Errorf("set token endpoint in the authentication storage: %w", err)
}
return wellKnownConfig, err
}
110 changes: 110 additions & 0 deletions internal/pkg/auth/user_login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package auth

import (
"fmt"
"io"
"net/http"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/zalando/go-keyring"
)

type apiClientMocked struct {
getFails bool
getResponse string
}

func (a *apiClientMocked) Do(_ *http.Request) (*http.Response, error) {
if a.getFails {
return &http.Response{
StatusCode: http.StatusNotFound,
}, fmt.Errorf("not found")
}
return &http.Response{
Status: "200 OK",
StatusCode: http.StatusAccepted,
Body: io.NopCloser(strings.NewReader(a.getResponse)),
}, nil
}

func TestParseWellKnownConfig(t *testing.T) {
tests := []struct {
name string
getFails bool
getResponse string
isValid bool
expected *wellKnownConfig
}{
{
name: "success",
getFails: false,
getResponse: `{"issuer":"issuer","authorization_endpoint":"auth","token_endpoint":"token"}`,
isValid: true,
expected: &wellKnownConfig{
Issuer: "issuer",
AuthorizationEndpoint: "auth",
TokenEndpoint: "token",
},
},
{
name: "get_fails",
getFails: true,
getResponse: "",
isValid: false,
expected: nil,
},
{
name: "empty_response",
getFails: true,
getResponse: "",
isValid: false,
expected: nil,
},
{
name: "missing_issuer",
getFails: true,
getResponse: `{"authorization_endpoint":"auth","token_endpoint":"token"}`,
isValid: false,
expected: nil,
},
{
name: "missing_authorization",
getFails: true,
getResponse: `{"issuer":"issuer","token_endpoint":"token"}`,
isValid: false,
expected: nil,
},
{
name: "missing_token",
getFails: true,
getResponse: `{"issuer":"issuer","authorization_endpoint":"auth"}`,
isValid: false,
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyring.MockInit()

testClient := apiClientMocked{
tt.getFails,
tt.getResponse,
}

got, err := parseWellKnownConfiguration(&testClient, "")

if tt.isValid && err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !tt.isValid && err == nil {
t.Fatalf("expected error, got none")
}

if tt.isValid && !cmp.Equal(*got, *tt.expected) {
t.Fatalf("expected %v, got %v", tt.expected, got)
}
})
}
}
Loading