Skip to content

Commit

Permalink
Unify home dir logic between shared config and sso (#1960)
Browse files Browse the repository at this point in the history
Take the logic (for grabbing a home dir) from shared config. move it to internal shareddefaults. and then use shareddefaults for both sso and shared config
  • Loading branch information
isaiahvita authored Dec 15, 2022
1 parent 6fad028 commit 6067fb2
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 41 deletions.
10 changes: 10 additions & 0 deletions .changelog/49406fdf70c541d7b12e656807214245.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"id": "49406fdf-70c5-41d7-b12e-656807214245",
"type": "bugfix",
"description": "Unify logic between shared config and in finding home directory",
"modules": [
".",
"config",
"credentials"
]
}
22 changes: 3 additions & 19 deletions config/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (
"io"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/internal/ini"
"github.com/aws/aws-sdk-go-v2/internal/shareddefaults"
"github.com/aws/smithy-go/logging"
)

Expand Down Expand Up @@ -108,7 +108,7 @@ var defaultSharedConfigProfile = DefaultSharedConfigProfile
// - Linux/Unix: $HOME/.aws/credentials
// - Windows: %USERPROFILE%\.aws\credentials
func DefaultSharedCredentialsFilename() string {
return filepath.Join(userHomeDir(), ".aws", "credentials")
return filepath.Join(shareddefaults.UserHomeDir(), ".aws", "credentials")
}

// DefaultSharedConfigFilename returns the SDK's default file path for
Expand All @@ -119,7 +119,7 @@ func DefaultSharedCredentialsFilename() string {
// - Linux/Unix: $HOME/.aws/config
// - Windows: %USERPROFILE%\.aws\config
func DefaultSharedConfigFilename() string {
return filepath.Join(userHomeDir(), ".aws", "config")
return filepath.Join(shareddefaults.UserHomeDir(), ".aws", "config")
}

// DefaultSharedConfigFiles is a slice of the default shared config files that
Expand Down Expand Up @@ -1268,22 +1268,6 @@ func (e CredentialRequiresARNError) Error() string {
)
}

func userHomeDir() string {
// Ignore errors since we only care about Windows and *nix.
home, _ := os.UserHomeDir()

if len(home) > 0 {
return home
}

currUser, _ := user.Current()
if currUser != nil {
home = currUser.HomeDir
}

return home
}

func oneOrNone(bs ...bool) bool {
var count int

Expand Down
12 changes: 6 additions & 6 deletions credentials/ssocreds/sso_cached_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/internal/sdk"
"github.com/aws/aws-sdk-go-v2/internal/shareddefaults"
)

var osUserHomeDur = os.UserHomeDir
var osUserHomeDur = shareddefaults.UserHomeDir

// StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or
// error if unable get derive the path. Key that will be used to compute a SHA1
Expand All @@ -25,13 +26,12 @@ var osUserHomeDur = os.UserHomeDir
//
// ~/.aws/sso/cache/<sha1-hex-encoded-key>.json
func StandardCachedTokenFilepath(key string) (string, error) {
homeDir, err := osUserHomeDur()
if err != nil {
return "", fmt.Errorf("unable to get USER's home directory for cached token, %w", err)
homeDir := osUserHomeDur()
if len(homeDir) == 0 {
return "", fmt.Errorf("unable to get USER's home directory for cached token")
}

hash := sha1.New()
if _, err = hash.Write([]byte(key)); err != nil {
if _, err := hash.Write([]byte(key)); err != nil {
return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %w", err)
}

Expand Down
14 changes: 5 additions & 9 deletions credentials/ssocreds/sso_cached_token_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ssocreds

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
Expand All @@ -25,22 +24,22 @@ func TestStandardSSOCacheTokenFilepath(t *testing.T) {

cases := map[string]struct {
key string
osUserHomeDir func() (string, error)
osUserHomeDir func() string
expectFilename string
expectErr string
}{
"success": {
key: "https://example.awsapps.com/start",
osUserHomeDir: func() (string, error) {
return os.TempDir(), nil
osUserHomeDir: func() string {
return os.TempDir()
},
expectFilename: filepath.Join(os.TempDir(), ".aws", "sso", "cache",
"e8be5486177c5b5392bd9aa76563515b29358e6e.json"),
},
"failure": {
key: "https://example.awsapps.com/start",
osUserHomeDir: func() (string, error) {
return "", fmt.Errorf("some error")
osUserHomeDir: func() string {
return ""
},
expectErr: "some error",
},
Expand All @@ -55,9 +54,6 @@ func TestStandardSSOCacheTokenFilepath(t *testing.T) {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect %v error in %v", e, a)
}
return
}
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions credentials/ssocreds/sso_credentials_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func TestProvider(t *testing.T) {
osUserHomeDur = origHomeDir
}()

osUserHomeDur = func() (string, error) {
return "testdata", nil
osUserHomeDur = func() string {
return "testdata"
}

restoreTime := sdk.TestingUseReferenceTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC))
Expand Down
17 changes: 12 additions & 5 deletions internal/shareddefaults/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package shareddefaults

import (
"os"
"os/user"
"path/filepath"
"runtime"
)

// SharedCredentialsFilename returns the SDK's default file path
Expand Down Expand Up @@ -31,10 +31,17 @@ func SharedConfigFilename() string {
// UserHomeDir returns the home directory for the user the process is
// running under.
func UserHomeDir() string {
if runtime.GOOS == "windows" { // Windows
return os.Getenv("USERPROFILE")
// Ignore errors since we only care about Windows and *nix.
home, _ := os.UserHomeDir()

if len(home) > 0 {
return home
}

currUser, _ := user.Current()
if currUser != nil {
home = currUser.HomeDir
}

// *nix
return os.Getenv("HOME")
return home
}

0 comments on commit 6067fb2

Please sign in to comment.