diff --git a/cmd/cmdtest.go b/cmd/cmdtest.go index dde3b47e..342bfd68 100644 --- a/cmd/cmdtest.go +++ b/cmd/cmdtest.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" "github.com/launchdarkly/ldcli/internal/analytics" + "github.com/launchdarkly/ldcli/internal/config" + "github.com/launchdarkly/ldcli/internal/resources" ) var StubbedSuccessResponse = `{ @@ -24,6 +26,7 @@ func CallCmd( args []string, ) ([]byte, error) { rootCmd, err := NewRootCommand( + config.NewService(&resources.MockClient{}), trackerFn, clients, "test", diff --git a/cmd/config/config.go b/cmd/config/config.go index fc3f3697..2a6bcbe6 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -2,6 +2,7 @@ package config import ( "encoding/json" + "errors" "fmt" "os" "sort" @@ -15,7 +16,7 @@ import ( "github.com/launchdarkly/ldcli/cmd/cliflags" "github.com/launchdarkly/ldcli/internal/analytics" "github.com/launchdarkly/ldcli/internal/config" - "github.com/launchdarkly/ldcli/internal/errors" + errs "github.com/launchdarkly/ldcli/internal/errors" "github.com/launchdarkly/ldcli/internal/output" ) @@ -38,10 +39,10 @@ func (cmd ConfigCmd) HelpCalled() bool { return cmd.helpCalled } -func NewConfigCmd(analyticsTrackerFn analytics.TrackerFn) *ConfigCmd { +func NewConfigCmd(service config.Service, analyticsTrackerFn analytics.TrackerFn) *ConfigCmd { cmd := &cobra.Command{ Long: "View and modify specific configuration values", - RunE: run(), + RunE: run(service), Short: "View and modify specific configuration values", Use: "config", PreRun: func(cmd *cobra.Command, args []string) { @@ -93,7 +94,7 @@ func NewConfigCmd(analyticsTrackerFn analytics.TrackerFn) *ConfigCmd { return &configCmd } -func run() func(*cobra.Command, []string) error { +func run(service config.Service) func(*cobra.Command, []string) error { return func(cmd *cobra.Command, args []string) error { switch { case viper.GetBool(ListFlag): @@ -120,7 +121,7 @@ func run() func(*cobra.Command, []string) error { case viper.GetBool(SetFlag): // flag needs two arguments: a key and value if len(args)%2 != 0 { - return errors.NewError("flag needs an argument: --set") + return errs.NewError("flag needs an argument: --set") } for i := 0; i < len(args)-1; i += 2 { @@ -132,7 +133,7 @@ func run() func(*cobra.Command, []string) error { rawConfig, v, err := getRawConfig() if err != nil { - return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) + return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) } // add arg pairs to config where each argument is --set arg1 val1 --set arg2 val2 @@ -146,9 +147,29 @@ func run() func(*cobra.Command, []string) error { } } + var updatingAccessToken bool + for _, f := range newFields { + if f == cliflags.AccessTokenFlag { + updatingAccessToken = true + break + } + } + if updatingAccessToken { + if !service.VerifyAccessToken( + rawConfig[cliflags.AccessTokenFlag].(string), + viper.GetString(cliflags.BaseURIFlag), + ) { + errorMessage := fmt.Sprintf("%s is invalid. ", cliflags.AccessTokenFlag) + errorMessage += errs.AccessTokenInvalidErrMessage(viper.GetString(cliflags.BaseURIFlag)) + err := errors.New(errorMessage) + + return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) + } + } + configFile, err := config.NewConfig(rawConfig) if err != nil { - return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) + return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) } setKeyFn := func(key string, value interface{}, v *viper.Viper) { @@ -156,7 +177,7 @@ func run() func(*cobra.Command, []string) error { } err = writeConfig(configFile, v, setKeyFn) if err != nil { - return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) + return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) } output, err := outputSetAction(newFields) @@ -173,7 +194,7 @@ func run() func(*cobra.Command, []string) error { config, v, err := getConfig() if err != nil { - return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) + return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) } unsetKeyFn := func(key string, value interface{}, v *viper.Viper) { @@ -183,7 +204,7 @@ func run() func(*cobra.Command, []string) error { } err = writeConfig(config, v, unsetKeyFn) if err != nil { - return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) + return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) } output, err := outputUnsetAction(viper.GetString(UnsetFlag)) @@ -297,7 +318,7 @@ func writeConfig( } func newErr(flag string) error { - err := errors.NewError( + err := errs.NewError( fmt.Sprintf( `{ "message": "%s is not a valid configuration option" @@ -306,7 +327,7 @@ func newErr(flag string) error { ), ) - return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) + return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err)) } func writeAlphabetizedFlags(sb *strings.Builder) { @@ -329,7 +350,7 @@ func outputSetAction(newFields []string) (string, error) { fieldsJSON, _ := json.Marshal(fields) output, err := output.CmdOutput("update", viper.GetString(cliflags.OutputFlag), fieldsJSON) if err != nil { - return "", errors.NewError(err.Error()) + return "", errs.NewError(err.Error()) } return output, nil @@ -344,7 +365,7 @@ func outputUnsetAction(newField string) (string, error) { fieldJSON, _ := json.Marshal(field) output, err := output.CmdOutput("delete", viper.GetString(cliflags.OutputFlag), fieldJSON) if err != nil { - return "", errors.NewError(err.Error()) + return "", errs.NewError(err.Error()) } return output, nil diff --git a/cmd/root.go b/cmd/root.go index 96112c12..972d80e9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -68,6 +68,7 @@ func (cmd RootCmd) Execute() error { } func NewRootCommand( + configService config.Service, analyticsTrackerFn analytics.TrackerFn, clients APIClients, version string, @@ -182,7 +183,7 @@ func NewRootCommand( return nil, err } - configCmd := configcmd.NewConfigCmd(analyticsTrackerFn) + configCmd := configcmd.NewConfigCmd(configService, analyticsTrackerFn) cmd.AddCommand(configCmd.Cmd()) cmd.AddCommand(NewQuickStartCmd(analyticsTrackerFn, clients.EnvironmentsClient, clients.FlagsClient)) cmd.AddCommand(resourcecmd.NewResourcesCmd()) @@ -213,10 +214,12 @@ func Execute(version string) { ProjectsClient: projects.NewClient(version), ResourcesClient: resources.NewClient(version), } + configService := config.NewService(resources.NewClient(version)) trackerFn := analytics.ClientFn{ ID: uuid.New().String(), } rootCmd, err := NewRootCommand( + configService, trackerFn.Tracker(version), clients, version, diff --git a/cmd/validators/validators.go b/cmd/validators/validators.go index 4c539877..ef1d6050 100644 --- a/cmd/validators/validators.go +++ b/cmd/validators/validators.go @@ -46,7 +46,8 @@ func CmdError(err error, commandPath string, baseURI string) error { if strings.Contains(err.Error(), cliflags.AccessTokenFlag) { errorMessage += "\n\n" if baseURI != "" { - errorMessage += fmt.Sprintf("Go to %s/settings/authorization to create an access token.\n", baseURI) + errorMessage += errs.AccessTokenInvalidErrMessage(baseURI) + errorMessage += "\n" } errorMessage += fmt.Sprintf("Use `ldcli config --set %s ` to configure the value to persist across CLI commands.\n\n", cliflags.AccessTokenFlag) } else { diff --git a/internal/config/config_service.go b/internal/config/config_service.go new file mode 100644 index 00000000..58b71a18 --- /dev/null +++ b/internal/config/config_service.go @@ -0,0 +1,29 @@ +package config + +import ( + "fmt" + + "github.com/launchdarkly/ldcli/internal/resources" +) + +type Service struct { + client resources.Client +} + +func NewService(client resources.Client) Service { + return Service{ + client: client, + } +} + +// VerifyAccessToken is true if the given access token is valid to make API requests. +func (s Service) VerifyAccessToken(accessToken string, baseURI string) bool { + path := fmt.Sprintf( + "%s/api/v2/account", + baseURI, + ) + + _, err := s.client.MakeRequest(accessToken, "HEAD", path, "application/json", nil, nil) + + return err == nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 14b5d7de..5d1c3ca0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,12 +1,16 @@ package config_test import ( + "errors" "fmt" - "github.com/launchdarkly/ldcli/internal/config" + "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/launchdarkly/ldcli/internal/config" + "github.com/launchdarkly/ldcli/internal/resources" ) func TestNewConfig(t *testing.T) { @@ -132,3 +136,24 @@ func TestNewConfig(t *testing.T) { }) }) } + +func TestService_VerifyAccessToken(t *testing.T) { + t.Run("is valid with a valid access token", func(t *testing.T) { + service := config.NewService(&resources.MockClient{}) + + isValid := service.VerifyAccessToken("valid-access-token", "http://test.com") + + assert.True(t, isValid) + }) + + t.Run("is invalid with an invalid access token", func(t *testing.T) { + service := config.NewService(&resources.MockClient{ + StatusCode: http.StatusUnauthorized, + Err: errors.New("invalid access token"), + }) + + isValid := service.VerifyAccessToken("invalid-access-token", "http://test.com") + + assert.False(t, isValid) + }) +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 141f892a..7beb10ba 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -115,3 +115,7 @@ func normalizeUnauthorizedJSON() ([]byte, error) { return errMsg, nil } + +func AccessTokenInvalidErrMessage(baseURI string) string { + return fmt.Sprintf("Go to %s/settings/authorization to create an access token.", baseURI) +} diff --git a/internal/output/resource_output.go b/internal/output/resource_output.go index 0408280f..7e3d59c7 100644 --- a/internal/output/resource_output.go +++ b/internal/output/resource_output.go @@ -2,14 +2,13 @@ package output import ( "encoding/json" + "errors" "fmt" "math" "net/url" "strconv" "strings" - "github.com/pkg/errors" - errs "github.com/launchdarkly/ldcli/internal/errors" ) diff --git a/internal/resources/mock_client.go b/internal/resources/mock_client.go index bb3607fc..f3ddb4c1 100644 --- a/internal/resources/mock_client.go +++ b/internal/resources/mock_client.go @@ -1,10 +1,15 @@ package resources -import "net/url" +import ( + "net/http" + "net/url" +) type MockClient struct { - Input []byte - Response []byte + Err error + Input []byte + Response []byte + StatusCode int } var _ Client = &MockClient{} @@ -16,5 +21,9 @@ func (c *MockClient) MakeRequest( ) ([]byte, error) { c.Input = data + if c.StatusCode > http.StatusBadRequest { + return c.Response, c.Err + } + return c.Response, nil }