From 86aa65b254e386aee1f3a60d4a0677164cf2481a Mon Sep 17 00:00:00 2001 From: Alec Newman Date: Sat, 7 Feb 2015 14:53:18 -0700 Subject: [PATCH] Simplified the way creds are done with authorizedrequest --- main.go | 65 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/main.go b/main.go index d580018..9447d4a 100644 --- a/main.go +++ b/main.go @@ -63,31 +63,35 @@ type FetchResult struct { TempDir string ETag string Key *Key - URL string - User *neturl.Userinfo UnencryptedHash []byte } -func AuthorizedRequest(reqGenerator func() (*http.Request, error)) (*http.Response, *neturl.Userinfo, error) { +var ( + Authorization *neturl.Userinfo +) + +type ReqGenerator func(*neturl.Userinfo) (*http.Request, error) + +func AuthorizedRequest(reqGenerator ReqGenerator) (*http.Response, error) { client := &http.Client{} - req, err := reqGenerator() + req, err := reqGenerator(Authorization) if err != nil { - return nil, nil, err + return nil, err } username, password, _ := req.BasicAuth() res, err := client.Do(req) if err != nil { - return nil, nil, err + return nil, err } if res.StatusCode == http.StatusForbidden { res.Body.Close() if username != "" && password != "" { - return nil, nil, fmt.Errorf("Forbidden") + return nil, fmt.Errorf("Forbidden") } Userprintf("The server responded with %d Forbidden.\n", res.StatusCode) @@ -95,38 +99,40 @@ func AuthorizedRequest(reqGenerator func() (*http.Request, error)) (*http.Respon if username == "" { username, err = Prompt("Username: ") if err != nil { - return nil, nil, err + return nil, err } } if password == "" { password, err = PromptSecret("Password: ") if err != nil { - return nil, nil, err + return nil, err } } - req, err = reqGenerator() + Authorization = neturl.UserPassword(username, password) + + req, err = reqGenerator(Authorization) if err != nil { - return nil, nil, err + return nil, err } req.SetBasicAuth(username, password) res, err = client.Do(req) if err != nil { - return nil, nil, err + return nil, err } if res.StatusCode == http.StatusForbidden { Userprintf("The server responded with %d Forbidden.\n", res.StatusCode) - return nil, nil, fmt.Errorf("Forbidden") + return nil, fmt.Errorf("Forbidden") } else { - return res, neturl.UserPassword(username, password), nil + return res, nil } } else { - return res, neturl.UserPassword(username, password), nil + return res, nil } } @@ -321,7 +327,7 @@ func GetDecryptedFile(path string, key *Key) ([]byte, error) { } } -func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, user *neturl.Userinfo, err error) { +func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, err error) { var cachedETagBytes []byte var res *http.Response @@ -335,7 +341,7 @@ func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, user *net cachedETag := string(cachedETagBytes) - res, user, err = AuthorizedRequest(func() (*http.Request, error) { + res, err = AuthorizedRequest(func(user *neturl.Userinfo) (*http.Request, error) { req, err := AuthNewRequest("GET", url, nil) if err != nil { return nil, err @@ -398,20 +404,20 @@ func GetDecryptedHTTP(url string, key *Key) (data []byte, etag string, user *net return } -func GetDecryptedData(url string, key *Key) ([]byte, string, *neturl.Userinfo, error) { +func GetDecryptedData(url string, key *Key) ([]byte, string, error) { parsedURL, err := neturl.Parse(url) if err != nil { - return nil, "", nil, err + return nil, "", err } switch parsedURL.Scheme { case "", "file": data, err := GetDecryptedFile(parsedURL.Path, key) - return data, "", nil, err + return data, "", err case "http", "https": return GetDecryptedHTTP(url, key) default: - return nil, "", nil, fmt.Errorf("Unsupported URL scheme %#v", parsedURL.Scheme) + return nil, "", fmt.Errorf("Unsupported URL scheme %#v", parsedURL.Scheme) } } @@ -466,7 +472,6 @@ func EmptyFetch(url string, key *Key) (*FetchResult, error) { return &FetchResult{ Initial: true, TempDir: tempdir, - URL: url, Key: key, }, nil } @@ -562,15 +567,15 @@ func Push(fetched *FetchResult, url string) error { // holding on to the file handle from the original fetch. return ioutil.WriteFile(url, encryptedArchive, 0644) case "http", "https": - reqGenerator := func() (*http.Request, error) { + reqGenerator := func(user *neturl.Userinfo) (*http.Request, error) { req, err := AuthNewRequest("PUT", url, bytes.NewReader(encryptedArchive)) if err != nil { return nil, err } - if fetched.User != nil { - pw, _ := fetched.User.Password() - req.SetBasicAuth(fetched.User.Username(), pw) + if user != nil { + pw, _ := user.Password() + req.SetBasicAuth(user.Username(), pw) } if fetched.ETag == "" { @@ -584,7 +589,7 @@ func Push(fetched *FetchResult, url string) error { return req, nil } - res, user, err := AuthorizedRequest(reqGenerator) + res, err := AuthorizedRequest(reqGenerator) if err != nil { return err } else if res.StatusCode == http.StatusConflict { @@ -593,8 +598,6 @@ func Push(fetched *FetchResult, url string) error { return &ErrHTTPStatus{res, url} } - fetched.User = user - return nil default: return fmt.Errorf("Unsupported URL: %s", url) @@ -623,7 +626,7 @@ func Fetch(url string) (*FetchResult, error) { return nil, &Suberr{"GetKey", err} } - decryptedArchive, etag, user, err := GetDecryptedData(url, key) + decryptedArchive, etag, err := GetDecryptedData(url, key) if _, ok := err.(*ErrNotFound); ok { return EmptyFetch(url, key) } else if err != nil { @@ -667,8 +670,6 @@ func Fetch(url string) (*FetchResult, error) { TempDir: tempdir, ETag: etag, Key: key, - URL: url, - User: user, UnencryptedHash: unencryptedHash, }, nil }