Skip to content

Commit

Permalink
feat: validate access token (#308)
Browse files Browse the repository at this point in the history
Validate access token when set in config
  • Loading branch information
dbolson authored May 29, 2024
1 parent a995cb8 commit 5591a0a
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 22 deletions.
3 changes: 3 additions & 0 deletions cmd/cmdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/require"

"github.com/launchdarkly/ldcli/internal/analytics"
"github.com/launchdarkly/ldcli/internal/config"
"github.com/launchdarkly/ldcli/internal/resources"
)

var StubbedSuccessResponse = `{
Expand All @@ -24,6 +26,7 @@ func CallCmd(
args []string,
) ([]byte, error) {
rootCmd, err := NewRootCommand(
config.NewService(&resources.MockClient{}),
trackerFn,
clients,
"test",
Expand Down
49 changes: 35 additions & 14 deletions cmd/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"encoding/json"
"errors"
"fmt"
"os"
"sort"
Expand All @@ -15,7 +16,7 @@ import (
"github.com/launchdarkly/ldcli/cmd/cliflags"
"github.com/launchdarkly/ldcli/internal/analytics"
"github.com/launchdarkly/ldcli/internal/config"
"github.com/launchdarkly/ldcli/internal/errors"
errs "github.com/launchdarkly/ldcli/internal/errors"
"github.com/launchdarkly/ldcli/internal/output"
)

Expand All @@ -38,10 +39,10 @@ func (cmd ConfigCmd) HelpCalled() bool {
return cmd.helpCalled
}

func NewConfigCmd(analyticsTrackerFn analytics.TrackerFn) *ConfigCmd {
func NewConfigCmd(service config.Service, analyticsTrackerFn analytics.TrackerFn) *ConfigCmd {
cmd := &cobra.Command{
Long: "View and modify specific configuration values",
RunE: run(),
RunE: run(service),
Short: "View and modify specific configuration values",
Use: "config",
PreRun: func(cmd *cobra.Command, args []string) {
Expand Down Expand Up @@ -93,7 +94,7 @@ func NewConfigCmd(analyticsTrackerFn analytics.TrackerFn) *ConfigCmd {
return &configCmd
}

func run() func(*cobra.Command, []string) error {
func run(service config.Service) func(*cobra.Command, []string) error {
return func(cmd *cobra.Command, args []string) error {
switch {
case viper.GetBool(ListFlag):
Expand All @@ -120,7 +121,7 @@ func run() func(*cobra.Command, []string) error {
case viper.GetBool(SetFlag):
// flag needs two arguments: a key and value
if len(args)%2 != 0 {
return errors.NewError("flag needs an argument: --set")
return errs.NewError("flag needs an argument: --set")
}

for i := 0; i < len(args)-1; i += 2 {
Expand All @@ -132,7 +133,7 @@ func run() func(*cobra.Command, []string) error {

rawConfig, v, err := getRawConfig()
if err != nil {
return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
}

// add arg pairs to config where each argument is --set arg1 val1 --set arg2 val2
Expand All @@ -146,17 +147,37 @@ func run() func(*cobra.Command, []string) error {
}
}

var updatingAccessToken bool
for _, f := range newFields {
if f == cliflags.AccessTokenFlag {
updatingAccessToken = true
break
}
}
if updatingAccessToken {
if !service.VerifyAccessToken(
rawConfig[cliflags.AccessTokenFlag].(string),
viper.GetString(cliflags.BaseURIFlag),
) {
errorMessage := fmt.Sprintf("%s is invalid. ", cliflags.AccessTokenFlag)
errorMessage += errs.AccessTokenInvalidErrMessage(viper.GetString(cliflags.BaseURIFlag))
err := errors.New(errorMessage)

return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
}
}

configFile, err := config.NewConfig(rawConfig)
if err != nil {
return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
}

setKeyFn := func(key string, value interface{}, v *viper.Viper) {
v.Set(key, value)
}
err = writeConfig(configFile, v, setKeyFn)
if err != nil {
return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
}

output, err := outputSetAction(newFields)
Expand All @@ -173,7 +194,7 @@ func run() func(*cobra.Command, []string) error {

config, v, err := getConfig()
if err != nil {
return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
}

unsetKeyFn := func(key string, value interface{}, v *viper.Viper) {
Expand All @@ -183,7 +204,7 @@ func run() func(*cobra.Command, []string) error {
}
err = writeConfig(config, v, unsetKeyFn)
if err != nil {
return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
}

output, err := outputUnsetAction(viper.GetString(UnsetFlag))
Expand Down Expand Up @@ -297,7 +318,7 @@ func writeConfig(
}

func newErr(flag string) error {
err := errors.NewError(
err := errs.NewError(
fmt.Sprintf(
`{
"message": "%s is not a valid configuration option"
Expand All @@ -306,7 +327,7 @@ func newErr(flag string) error {
),
)

return errors.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
return errs.NewError(output.CmdOutputError(viper.GetString(cliflags.OutputFlag), err))
}

func writeAlphabetizedFlags(sb *strings.Builder) {
Expand All @@ -329,7 +350,7 @@ func outputSetAction(newFields []string) (string, error) {
fieldsJSON, _ := json.Marshal(fields)
output, err := output.CmdOutput("update", viper.GetString(cliflags.OutputFlag), fieldsJSON)
if err != nil {
return "", errors.NewError(err.Error())
return "", errs.NewError(err.Error())
}

return output, nil
Expand All @@ -344,7 +365,7 @@ func outputUnsetAction(newField string) (string, error) {
fieldJSON, _ := json.Marshal(field)
output, err := output.CmdOutput("delete", viper.GetString(cliflags.OutputFlag), fieldJSON)
if err != nil {
return "", errors.NewError(err.Error())
return "", errs.NewError(err.Error())
}

return output, nil
Expand Down
5 changes: 4 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func (cmd RootCmd) Execute() error {
}

func NewRootCommand(
configService config.Service,
analyticsTrackerFn analytics.TrackerFn,
clients APIClients,
version string,
Expand Down Expand Up @@ -182,7 +183,7 @@ func NewRootCommand(
return nil, err
}

configCmd := configcmd.NewConfigCmd(analyticsTrackerFn)
configCmd := configcmd.NewConfigCmd(configService, analyticsTrackerFn)
cmd.AddCommand(configCmd.Cmd())
cmd.AddCommand(NewQuickStartCmd(analyticsTrackerFn, clients.EnvironmentsClient, clients.FlagsClient))
cmd.AddCommand(resourcecmd.NewResourcesCmd())
Expand Down Expand Up @@ -213,10 +214,12 @@ func Execute(version string) {
ProjectsClient: projects.NewClient(version),
ResourcesClient: resources.NewClient(version),
}
configService := config.NewService(resources.NewClient(version))
trackerFn := analytics.ClientFn{
ID: uuid.New().String(),
}
rootCmd, err := NewRootCommand(
configService,
trackerFn.Tracker(version),
clients,
version,
Expand Down
3 changes: 2 additions & 1 deletion cmd/validators/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ func CmdError(err error, commandPath string, baseURI string) error {
if strings.Contains(err.Error(), cliflags.AccessTokenFlag) {
errorMessage += "\n\n"
if baseURI != "" {
errorMessage += fmt.Sprintf("Go to %s/settings/authorization to create an access token.\n", baseURI)
errorMessage += errs.AccessTokenInvalidErrMessage(baseURI)
errorMessage += "\n"
}
errorMessage += fmt.Sprintf("Use `ldcli config --set %s <value>` to configure the value to persist across CLI commands.\n\n", cliflags.AccessTokenFlag)
} else {
Expand Down
29 changes: 29 additions & 0 deletions internal/config/config_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package config

import (
"fmt"

"github.com/launchdarkly/ldcli/internal/resources"
)

type Service struct {
client resources.Client
}

func NewService(client resources.Client) Service {
return Service{
client: client,
}
}

// VerifyAccessToken is true if the given access token is valid to make API requests.
func (s Service) VerifyAccessToken(accessToken string, baseURI string) bool {
path := fmt.Sprintf(
"%s/api/v2/account",
baseURI,
)

_, err := s.client.MakeRequest(accessToken, "HEAD", path, "application/json", nil, nil)

return err == nil
}
27 changes: 26 additions & 1 deletion internal/config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package config_test

import (
"errors"
"fmt"
"github.com/launchdarkly/ldcli/internal/config"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/launchdarkly/ldcli/internal/config"
"github.com/launchdarkly/ldcli/internal/resources"
)

func TestNewConfig(t *testing.T) {
Expand Down Expand Up @@ -132,3 +136,24 @@ func TestNewConfig(t *testing.T) {
})
})
}

func TestService_VerifyAccessToken(t *testing.T) {
t.Run("is valid with a valid access token", func(t *testing.T) {
service := config.NewService(&resources.MockClient{})

isValid := service.VerifyAccessToken("valid-access-token", "http://test.com")

assert.True(t, isValid)
})

t.Run("is invalid with an invalid access token", func(t *testing.T) {
service := config.NewService(&resources.MockClient{
StatusCode: http.StatusUnauthorized,
Err: errors.New("invalid access token"),
})

isValid := service.VerifyAccessToken("invalid-access-token", "http://test.com")

assert.False(t, isValid)
})
}
4 changes: 4 additions & 0 deletions internal/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ func normalizeUnauthorizedJSON() ([]byte, error) {

return errMsg, nil
}

func AccessTokenInvalidErrMessage(baseURI string) string {
return fmt.Sprintf("Go to %s/settings/authorization to create an access token.", baseURI)
}
3 changes: 1 addition & 2 deletions internal/output/resource_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ package output

import (
"encoding/json"
"errors"
"fmt"
"math"
"net/url"
"strconv"
"strings"

"github.com/pkg/errors"

errs "github.com/launchdarkly/ldcli/internal/errors"
)

Expand Down
15 changes: 12 additions & 3 deletions internal/resources/mock_client.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package resources

import "net/url"
import (
"net/http"
"net/url"
)

type MockClient struct {
Input []byte
Response []byte
Err error
Input []byte
Response []byte
StatusCode int
}

var _ Client = &MockClient{}
Expand All @@ -16,5 +21,9 @@ func (c *MockClient) MakeRequest(
) ([]byte, error) {
c.Input = data

if c.StatusCode > http.StatusBadRequest {
return c.Response, c.Err
}

return c.Response, nil
}

0 comments on commit 5591a0a

Please sign in to comment.