Skip to content

Commit

Permalink
Simplified the way creds are done with authorizedrequest
Browse files Browse the repository at this point in the history
  • Loading branch information
rovaughn committed Feb 7, 2015
1 parent 4dc16ff commit 86aa65b
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,70 +63,76 @@ type FetchResult struct {
TempDir string
ETag string
Key *Key
URL string
User *neturl.Userinfo
UnencryptedHash []byte
}

func AuthorizedRequest(reqGenerator func() (*http.Request, error)) (*http.Response, *neturl.Userinfo, error) {
var (
Authorization *neturl.Userinfo
)

type ReqGenerator func(*neturl.Userinfo) (*http.Request, error)

func AuthorizedRequest(reqGenerator ReqGenerator) (*http.Response, error) {
client := &http.Client{}

req, err := reqGenerator()
req, err := reqGenerator(Authorization)
if err != nil {
return nil, nil, err
return nil, err
}

username, password, _ := req.BasicAuth()

res, err := client.Do(req)
if err != nil {
return nil, nil, err
return nil, err
}

if res.StatusCode == http.StatusForbidden {
res.Body.Close()

if username != "" && password != "" {
return nil, nil, fmt.Errorf("Forbidden")
return nil, fmt.Errorf("Forbidden")
}

Userprintf("The server responded with %d Forbidden.\n", res.StatusCode)

if username == "" {
username, err = Prompt("Username: ")
if err != nil {
return nil, nil, err
return nil, err
}
}

if password == "" {
password, err = PromptSecret("Password: ")
if err != nil {
return nil, nil, err
return nil, err
}
}

req, err = reqGenerator()
Authorization = neturl.UserPassword(username, password)

req, err = reqGenerator(Authorization)
if err != nil {
return nil, nil, err
return nil, err
}

req.SetBasicAuth(username, password)

res, err = client.Do(req)
if err != nil {
return nil, nil, err
return nil, err
}

if res.StatusCode == http.StatusForbidden {
Userprintf("The server responded with %d Forbidden.\n", res.StatusCode)

return nil, nil, fmt.Errorf("Forbidden")
return nil, fmt.Errorf("Forbidden")
} else {
return res, neturl.UserPassword(username, password), nil
return res, nil
}
} else {
return res, neturl.UserPassword(username, password), nil
return res, nil
}
}

Expand Down Expand Up @@ -321,7 +327,7 @@ func GetDecryptedFile(path string, key *Key) ([]byte, error) {
}
}

func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, user *neturl.Userinfo, err error) {
func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, err error) {
var cachedETagBytes []byte
var res *http.Response

Expand All @@ -335,7 +341,7 @@ func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, user *net

cachedETag := string(cachedETagBytes)

res, user, err = AuthorizedRequest(func() (*http.Request, error) {
res, err = AuthorizedRequest(func(user *neturl.Userinfo) (*http.Request, error) {
req, err := AuthNewRequest("GET", url, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -398,20 +404,20 @@ func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, user *net
return
}

func GetDecryptedData(url string, key *Key) ([]byte, string, *neturl.Userinfo, error) {
func GetDecryptedData(url string, key *Key) ([]byte, string, error) {
parsedURL, err := neturl.Parse(url)
if err != nil {
return nil, "", nil, err
return nil, "", err
}

switch parsedURL.Scheme {
case "", "file":
data, err := GetDecryptedFile(parsedURL.Path, key)
return data, "", nil, err
return data, "", err
case "http", "https":
return GetDecryptedHTTP(url, key)
default:
return nil, "", nil, fmt.Errorf("Unsupported URL scheme %#v", parsedURL.Scheme)
return nil, "", fmt.Errorf("Unsupported URL scheme %#v", parsedURL.Scheme)
}
}

Expand Down Expand Up @@ -466,7 +472,6 @@ func EmptyFetch(url string, key *Key) (*FetchResult, error) {
return &FetchResult{
Initial: true,
TempDir: tempdir,
URL: url,
Key: key,
}, nil
}
Expand Down Expand Up @@ -562,15 +567,15 @@ func Push(fetched *FetchResult, url string) error {
// holding on to the file handle from the original fetch.
return ioutil.WriteFile(url, encryptedArchive, 0644)
case "http", "https":
reqGenerator := func() (*http.Request, error) {
reqGenerator := func(user *neturl.Userinfo) (*http.Request, error) {
req, err := AuthNewRequest("PUT", url, bytes.NewReader(encryptedArchive))
if err != nil {
return nil, err
}

if fetched.User != nil {
pw, _ := fetched.User.Password()
req.SetBasicAuth(fetched.User.Username(), pw)
if user != nil {
pw, _ := user.Password()
req.SetBasicAuth(user.Username(), pw)
}

if fetched.ETag == "" {
Expand All @@ -584,7 +589,7 @@ func Push(fetched *FetchResult, url string) error {
return req, nil
}

res, user, err := AuthorizedRequest(reqGenerator)
res, err := AuthorizedRequest(reqGenerator)
if err != nil {
return err
} else if res.StatusCode == http.StatusConflict {
Expand All @@ -593,8 +598,6 @@ func Push(fetched *FetchResult, url string) error {
return &ErrHTTPStatus{res, url}
}

fetched.User = user

return nil
default:
return fmt.Errorf("Unsupported URL: %s", url)
Expand Down Expand Up @@ -623,7 +626,7 @@ func Fetch(url string) (*FetchResult, error) {
return nil, &Suberr{"GetKey", err}
}

decryptedArchive, etag, user, err := GetDecryptedData(url, key)
decryptedArchive, etag, err := GetDecryptedData(url, key)
if _, ok := err.(*ErrNotFound); ok {
return EmptyFetch(url, key)
} else if err != nil {
Expand Down Expand Up @@ -667,8 +670,6 @@ func Fetch(url string) (*FetchResult, error) {
TempDir: tempdir,
ETag: etag,
Key: key,
URL: url,
User: user,
UnencryptedHash: unencryptedHash,
}, nil
}
Expand Down

0 comments on commit 86aa65b

Please sign in to comment.