Skip to content

Commit

Permalink
Incorporate feedback from PR
Browse files Browse the repository at this point in the history
- Refactor `HasWriteAccess` and `HasReadAccess` to `VerfiyWriteAccess`
and `VerifyReadAccess`. Both methods now only return an error instead of
an error and bool.
- make use of `errors.Wrapf(..)`
- make use of `assert.EqualError(..)`

Updated error message format:
```
[prepare] Error verifying write access to "dev.registry.pivotal.io/garbage": GET https://dev.registry.pivotal.io/service/token?scope=repository%3Agarbage%3Apush%2Cpull&service=harbor-registry: unsupported status code 500
```

Signed-off-by: Sukhil Suresh <ssuresh@pivotal.io>
  • Loading branch information
sukhil-suresh committed May 11, 2020
1 parent 576542f commit e9ba538
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 81 deletions.
16 changes: 4 additions & 12 deletions cmd/build-init/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,14 @@ func main() {
}
}

hasImageWriteAccess, err := dockercreds.HasWriteAccess(creds, *imageTag)
err = dockercreds.VerifyWriteAccess(creds, *imageTag)
if err != nil {
logger.Fatal(err)
}

if !hasImageWriteAccess {
logger.Fatalf("invalid credentials to build to %s", *imageTag)
logger.Fatal(errors.Wrapf(err, "Error verifying write access to %q", *imageTag))
}

hasRunImageReadAccess, err := dockercreds.HasReadAccess(creds, *runImage)
err = dockercreds.VerifyReadAccess(creds, *runImage)
if err != nil {
logger.Fatal(errors.Wrapf(err, "validating read access to run image"))
}

if !hasRunImageReadAccess {
logger.Fatalf("could not read run image: %s", *runImage)
logger.Fatal(errors.Wrapf(err, "Error verifying read access to run image %q", *runImage))
}

err = fetchSource(logger, creds)
Expand Down
35 changes: 15 additions & 20 deletions pkg/dockercreds/access_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
"github.com/pkg/errors"
)

func HasWriteAccess(keychain authn.Keychain, tag string) (bool, error) {
func VerifyWriteAccess(keychain authn.Keychain, tag string) error {
var auth authn.Authenticator
ref, err := name.ParseReference(tag, name.WeakValidation)
if err != nil {
return false, err
return errors.Wrapf(err, "Error parsing reference %q", tag)
}

auth, err = keychain.Resolve(ref.Context().Registry)
if err != nil {
return false, err
return errors.Wrap(err, "Error resolving credentials")
}

scopes := []string{ref.Scope(transport.PushScope)}
Expand All @@ -30,17 +30,16 @@ func HasWriteAccess(keychain authn.Keychain, tag string) (bool, error) {
if transportError, ok := err.(*transport.Error); ok {
for _, diagnosticError := range transportError.Errors {
if diagnosticError.Code == transport.UnauthorizedErrorCode {
return false, nil
return errors.Wrap(err, "Unauthorized")
}
}

if transportError.StatusCode == 401 {
return false, nil
return errors.Wrap(err, "Unauthorized")
}
}

err = errors.Errorf("Error validating write permission to %s. %s", tag, err.Error())
return false, errors.WithStack(err)
return errors.WithStack(err)
}

client := &http.Client{Transport: tr}
Expand All @@ -54,30 +53,26 @@ func HasWriteAccess(keychain authn.Keychain, tag string) (bool, error) {
// Make the request to initiate the blob upload.
resp, err := client.Post(u.String(), "application/json", nil)
if err != nil {
return false, errors.WithStack(err)
return errors.WithStack(err)
}
defer resp.Body.Close()

if err := transport.CheckError(resp, http.StatusCreated, http.StatusAccepted); err != nil {
return false, nil
if err = transport.CheckError(resp, http.StatusCreated, http.StatusAccepted); err != nil {
return errors.WithStack(err)
}

return true, nil
return nil
}

func HasReadAccess(keychain authn.Keychain, tag string) (bool, error) {
func VerifyReadAccess(keychain authn.Keychain, tag string) error {
ref, err := name.ParseReference(tag, name.WeakValidation)
if err != nil {
return false, errors.Wrapf(err, "parse reference '%s'", tag)
return errors.Wrapf(err, "Error parsing reference %q", tag)
}
_, err = remote.Get(ref, remote.WithAuthFromKeychain(keychain), remote.WithTransport(http.DefaultTransport))
if err != nil {
if _, ok := err.(*transport.Error); ok {
return false, nil
}

return false, errors.Wrapf(err, "validating read access to: %s", tag)
if _, err = remote.Get(ref, remote.WithAuthFromKeychain(keychain), remote.WithTransport(http.DefaultTransport)); err != nil {
return errors.WithStack(err)
}

return true, nil
return nil
}
79 changes: 30 additions & 49 deletions pkg/dockercreds/access_checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
tagName = fmt.Sprintf("%s/some/image:tag", server.URL[7:])
)

when("HasWriteAccess", func() {
it("true when has permission", func() {
when("VerifyWriteAccess", func() {
it("does not error when has write access", func() {
handler.HandleFunc("/v2/some/image/blobs/uploads/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(201)
})
Expand All @@ -34,9 +34,8 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
writer.WriteHeader(200)
})

hasAccess, err := HasWriteAccess(testKeychain{}, tagName)
err := VerifyWriteAccess(testKeychain{}, tagName)
require.NoError(t, err)
assert.True(t, hasAccess)
})

it("requests scope push permission", func() {
Expand All @@ -51,10 +50,10 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
writer.WriteHeader(401)
})

_, _ = HasWriteAccess(testKeychain{}, tagName)
_ = VerifyWriteAccess(testKeychain{}, tagName)
})

it("false when fetching token is unauthorized", func() {
it("errors when fetching token is unauthorized", func() {
handler.HandleFunc("/unauthorized-token/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(401)
writer.Write([]byte(`{"errors": [{"code": "UNAUTHORIZED"}]}`))
Expand All @@ -65,12 +64,12 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
writer.WriteHeader(401)
})

hasAccess, err := HasWriteAccess(testKeychain{}, tagName)
require.NoError(t, err)
assert.False(t, hasAccess)
err := VerifyWriteAccess(testKeychain{}, tagName)
require.Error(t, err)
assert.Contains(t, err.Error(), "Unauthorized")
})

it("false when server responds with unauthorized but without a code such as on artifactory", func() {
it("errors when server responds with unauthorized but without a code such as on artifactory", func() {
handler.HandleFunc("/unauthorized-token/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(401)
writer.Write([]byte(`{"statusCode":401,"details":"BAD_CREDENTIAL"}`))
Expand All @@ -81,12 +80,12 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
writer.WriteHeader(401)
})

hasAccess, err := HasWriteAccess(testKeychain{}, tagName)
require.NoError(t, err)
assert.False(t, hasAccess)
err := VerifyWriteAccess(testKeychain{}, tagName)
require.Error(t, err)
assert.Contains(t, err.Error(), "Unauthorized")
})

it("false when does not have permission", func() {
it("errors when does not have permission", func() {
handler.HandleFunc("/v2/some/image/blobs/uploads/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(403)
})
Expand All @@ -95,36 +94,31 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
writer.WriteHeader(200)
})

hasAccess, err := HasWriteAccess(testKeychain{}, tagName)
require.NoError(t, err)
assert.False(t, hasAccess)
err := VerifyWriteAccess(testKeychain{}, tagName)
assert.EqualError(t, err, fmt.Sprintf("POST %s/v2/some/image/blobs/uploads/: unsupported status code 403", server.URL))
})

it("false when cannot reach server with an error", func() {
it("errors when cannot reach server", func() {
handler.HandleFunc("/v2/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(404)
})

hasAccess, err := HasWriteAccess(testKeychain{}, tagName)
require.Error(t, err)
assert.False(t, hasAccess)
err := VerifyWriteAccess(testKeychain{}, tagName)
assert.EqualError(t, err, fmt.Sprintf("GET %s/v2/: unsupported status code 404", server.URL))
})

it("wraps unhandled server errors with a reasonable error message", func() {
it("errors when server errors", func() {
handler.HandleFunc("/v2/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(500)
})

hasAccess, err := HasWriteAccess(testKeychain{}, tagName)
require.Error(t, err)
expectedErrorMessage := fmt.Sprintf("Error validating write permission to %s. GET %s/v2/: unsupported status code 500", tagName, server.URL)
assert.Equal(t, expectedErrorMessage, err.Error())
assert.False(t, hasAccess)
err := VerifyWriteAccess(testKeychain{}, tagName)
assert.EqualError(t, err, fmt.Sprintf("GET %s/v2/: unsupported status code 500", server.URL))
})
})

when("#HasReadAccess", func() {
it("returns true when we do have read access", func() {
when("#VerifyReadAccess", func() {
it("does not error when has read access", func() {
handler.HandleFunc("/v2/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(200)
})
Expand All @@ -133,12 +127,11 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
writer.WriteHeader(200)
})

canRead, err := HasReadAccess(testKeychain{}, tagName)
err := VerifyReadAccess(testKeychain{}, tagName)
require.NoError(t, err)
assert.True(t, canRead)
})

it("returns false when we do not have read access", func() {
it("errors when has no read access", func() {
handler.HandleFunc("/v2/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(200)
})
Expand All @@ -147,29 +140,17 @@ func testAccessChecker(t *testing.T, when spec.G, it spec.S) {
writer.WriteHeader(401)
})

canRead, err := HasReadAccess(testKeychain{}, tagName)
require.NoError(t, err)
assert.False(t, canRead)
})

it("returns false when server responds with 404", func() {
handler.HandleFunc("/v2/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(404)
})

canRead, err := HasReadAccess(testKeychain{}, tagName)
require.NoError(t, err)
assert.False(t, canRead)
err := VerifyReadAccess(testKeychain{}, tagName)
assert.EqualError(t, err, fmt.Sprintf("GET %s/v2/some/image/manifests/tag: unsupported status code 401", server.URL))
})

it("returns false with error when we cannot reach the server", func() {
it("errors when cannot reach server", func() {
handler.HandleFunc("/v2/", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(404)
})

canRead, err := HasReadAccess(testKeychain{}, "localhost:9999/blj")
require.Error(t, err)
assert.False(t, canRead)
err := VerifyReadAccess(testKeychain{}, tagName)
assert.EqualError(t, err, fmt.Sprintf("GET %s/v2/: unsupported status code 404", server.URL))
})
})
}
Expand Down

0 comments on commit e9ba538

Please sign in to comment.