diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..83d34b7 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,60 @@ +version: 2.1 + +executors: + golang: + parameters: + version: + description: Go version + type: string + docker: + - image: circleci/golang:<< parameters.version >> + +commands: + test: + steps: + - run: make + - run: make test + +workflows: + test: + jobs: + - go1_9 + - go1_10 + - go1_11 + - go1_12 + +jobs: + go1_12: &template + executor: + name: golang + version: "1.12" + steps: + - checkout + - test + + go1_11: + <<: *template + executor: + name: golang + version: "1.11" + + go1_10: &nomod_template + <<: *template + environment: + GO111MODULE: off + GO_COMMAND: vgo + executor: + name: golang + version: "1.10" + working_directory: /go/src/github.com/gbrlsnchs/jwt + steps: + - checkout + - run: go get -u golang.org/x/crypto/ed25519 + - run: go get -u golang.org/x/xerrors + - test + + go1_9: + <<: *nomod_template + executor: + name: golang + version: "1.9" diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8b712a2..0000000 --- a/.travis.yml +++ /dev/null @@ -1,17 +0,0 @@ -os: - - 'linux' - - 'osx' - - 'windows' -sudo: false - -language: 'go' -go: - - '1.11' - - '1.12beta2' - -install: - - 'cd $GOPATH' - - 'if [ "$(go version | awk ''{print $3}'')" == "go1.10" ]; then go get -u golang.org/x/vgo && BIN=vgo; else BIN=go; fi' - - 'mv ${TRAVIS_BUILD_DIR} ${TRAVIS_HOME}/test' - - 'cd ${TRAVIS_HOME}/test' -script: '${BIN} test -v -race -count=10' diff --git a/Makefile b/Makefile index 077a1d1..f429d0c 100644 --- a/Makefile +++ b/Makefile @@ -1,30 +1,19 @@ -install: export GO111MODULE := on -install: - current_dir := ${PWD} - @cd ${GOPATH} - go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.16.0 +export GO111MODULE ?= on + +all: export GO111MODULE := off +all: go get -u golang.org/x/tools/cmd/goimports + go get -u golang.org/x/lint/golint -lint: golint_cmd := golint -set_exit_status -lint: - @echo "+++ 'lint' (${golint_cmd})" - @${golint_cmd} +fix: + @goimports -w *.go -test-units: export GO111MODULE := on -test-units: go_test_flags := -v -coverprofile=c.out -ifdef GO_TEST_RUN -test-units: go_test_flags += -run=${GO_TEST_RUN} -endif -test-units: GO_TEST_TARGET ?= ./... -test-units: go_test_cmd := go test ${go_test_flags} ${GO_TEST_TARGET} -test-units: - @echo "+++ 'test-units' (${go_test_cmd})" - @${go_test_cmd} +lint: + @! goimports -d . | grep -vF "no errors" + @golint -set_exit_status ./... -test-cover: export GO111MODULE := on -test-cover: go_tool_cover := go tool cover -func=c.out -test-cover: - @echo "+++ 'test-cover' (${go_tool_cover})" - @${go_tool_cover} +bench: + @go test -v -run=^$$ -bench=. -test: test-units lint +test: lint + @go test -v ./... diff --git a/README.md b/README.md index d10d711..bcf90f3 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,22 @@ # jwt (JSON Web Token for Go) +[![JWT compatible](https://jwt.io/img/badge.svg)](https://jwt.io) -

- -

- -

- -

+[![CircleCI](https://circleci.com/gh/gbrlsnchs/jwt.svg?style=shield)](https://circleci.com/gh/gbrlsnchs/jwt) +[![Go Report Card](https://goreportcard.com/badge/github.com/gbrlsnchs/jwt)](https://goreportcard.com/report/github.com/gbrlsnchs/jwt) +[![GoDoc](https://godoc.org/github.com/gbrlsnchs/jwt?status.svg)](https://godoc.org/github.com/gbrlsnchs/jwt) +[![Join the chat at https://gitter.im/gbrlsnchs/jwt](https://badges.gitter.im/gbrlsnchs/jwt.svg)](https://gitter.im/gbrlsnchs/jwt?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -

- Build Status - Go Report Card - Sourcegraph - GoDoc - Minimal Version - Join the chat at https://gitter.im/gbrlsnchs/jwt -

- -

- JWT compatible -

- -## Important -Branch `master` is unstable, **always** use tagged versions. That way it is possible to differentiate pre-release tags from production ones. -In other words, API changes all the time in `master`. It's a place for public experiment. Thus, make use of the latest stable version via Go modules. +## Compatibility +[![Version Compatibility](https://img.shields.io/badge/go%20modules-go1.11+-5272b4.svg)](https://github.com/gbrlsnchs/jwt#installing) +[![vgo](https://img.shields.io/badge/vgo-go1.10-5272b4.svg)](https://github.com/gbrlsnchs/jwt#installing) +[![go get](https://img.shields.io/badge/go%20get-go1.9.7+,%20go1.10.3+%20and%20go1.11-5272b4.svg)](https://github.com/gbrlsnchs/jwt#installing) ## About This package is a JWT signer, verifier and validator for [Go](https://golang.org) (or Golang). Although there are many JWT packages out there for Go, many lack support for some signing, verifying or validation methods and, when they don't, they're overcomplicated. This package tries to mimic the ease of use from [Node JWT library](https://github.com/auth0/node-jsonwebtoken)'s API while following the [Effective Go](https://golang.org/doc/effective_go.html) guidelines. -Support for [JWE](https://tools.ietf.org/html/rfc7516) isn't provided. Instead, [JWS](https://tools.ietf.org/html/rfc7515) is used, narrowed down to the [JWT specification](https://tools.ietf.org/html/rfc7519). +Support for [JWE](https://tools.ietf.org/html/rfc7516) isn't provided (not yet but is in the roadmap, see #17). Instead, [JWS](https://tools.ietf.org/html/rfc7515) is used, narrowed down to the [JWT specification](https://tools.ietf.org/html/rfc7519). ### Supported signing methods | | SHA-256 | SHA-384 | SHA-512 | @@ -41,117 +27,221 @@ Support for [JWE](https://tools.ietf.org/html/rfc7516) isn't provided. Instead, | ECDSA | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | EdDSA | :heavy_minus_sign: | :heavy_minus_sign: | :heavy_check_mark: | +## Important +Branch `master` is unstable, **always** use tagged versions. That way it is possible to differentiate pre-release tags from production ones. +In other words, API changes all the time in `master`. It's a place for public experiment. Thus, make use of the latest stable version via Go modules. + ## Usage Full documentation [here](https://godoc.org/github.com/gbrlsnchs/jwt). ### Installing -`GO111MODULE=on go get -u github.com/gbrlsnchs/jwt/v3` +
Go 1.12 onward +

+ +```sh +$ go get -u github.com/gbrlsnchs/jwt/v3 +``` + +

+
+ +
Go 1.11 +

-### Importing +```sh +$ GO111MODULE=on go get -u github.com/gbrlsnchs/jwt/v3 +``` + +

+
+ +
Go 1.10 with vgo +

+ +```sh +$ vgo get -u github.com/gbrlsnchs/jwt/v3 +``` + +

+
+ +
Go 1.9.7+, Go 1.10.3+ (without vgo) and Go 1.11 (when GO111MODULE=off) +

+ +```sh +$ go get -u github.com/gbrlsnchs/jwt/v3 +``` + +#### Important +Your project must be inside the `GOPATH`. + +

+
+ +### Signing ```go import ( - // ... + "time" "github.com/gbrlsnchs/jwt/v3" ) -``` -### Signing a simple JWT -```go -now := time.Now() -hs256 := jwt.NewHMAC(jwt.SHA256, []byte("secret")) -h := jwt.Header{KeyID: "kid"} -p := jwt.Payload{ - Issuer: "gbrlsnchs", - Subject: "someone", - Audience: jwt.Audience{"https://golang.org", "https://jwt.io"}, - ExpirationTime: now.Add(24 * 30 * 12 * time.Hour).Unix(), - NotBefore: now.Add(30 * time.Minute).Unix(), - IssuedAt: now.Unix(), - JWTID: "foobar", +type CustomPayload struct { + jwt.Payload + Foo string `json:"foo,omitempty"` + Bar int `json:"bar,omitempty"` } -token, err := jwt.Sign(h, p, hs256) -if err != nil { - // Handle error. + +var hs = jwt.NewHS256([]byte("secret")) + +func main() { + now := time.Now() + pl := CustomPayload{ + Payload: jwt.Payload{ + Issuer: "gbrlsnchs", + Subject: "someone", + Audience: jwt.Audience{"https://golang.org", "https://jwt.io"}, + ExpirationTime: jwt.NumericDate(now.Add(24 * 30 * 12 * time.Hour)), + NotBefore: jwt.NumericDate(now.Add(30 * time.Minute)), + IssuedAt: jwt.NumericDate(now), + JWTID: "foobar", + }, + Foo: "foo", + Bar: 1337, + } + + token, err := jwt.Sign(pl, hs) + if err != nil { + // ... + } + + // ... } -log.Printf("token = %s", token) ``` -### Signing a JWT with public claims -#### First, create a custom type and embed a JWT pointer in it +### Verifying ```go +import "github.com/gbrlsnchs/jwt/v3" + type CustomPayload struct { jwt.Payload - IsLoggedIn bool `json:"isLoggedIn"` - CustomField string `json:"customField,omitempty"` + Foo string `json:"foo,omitempty"` + Bar int `json:"bar,omitempty"` +} + +var hs = jwt.NewHS256([]byte("secret")) + +func main() { + // ... + + var pl CustomPayload + hd, err := jwt.Verify(token, hs, &pl) + if err != nil { + // ... + } + + // ... } ``` -#### Now initialize, marshal and sign it +### Other use case examples +
Setting "cty" and "kid" claims +

+ +The "cty" and "kid" claims can be set by passing options to the `jwt.Sign` function: ```go -now := time.Now() -hs256 := jwt.NewHMAC(jwt.SHA256, []byte("secret")) -h := jwt.Header{KeyID: "kid"} -p := CustomPayload{ - Payload: jwt.Payload{ - Issuer: "gbrlsnchs", - Subject: "someone", - Audience: jwt.Audience{"https://golang.org", "https://jwt.io"}, - ExpirationTime: now.Add(24 * 30 * 12 * time.Hour).Unix(), - NotBefore: now.Add(30 * time.Minute).Unix(), - IssuedAt: now.Unix(), - JWTID: "foobar", - }, - IsLoggedIn: true, - CustomField: "myCustomField", -} -token, err := jwt.Sign(h, p, hs256) -if err != nil { - // Handle error. +import ( + "time" + + "github.com/gbrlsnchs/jwt/v3" +) + +var hs = jwt.NewHS256([]byte("secret")) + +func main() { + pl := jwt.Payload{ + Subject: "gbrlsnchs", + Issuer: "gsr.dev", + IssuedAt: jwt.NumericDate(time.Now()), + } + + token, err := jwt.Sign(pl, hs, jwt.ContentType("JWT"), jwt.KeyID("my_key")) + if err != nil { + // ... + } + + // ... } -log.Printf("token = %s", token) ``` -### Verifying and validating a JWT +

+
+ +
Validating "alg" before verifying +

+ +For validating the "alg" field in a JOSE header **before** verification, the `jwt.ValidateHeader` option must be passed to `jwt.Verify`. ```go -now := time.Now() -hs256 := jwt.NewHMAC(jwt.SHA256, []byte("secret")) -token := []byte("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + - "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ." + - "lZ1zDoGNAv3u-OclJtnoQKejE8_viHlMtGlAxE8AE0Q") - -raw, err := jwt.Parse(token) -if err != nil { - // Handle error. -} -if err = raw.Verify(hs256); err != nil { - // Handle error. +import "github.com/gbrlsnchs/jwt/v3" + +var hs = jwt.NewHS256([]byte("secret")) + +func main() { + // ... + + var pl jwt.Payload + if _, err := jwt.Verify(token, hs, &pl, jwt.ValidateHeader); err != nil { + // ... + } + + // ... } +``` + +

+
+ +
Using an Algorithm resolver +

+ +```go +import ( + "errors" + + "github.com/gbrlsnchs/jwt/v3" + "github.com/gbrlsnchs/jwt/v3/jwtutil" +) + var ( - h jwt.Header - p CustomPayload + // ... + + rs256 = jwt.NewRS256(jwt.RSAPublicKey(myRSAPublicKey)) + es256 = jwt.NewES256(jwt.ECDSAPublicKey(myECDSAPublicKey)) ) -if h, err = raw.Decode(&p); err != nil { - // Handle error. -} -fmt.Println(h.Algorithm) -fmt.Println(h.KeyID) - -iatValidator := jwt.IssuedAtValidator(now) -expValidator := jwt.ExpirationTimeValidator(now, true) -audValidator := jwt.AudienceValidator(jwt.Audience{"https://golang.org", "https://jwt.io", "https://google.com", "https://reddit.com"}) -if err := p.Validate(iatValidator, expValidator, audValidator); err != nil { - switch err { - case jwt.ErrIatValidation: - // handle "iat" validation error - case jwt.ErrExpValidation: - // handle "exp" validation error - case jwt.ErrAudValidation: - // handle "aud" validation error + +func main() { + rv := &jwtutil.Resolver{New: func(hd jwt.Header) { + switch hd.KeyID { + case "foo": + return rs256, nil + case "bar": + return es256, nil + default: + return nil, errors.New(`invalid "kid"`) + } + }} + var pl jwt.Payload + if _, err := jwt.Verify(token, rv, &pl); err != nil { + // ... } + + // ... } ``` +

+
+ ## Contributing ### How to help - For bugs and opinions, please [open an issue](https://github.com/gbrlsnchs/jwt/issues/new) diff --git a/audience_test.go b/audience_test.go index 9d9711b..28aec26 100644 --- a/audience_test.go +++ b/audience_test.go @@ -9,6 +9,21 @@ import ( ) func TestAudienceMarshal(t *testing.T) { + t.Run("omitempty", func(t *testing.T) { + var ( + b []byte + err error + v = struct { + Audience jwt.Audience `json:"aud,omitempty"` + }{} + ) + if b, err = json.Marshal(v); err != nil { + t.Fatal(err) + } + checkAudMarshal(t, "{}", b) + + }) + testCases := []struct { aud jwt.Audience expected string @@ -39,20 +54,6 @@ func TestAudienceMarshal(t *testing.T) { } } -func TestAudienceOmitempty(t *testing.T) { - var ( - b []byte - err error - v = struct { - Audience jwt.Audience `json:"aud,omitempty"` - }{} - ) - if b, err = json.Marshal(v); err != nil { - t.Fatal(err) - } - checkAudMarshal(t, "{}", b) -} - func TestAudienceUnmarshal(t *testing.T) { testCases := []struct { jstr []byte diff --git a/bench_test.go b/bench_test.go new file mode 100644 index 0000000..572d643 --- /dev/null +++ b/bench_test.go @@ -0,0 +1,80 @@ +package jwt_test + +import ( + "testing" + "time" + + "github.com/gbrlsnchs/jwt/v3" +) + +var ( + benchHS256 = jwt.NewHS256([]byte("secret")) + benchRecv []byte +) + +func BenchmarkSign(b *testing.B) { + now := time.Now() + var ( + token []byte + err error + pl = jwt.Payload{ + Issuer: "gbrlsnchs", + Subject: "someone", + Audience: jwt.Audience{"https://golang.org", "https://jwt.io"}, + ExpirationTime: jwt.NumericDate(now.Add(24 * 30 * 12 * time.Hour)), + NotBefore: jwt.NumericDate(now.Add(30 * time.Minute)), + IssuedAt: jwt.NumericDate(now), + JWTID: "foobar", + } + ) + b.Run("Default", func(b *testing.B) { + b.ReportAllocs() + for n := 0; n < b.N; n++ { + token, err = jwt.Sign(pl, benchHS256) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run(`With "kid"`, func(b *testing.B) { + b.ReportAllocs() + for n := 0; n < b.N; n++ { + token, err = jwt.Sign(pl, benchHS256, jwt.KeyID("kid")) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run(`With "cty" and "kid"`, func(b *testing.B) { + b.ReportAllocs() + for n := 0; n < b.N; n++ { + token, err = jwt.Sign(pl, benchHS256, jwt.ContentType("cty"), jwt.KeyID("kid")) + if err != nil { + b.Fatal(err) + } + } + }) + + benchRecv = token + +} + +func BenchmarkVerify(b *testing.B) { + var ( + token = []byte( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + + "eyJpc3MiOiJnYnJsc25jaHMiLCJzdWIiOiJzb21lb25lIiwiYXVkIjpbImh0dHBzOi8vZ29sYW5nLm9yZyIsImh0dHBzOi8vand0LmlvIl0sImV4cCI6MTU5MzM5MTE4MiwibmJmIjoxNTYyMjg4OTgyLCJpYXQiOjE1NjIyODcxODIsImp0aSI6ImZvb2JhciJ9." + + "bKevp7jmMbH9-Hy5g5OxLgq8tg13z9voH7lZ4m9y484", + ) + err error + ) + b.Run("Default", func(b *testing.B) { + b.ReportAllocs() + for n := 0; n < b.N; n++ { + var pl jwt.Payload + if _, err = jwt.Verify(token, benchHS256, &pl); err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/ecdsa_sha.go b/ecdsa_sha.go index 778d639..3e87f16 100644 --- a/ecdsa_sha.go +++ b/ecdsa_sha.go @@ -15,12 +15,26 @@ var ( ErrECDSANilPrivKey = errors.New("jwt: ECDSA private key is nil") // ErrECDSANilPubKey is the error for trying to verify a JWT with a nil public key. ErrECDSANilPubKey = errors.New("jwt: ECDSA public key is nil") - // ErrECDSAVerification is the error for an invalid signature. + // ErrECDSAVerification is the error for an invalid ECDSA signature. ErrECDSAVerification = errors.New("jwt: ECDSA verification failed") - _ Algorithm = new(ecdsaSHA) + _ Algorithm = new(ECDSASHA) ) +// ECDSAPrivateKey is an option to set a private key to the ECDSA-SHA algorithm. +func ECDSAPrivateKey(priv *ecdsa.PrivateKey) func(*ECDSASHA) { + return func(es *ECDSASHA) { + es.priv = priv + } +} + +// ECDSAPublicKey is an option to set a public key to the ECDSA-SHA algorithm. +func ECDSAPublicKey(pub *ecdsa.PublicKey) func(*ECDSASHA) { + return func(es *ECDSASHA) { + es.pub = pub + } +} + func byteSize(bitSize int) int { byteSize := bitSize / 8 if bitSize%8 > 0 { @@ -29,7 +43,8 @@ func byteSize(bitSize int) int { return byteSize } -type ecdsaSHA struct { +// ECDSASHA is an algorithm that uses ECDSA to sign SHA hashes. +type ECDSASHA struct { name string priv *ecdsa.PrivateKey pub *ecdsa.PublicKey @@ -39,39 +54,44 @@ type ecdsaSHA struct { pool *hashPool } -func newECDSASHA(name string, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, sha crypto.Hash) *ecdsaSHA { - return &ecdsaSHA{ +func newECDSASHA(name string, opts []func(*ECDSASHA), sha crypto.Hash) *ECDSASHA { + es := ECDSASHA{ name: name, - priv: priv, - pub: pub, sha: sha, - size: byteSize(pub.Params().BitSize) * 2, pool: newHashPool(sha.New), } + for _, opt := range opts { + opt(&es) + } + if es.pub == nil { + es.pub = &es.priv.PublicKey + } + es.size = byteSize(es.pub.Params().BitSize) * 2 + return &es } // NewES256 creates a new algorithm using ECDSA and SHA-256. -func NewES256(priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey) Algorithm { - return newECDSASHA("ES256", priv, pub, crypto.SHA256) +func NewES256(opts ...func(*ECDSASHA)) *ECDSASHA { + return newECDSASHA("ES256", opts, crypto.SHA256) } // NewES384 creates a new algorithm using ECDSA and SHA-384. -func NewES384(priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey) Algorithm { - return newECDSASHA("ES384", priv, pub, crypto.SHA384) +func NewES384(opts ...func(*ECDSASHA)) *ECDSASHA { + return newECDSASHA("ES384", opts, crypto.SHA384) } // NewES512 creates a new algorithm using ECDSA and SHA-512. -func NewES512(priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey) Algorithm { - return newECDSASHA("ES512", priv, pub, crypto.SHA512) +func NewES512(opts ...func(*ECDSASHA)) *ECDSASHA { + return newECDSASHA("ES512", opts, crypto.SHA512) } // Name returns the algorithm's name. -func (es *ecdsaSHA) Name() string { +func (es *ECDSASHA) Name() string { return es.name } // Sign signs headerPayload using the ECDSA-SHA algorithm. -func (es *ecdsaSHA) Sign(headerPayload []byte) ([]byte, error) { +func (es *ECDSASHA) Sign(headerPayload []byte) ([]byte, error) { if es.priv == nil { return nil, ErrECDSANilPrivKey } @@ -79,12 +99,12 @@ func (es *ecdsaSHA) Sign(headerPayload []byte) ([]byte, error) { } // Size returns the signature's byte size. -func (es *ecdsaSHA) Size() int { +func (es *ECDSASHA) Size() int { return es.size } // Verify verifies a signature based on headerPayload using ECDSA-SHA. -func (es *ecdsaSHA) Verify(headerPayload, sig []byte) (err error) { +func (es *ECDSASHA) Verify(headerPayload, sig []byte) (err error) { if es.pub == nil { return ErrECDSANilPubKey } @@ -108,7 +128,7 @@ func (es *ecdsaSHA) Verify(headerPayload, sig []byte) (err error) { return nil } -func (es *ecdsaSHA) sign(headerPayload []byte) ([]byte, error) { +func (es *ECDSASHA) sign(headerPayload []byte) ([]byte, error) { sum, err := es.pool.sign(headerPayload) if err != nil { return nil, err diff --git a/ecdsa_sha_test.go b/ecdsa_sha_test.go index 30619ec..58a6029 100644 --- a/ecdsa_sha_test.go +++ b/ecdsa_sha_test.go @@ -1 +1,26 @@ package jwt_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" +) + +var ( + es256PrivateKey1, es256PublicKey1 = genECDSAKeys(elliptic.P256()) + es256PrivateKey2, es256PublicKey2 = genECDSAKeys(elliptic.P256()) + + es384PrivateKey1, es384PublicKey1 = genECDSAKeys(elliptic.P384()) + es384PrivateKey2, es384PublicKey2 = genECDSAKeys(elliptic.P384()) + + es512PrivateKey1, es512PublicKey1 = genECDSAKeys(elliptic.P521()) + es512PrivateKey2, es512PublicKey2 = genECDSAKeys(elliptic.P521()) +) + +func genECDSAKeys(c elliptic.Curve) (*ecdsa.PrivateKey, *ecdsa.PublicKey) { + priv, err := ecdsa.GenerateKey(c, rand.Reader) + if err != nil { + panic(err) + } + return priv, &priv.PublicKey +} diff --git a/ed25519_go1_12.go b/ed25519_go1_12.go index 16d2fc9..070aeb1 100644 --- a/ed25519_go1_12.go +++ b/ed25519_go1_12.go @@ -3,58 +3,82 @@ package jwt import ( + "errors" + "github.com/gbrlsnchs/jwt/v3/internal" "golang.org/x/crypto/ed25519" ) var ( // ErrEd25519PrivKey is the error for trying to sign a JWT with a nil private key. - ErrEd25519PrivKey = internal.NewError("jwt: edDSA private key is nil") + ErrEd25519PrivKey = errors.New("jwt: Ed25519 private key is nil") // ErrEd25519PubKey is the error for trying to verify a JWT with a nil public key. - ErrEd25519PubKey = internal.NewError("jwt: edDSA public key is nil") - // ErrEd25519Verification is the error for when verification with edDSA fails. - ErrEd25519Verification = internal.NewError("jwt: edDSA verification failed") + ErrEd25519PubKey = errors.New("jwt: Ed25519 public key is nil") + // ErrEd25519Verification is the error for when verification with Ed25519 fails. + ErrEd25519Verification = errors.New("jwt: Ed25519 verification failed") - _ Algorithm = new(edDSA) + _ Algorithm = new(Ed25519) ) -type edDSA struct { +// Ed25519PrivateKey is an option to set a private key to the Ed25519 algorithm. +func Ed25519PrivateKey(priv ed25519.PrivateKey) func(*Ed25519) { + return func(ed *Ed25519) { + ed.priv = priv + } +} + +// Ed25519PublicKey is an option to set a public key to the Ed25519 algorithm. +func Ed25519PublicKey(pub ed25519.PublicKey) func(*Ed25519) { + return func(ed *Ed25519) { + ed.pub = pub + } +} + +// Ed25519 is an algorithm that uses EdDSA to sign SHA-512 hashes. +type Ed25519 struct { priv ed25519.PrivateKey pub ed25519.PublicKey } // NewEd25519 creates a new algorithm using EdDSA and SHA-512. -func NewEd25519(priv ed25519.PrivateKey, pub ed25519.PublicKey) Algorithm { - return &edDSA{priv: priv, pub: pub} +func NewEd25519(opts ...func(*Ed25519)) *Ed25519 { + var ed Ed25519 + for _, opt := range opts { + opt(&ed) + } + if ed.pub == nil { + ed.pub = ed.priv.Public().(ed25519.PublicKey) + } + return &ed } // Name returns the algorithm's name. -func (*edDSA) Name() string { +func (*Ed25519) Name() string { return "Ed25519" } // Sign signs headerPayload using the Ed25519 algorithm. -func (e *edDSA) Sign(headerPayload []byte) ([]byte, error) { - if e.priv == nil { +func (ed *Ed25519) Sign(headerPayload []byte) ([]byte, error) { + if ed.priv == nil { return nil, ErrEd25519PrivKey } - return ed25519.Sign(e.priv, headerPayload), nil + return ed25519.Sign(ed.priv, headerPayload), nil } // Size returns the signature byte size. -func (*edDSA) Size() int { +func (*Ed25519) Size() int { return ed25519.SignatureSize } // Verify verifies a payload and a signature. -func (e *edDSA) Verify(payload, sig []byte) (err error) { - if e.pub == nil { +func (ed *Ed25519) Verify(payload, sig []byte) (err error) { + if ed.pub == nil { return ErrEd25519PubKey } if sig, err = internal.DecodeToBytes(sig); err != nil { return err } - if !ed25519.Verify(e.pub, payload, sig) { + if !ed25519.Verify(ed.pub, payload, sig) { return ErrEd25519Verification } return nil diff --git a/ed25519_test.go b/ed25519_test.go new file mode 100644 index 0000000..e57692a --- /dev/null +++ b/ed25519_test.go @@ -0,0 +1,8 @@ +package jwt_test + +import "github.com/gbrlsnchs/jwt/v3/internal" + +var ( + ed25519PrivateKey1, ed25519PublicKey1 = internal.GenerateEd25519Keys() + ed25519PrivateKey2, ed25519PublicKey2 = internal.GenerateEd25519Keys() +) diff --git a/encoded_jwt.png b/encoded_jwt.png deleted file mode 100644 index 0f5ed18..0000000 Binary files a/encoded_jwt.png and /dev/null differ diff --git a/go.mod b/go.mod index c586c25..f93a627 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,5 @@ module github.com/gbrlsnchs/jwt/v3 -require ( - golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 - golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 -) +require golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 + +go 1.10 diff --git a/go.sum b/go.sum index 3811b06..ab4e508 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,2 @@ golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE= golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/gopher_head.png b/gopher_head.png deleted file mode 100644 index bfb720f..0000000 Binary files a/gopher_head.png and /dev/null differ diff --git a/header.go b/header.go index 2e48cd4..ae4fb6d 100644 --- a/header.go +++ b/header.go @@ -5,7 +5,7 @@ package jwt // Parameters are ordered according to the RFC 7515. type Header struct { Algorithm string `json:"alg,omitempty"` + ContentType string `json:"cty,omitempty"` KeyID string `json:"kid,omitempty"` Type string `json:"typ,omitempty"` - ContentType string `json:"cty,omitempty"` } diff --git a/hmac_sha.go b/hmac_sha.go index c9cd7ff..91c54d2 100644 --- a/hmac_sha.go +++ b/hmac_sha.go @@ -15,10 +15,11 @@ var ( // ErrHMACVerification is the error for an invalid signature. ErrHMACVerification = errors.New("jwt: HMAC verification failed") - _ Algorithm = new(hmacSHA) + _ Algorithm = new(HMACSHA) ) -type hmacSHA struct { +// HMACSHA is an algorithm that uses HMAC to sign SHA hashes. +type HMACSHA struct { name string key []byte sha crypto.Hash @@ -26,8 +27,8 @@ type hmacSHA struct { pool *hashPool } -func newHMACSHA(name string, key []byte, sha crypto.Hash) *hmacSHA { - return &hmacSHA{ +func newHMACSHA(name string, key []byte, sha crypto.Hash) *HMACSHA { + return &HMACSHA{ name: name, // cache name key: key, sha: sha, @@ -37,27 +38,27 @@ func newHMACSHA(name string, key []byte, sha crypto.Hash) *hmacSHA { } // NewHS256 creates a new algorithm using HMAC and SHA-256. -func NewHS256(key []byte) Algorithm { +func NewHS256(key []byte) *HMACSHA { return newHMACSHA("HS256", key, crypto.SHA256) } // NewHS384 creates a new algorithm using HMAC and SHA-384. -func NewHS384(key []byte) Algorithm { +func NewHS384(key []byte) *HMACSHA { return newHMACSHA("HS384", key, crypto.SHA384) } // NewHS512 creates a new algorithm using HMAC and SHA-512. -func NewHS512(key []byte) Algorithm { +func NewHS512(key []byte) *HMACSHA { return newHMACSHA("HS512", key, crypto.SHA512) } // Name returns the algorithm's name. -func (hs *hmacSHA) Name() string { +func (hs *HMACSHA) Name() string { return hs.name } // Sign signs headerPayload using the HMAC-SHA algorithm. -func (hs *hmacSHA) Sign(headerPayload []byte) ([]byte, error) { +func (hs *HMACSHA) Sign(headerPayload []byte) ([]byte, error) { if string(hs.key) == "" { return nil, ErrHMACMissingKey } @@ -65,12 +66,12 @@ func (hs *hmacSHA) Sign(headerPayload []byte) ([]byte, error) { } // Size returns the signature's byte size. -func (hs *hmacSHA) Size() int { +func (hs *HMACSHA) Size() int { return hs.size } // Verify verifies a signature based on headerPayload using HMAC-SHA. -func (hs *hmacSHA) Verify(headerPayload, sig []byte) (err error) { +func (hs *HMACSHA) Verify(headerPayload, sig []byte) (err error) { if sig, err = internal.DecodeToBytes(sig); err != nil { return err } diff --git a/hmac_sha_test.go b/hmac_sha_test.go index edfdcee..70f52b9 100644 --- a/hmac_sha_test.go +++ b/hmac_sha_test.go @@ -1,3 +1,6 @@ package jwt_test -var hmacKey = []byte("secret") +var ( + hmacKey1 = []byte("secret") + hmacKey2 = []byte("terces") +) diff --git a/internal/decode_test.go b/internal/decode_test.go index 1db91ef..6f172d7 100644 --- a/internal/decode_test.go +++ b/internal/decode_test.go @@ -26,7 +26,7 @@ func TestDecode(t *testing.T) { {rawURLEnc, "{}", "", false}, {rawURLEnc, `{"x":"test"}`, "test", false}, {stdEnc, "{}", "", true}, - {stdEnc, `{"x":"test"}`, "test", false}, + {stdEnc, `{"x":"test"}`, "test", false}, // the output is the same as with RawURLEncoding {nil, "{}", "", true}, {nil, `{"x":"test"}`, "", true}, } @@ -38,11 +38,10 @@ func TestDecode(t *testing.T) { } t.Logf("b64: %s", b64) var ( - dt decodeTest - err = internal.Decode([]byte(b64), &dt) - b64err = new(base64.CorruptInputError) + dt decodeTest + err = internal.Decode([]byte(b64), &dt) ) - if want, got := tc.errors, internal.ErrorAs(err, b64err); want != got { + if want, got := tc.errors, err != nil; want != got { t.Fatalf("want %t, got %t: %v", want, got, err) } if want, got := tc.expected, dt.X; want != got { diff --git a/internal/ed25519_go1_12.go b/internal/ed25519_go1_12.go new file mode 100644 index 0000000..3436a4c --- /dev/null +++ b/internal/ed25519_go1_12.go @@ -0,0 +1,18 @@ +// +build !go1.13 + +package internal + +import ( + "crypto/rand" + + "golang.org/x/crypto/ed25519" +) + +// GenerateEd25519Keys generates a pair of keys for testing purposes. +func GenerateEd25519Keys() (ed25519.PrivateKey, ed25519.PublicKey) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + return priv, pub +} diff --git a/internal/epoch.go b/internal/epoch.go new file mode 100644 index 0000000..494c09e --- /dev/null +++ b/internal/epoch.go @@ -0,0 +1,6 @@ +package internal + +import "time" + +// Epoch is 01/01/1970. +var Epoch = time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC) diff --git a/internal/error_go1_12.go b/internal/error_go1_12.go deleted file mode 100644 index 0d25232..0000000 --- a/internal/error_go1_12.go +++ /dev/null @@ -1,17 +0,0 @@ -// +build !go1.13 - -package internal - -import "golang.org/x/xerrors" - -// Errorf is a wrapper for xerrors.Errorf. -func Errorf(format string, a ...interface{}) error { return xerrors.Errorf(format, a...) } - -// ErrorAs is a wrapper for xerrors.As. -func ErrorAs(err error, target interface{}) bool { return xerrors.As(err, target) } - -// ErrorIs is a wrapper for xerrors.Is. -func ErrorIs(err, target error) bool { return xerrors.Is(err, target) } - -// NewError is a wrapper for xerrors.New. -func NewError(text string) error { return xerrors.New(text) } diff --git a/internal/rsa_signature_size.go b/internal/rsa_signature_size.go new file mode 100644 index 0000000..ed99e7b --- /dev/null +++ b/internal/rsa_signature_size.go @@ -0,0 +1,10 @@ +// +build go1.11 + +package internal + +import "crypto/rsa" + +// RSASignatureSize returns the signature size of an RSA signature. +func RSASignatureSize(pub *rsa.PublicKey) int { + return pub.Size() +} diff --git a/internal/rsa_signature_size_go1_10.go b/internal/rsa_signature_size_go1_10.go new file mode 100644 index 0000000..55a12aa --- /dev/null +++ b/internal/rsa_signature_size_go1_10.go @@ -0,0 +1,11 @@ +// +build !go1.11 + +package internal + +import "crypto/rsa" + +// RSASignatureSize returns the signature size of an RSA signature. +func RSASignatureSize(pub *rsa.PublicKey) int { + // As defined at https://golang.org/src/crypto/rsa/rsa.go?s=1609:1641#L39. + return (pub.N.BitLen() + 7) / 8 +} diff --git a/jwtutil/resolver.go b/jwtutil/resolver.go new file mode 100644 index 0000000..748291a --- /dev/null +++ b/jwtutil/resolver.go @@ -0,0 +1,46 @@ +package jwtutil + +import ( + "errors" + + "github.com/gbrlsnchs/jwt/v3" +) + +// Resolver is an Algorithm resolver. +type Resolver struct { + New func(jwt.Header) (jwt.Algorithm, error) + alg jwt.Algorithm +} + +// Name returns an Algorithm's name. +func (rv *Resolver) Name() string { + return rv.alg.Name() +} + +// Resolve sets an Algorithm based on a JOSE Header. +func (rv *Resolver) Resolve(hd jwt.Header) error { + if rv.alg != nil { + return nil + } + alg, err := rv.New(hd) + if err != nil { + return err + } + rv.alg = alg + return nil +} + +// Sign returns an error since Resolver doesn't support signing. +func (rv *Resolver) Sign(_ []byte) ([]byte, error) { + return nil, errors.New("jwtutil: Resolver can only verify") +} + +// Size returns an Algorithm's size. +func (rv *Resolver) Size() int { + return rv.alg.Size() +} + +// Verify resolves and Algorithm and verifies using it. +func (rv *Resolver) Verify(headerPayload, sig []byte) error { + return rv.alg.Verify(headerPayload, sig) +} diff --git a/jwtutil/resolver_test.go b/jwtutil/resolver_test.go new file mode 100644 index 0000000..f5e7b0a --- /dev/null +++ b/jwtutil/resolver_test.go @@ -0,0 +1,52 @@ +package jwtutil_test + +import ( + "errors" + "testing" + + "github.com/gbrlsnchs/jwt/v3" + "github.com/gbrlsnchs/jwt/v3/jwtutil" +) + +var hs256 = jwt.NewHS256([]byte("resolver")) + +func TestResolver(t *testing.T) { + testCases := []struct { + signer jwt.Algorithm + signOpts []jwt.SignOption + verifier jwt.Algorithm + }{ + { + signer: hs256, + verifier: &jwtutil.Resolver{ + New: func(hd jwt.Header) (jwt.Algorithm, error) { + return hs256, nil + }, + }, + }, + { + signer: hs256, + signOpts: []jwt.SignOption{jwt.KeyID("test")}, + verifier: &jwtutil.Resolver{ + New: func(hd jwt.Header) (jwt.Algorithm, error) { + if hd.KeyID != "test" { + return nil, errors.New(`wrong "kid"`) + } + return hs256, nil + }, + }, + }, + } + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + token, err := jwt.Sign(jwt.Payload{}, tc.signer, tc.signOpts...) + if err != nil { + t.Fatal(err) + } + var pl jwt.Payload + if _, err = jwt.Verify(token, tc.verifier, &pl); err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/payload.go b/payload.go index d072dbb..c8329a6 100644 --- a/payload.go +++ b/payload.go @@ -1,28 +1,12 @@ package jwt -var ( - _ Validator = new(Payload) - _ Validator = &struct{ Payload }{} - _ Validator = &struct{ *Payload }{} -) - // Payload is a JWT payload according to the RFC 7519. type Payload struct { Issuer string `json:"iss,omitempty"` Subject string `json:"sub,omitempty"` Audience Audience `json:"aud,omitempty"` - ExpirationTime int64 `json:"exp,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - IssuedAt int64 `json:"iat,omitempty"` + ExpirationTime *Time `json:"exp,omitempty"` + NotBefore *Time `json:"nbf,omitempty"` + IssuedAt *Time `json:"iat,omitempty"` JWTID string `json:"jti,omitempty"` } - -// Validate validates Payload claims. -func (p *Payload) Validate(funcs ...ValidatorFunc) error { - for _, fn := range funcs { - if err := fn(p); err != nil { - return err - } - } - return nil -} diff --git a/raw_token.go b/raw_token.go index c25ec03..e393350 100644 --- a/raw_token.go +++ b/raw_token.go @@ -1,40 +1,52 @@ package jwt -import "github.com/gbrlsnchs/jwt/v3/internal" +import ( + "errors" + + "github.com/gbrlsnchs/jwt/v3/internal" +) // ErrMalformed indicates a token doesn't have a valid format, as per the RFC 7519. -var ErrMalformed = internal.NewError("jwt: malformed token") +var ErrMalformed = errors.New("jwt: malformed token") // RawToken is a representation of a parsed JWT string. type RawToken struct { token []byte sep1, sep2 int - valid bool - hd Header -} + hd Header + alg Algorithm -// Decode decodes a raw JWT into a payload and returns its header. -func (raw RawToken) Decode(payload interface{}) error { - if !raw.valid { - return ErrMalformed - } - return internal.Decode(raw.payload(), payload) + pl *Payload + vds []Validator } -// Header returns a JOSE Header extracted from a JWT. -func (raw RawToken) Header() Header { - return raw.hd +func (rt *RawToken) header() []byte { return rt.token[:rt.sep1] } +func (rt *RawToken) headerPayload() []byte { return rt.token[:rt.sep2] } +func (rt *RawToken) payload() []byte { return rt.token[rt.sep1+1 : rt.sep2] } +func (rt *RawToken) sig() []byte { return rt.token[rt.sep2+1:] } + +func (rt *RawToken) setToken(token []byte, sep1, sep2 int) { + rt.sep1 = sep1 + rt.sep2 = sep1 + 1 + sep2 + rt.token = token } -func (raw RawToken) header() []byte { return raw.token[:raw.sep1] } -func (raw RawToken) headerPayload() []byte { return raw.token[:raw.sep2] } -func (raw RawToken) payload() []byte { return raw.token[raw.sep1+1 : raw.sep2] } -func (raw RawToken) sig() []byte { return raw.token[raw.sep2+1:] } +func (rt *RawToken) decode(payload interface{}) (err error) { + if err = internal.Decode(rt.payload(), payload); err != nil { + return err + } + for _, vd := range rt.vds { + if err = vd(rt.pl); err != nil { + return err + } + } + return nil +} -func (raw RawToken) withToken(token []byte, sep1, sep2 int) RawToken { - raw.sep1 = sep1 - raw.sep2 = sep1 + 1 + sep2 - raw.token = token - return raw +func (rt *RawToken) decodeHeader() error { + if err := internal.Decode(rt.header(), &rt.hd); err != nil { + return err + } + return nil } diff --git a/raw_token_test.go b/raw_token_test.go deleted file mode 100644 index 8693f0b..0000000 --- a/raw_token_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package jwt_test - -import ( - "reflect" - "testing" - - "github.com/gbrlsnchs/jwt/v3" -) - -var ( - testToken = []byte( - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." + - "eyJzdHJpbmciOiJmb29iYXIiLCJpbnQiOjEzMzcsImlhdCI6MTUxNjIzOTAyMn0." + - "bVYo9Q0lGouCj1y9zFY17bfxQaRUuM6wtpnIy0m4uD0", - ) - testRaw, _ = jwt.Verify(jwt.NewHS256([]byte("secret")), testToken) -) - -func TestRawTokenDecode(t *testing.T) { - testCases := []struct { - raw jwt.RawToken - wantPayload testPayload - }{ - { - raw: testRaw, - wantPayload: testPayload{ - String: "foobar", - Int: 1337, - Payload: jwt.Payload{ - IssuedAt: 1516239022, - }, - }, - }, - } - for _, tc := range testCases { - t.Run("", func(t *testing.T) { - var payload testPayload - err := tc.raw.Decode(&payload) - if err != nil { - t.Fatal(err) - } - if want, got := tc.wantPayload, payload; !reflect.DeepEqual(got, want) { - t.Errorf("want %#+v, got %#+v", want, got) - } - }) - } -} diff --git a/resolver.go b/resolver.go new file mode 100644 index 0000000..08f7767 --- /dev/null +++ b/resolver.go @@ -0,0 +1,7 @@ +package jwt + +// Resolver is an Algorithm that needs to set some variables +// based on a Header before performing signing and verification. +type Resolver interface { + Resolve(Header) error +} diff --git a/rsa_sha.go b/rsa_sha.go index eff9205..792fd34 100644 --- a/rsa_sha.go +++ b/rsa_sha.go @@ -14,11 +14,28 @@ var ( ErrRSANilPrivKey = errors.New("jwt: RSA private key is nil") // ErrRSANilPubKey is the error for trying to verify a JWT with a nil public key. ErrRSANilPubKey = errors.New("jwt: RSA public key is nil") + // ErrRSAVerification is the error for an invalid RSA signature. + ErrRSAVerification = errors.New("jwt: RSA verification failed") - _ Algorithm = new(rsaSHA) + _ Algorithm = new(RSASHA) ) -type rsaSHA struct { +// RSAPrivateKey is an option to set a private key to the RSA-SHA algorithm. +func RSAPrivateKey(priv *rsa.PrivateKey) func(*RSASHA) { + return func(rs *RSASHA) { + rs.priv = priv + } +} + +// RSAPublicKey is an option to set a public key to the RSA-SHA algorithm. +func RSAPublicKey(pub *rsa.PublicKey) func(*RSASHA) { + return func(rs *RSASHA) { + rs.pub = pub + } +} + +// RSASHA is an algorithm that uses RSA to sign SHA hashes. +type RSASHA struct { name string priv *rsa.PrivateKey pub *rsa.PublicKey @@ -28,64 +45,65 @@ type rsaSHA struct { opts *rsa.PSSOptions } -func newRSASHA(name string, priv *rsa.PrivateKey, pub *rsa.PublicKey, sha crypto.Hash, pss bool) *rsaSHA { - if pub == nil { - pub = &priv.PublicKey - } - rs := &rsaSHA{ +func newRSASHA(name string, opts []func(*RSASHA), sha crypto.Hash, pss bool) *RSASHA { + rs := RSASHA{ name: name, // cache name - priv: priv, - pub: pub, sha: sha, - size: pub.Size(), // cache size pool: newHashPool(sha.New), } + for _, opt := range opts { + opt(&rs) + } + if rs.pub == nil { + rs.pub = &rs.priv.PublicKey + } + rs.size = internal.RSASignatureSize(rs.pub) // cache size if pss { rs.opts = &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, Hash: sha, } } - return rs + return &rs } // NewRS256 creates a new algorithm using RSA and SHA-256. -func NewRS256(priv *rsa.PrivateKey, pub *rsa.PublicKey) Algorithm { - return newRSASHA("RS256", priv, pub, crypto.SHA256, false) +func NewRS256(opts ...func(*RSASHA)) *RSASHA { + return newRSASHA("RS256", opts, crypto.SHA256, false) } // NewRS384 creates a new algorithm using RSA and SHA-384. -func NewRS384(priv *rsa.PrivateKey, pub *rsa.PublicKey) Algorithm { - return newRSASHA("RS384", priv, pub, crypto.SHA384, false) +func NewRS384(opts ...func(*RSASHA)) *RSASHA { + return newRSASHA("RS384", opts, crypto.SHA384, false) } // NewRS512 creates a new algorithm using RSA and SHA-512. -func NewRS512(priv *rsa.PrivateKey, pub *rsa.PublicKey) Algorithm { - return newRSASHA("RS512", priv, pub, crypto.SHA512, false) +func NewRS512(opts ...func(*RSASHA)) *RSASHA { + return newRSASHA("RS512", opts, crypto.SHA512, false) } // NewPS256 creates a new algorithm using RSA-PSS and SHA-256. -func NewPS256(priv *rsa.PrivateKey, pub *rsa.PublicKey) Algorithm { - return newRSASHA("PS256", priv, pub, crypto.SHA256, true) +func NewPS256(opts ...func(*RSASHA)) *RSASHA { + return newRSASHA("PS256", opts, crypto.SHA256, true) } // NewPS384 creates a new algorithm using RSA-PSS and SHA-384. -func NewPS384(priv *rsa.PrivateKey, pub *rsa.PublicKey) Algorithm { - return newRSASHA("PS384", priv, pub, crypto.SHA384, true) +func NewPS384(opts ...func(*RSASHA)) *RSASHA { + return newRSASHA("PS384", opts, crypto.SHA384, true) } // NewPS512 creates a new algorithm using RSA-PSS and SHA-512. -func NewPS512(priv *rsa.PrivateKey, pub *rsa.PublicKey) Algorithm { - return newRSASHA("PS512", priv, pub, crypto.SHA512, true) +func NewPS512(opts ...func(*RSASHA)) *RSASHA { + return newRSASHA("PS512", opts, crypto.SHA512, true) } // Name returns the algorithm's name. -func (rs *rsaSHA) Name() string { +func (rs *RSASHA) Name() string { return rs.name } // Sign signs headerPayload using either RSA-SHA or RSA-PSS-SHA algorithms. -func (rs *rsaSHA) Sign(headerPayload []byte) ([]byte, error) { +func (rs *RSASHA) Sign(headerPayload []byte) ([]byte, error) { if rs.priv == nil { return nil, ErrRSANilPrivKey } @@ -100,15 +118,12 @@ func (rs *rsaSHA) Sign(headerPayload []byte) ([]byte, error) { } // Size returns the signature's byte size. -func (rs *rsaSHA) Size() int { - if rs.pub == nil { - return 0 - } - return rs.pub.Size() +func (rs *RSASHA) Size() int { + return rs.size } // Verify verifies a signature based on headerPayload using either RSA-SHA or RSA-PSS-SHA. -func (rs *rsaSHA) Verify(headerPayload, sig []byte) (err error) { +func (rs *RSASHA) Verify(headerPayload, sig []byte) (err error) { if rs.pub == nil { return ErrRSANilPubKey } @@ -120,7 +135,12 @@ func (rs *rsaSHA) Verify(headerPayload, sig []byte) (err error) { return err } if rs.opts != nil { - return rsa.VerifyPSS(rs.pub, rs.sha, sum, sig, rs.opts) + err = rsa.VerifyPSS(rs.pub, rs.sha, sum, sig, rs.opts) + } else { + err = rsa.VerifyPKCS1v15(rs.pub, rs.sha, sum, sig) + } + if err != nil { + return ErrRSAVerification } - return rsa.VerifyPKCS1v15(rs.pub, rs.sha, sum, sig) + return nil } diff --git a/sign.go b/sign.go index b7de417..6db5b36 100644 --- a/sign.go +++ b/sign.go @@ -5,8 +5,34 @@ import ( "encoding/json" ) -// Sign generates a JWT from hd and payload and signs it with alg. -func Sign(alg Algorithm, hd Header, payload interface{}) ([]byte, error) { +// SignOption is a functional option for signing. +type SignOption func(*Header) + +// ContentType sets the "cty" claim for a Header before signing. +func ContentType(cty string) SignOption { + return func(hd *Header) { + hd.ContentType = cty + } +} + +// KeyID sets the "kid" claim for a Header before signing. +func KeyID(kid string) SignOption { + return func(hd *Header) { + hd.KeyID = kid + } +} + +// Sign signs a payload with alg. +func Sign(payload interface{}, alg Algorithm, opts ...SignOption) ([]byte, error) { + var hd Header + for _, opt := range opts { + opt(&hd) + } + if rv, ok := alg.(Resolver); ok { + if err := rv.Resolve(hd); err != nil { + return nil, err + } + } // Override some values or set them if empty. hd.Algorithm = alg.Name() hd.Type = "JWT" @@ -15,6 +41,10 @@ func Sign(alg Algorithm, hd Header, payload interface{}) ([]byte, error) { if err != nil { return nil, err } + + if payload == nil { + payload = Payload{} + } // Marshal the claims part of the JWT. pb, err := json.Marshal(payload) if err != nil { diff --git a/sign_test.go b/sign_test.go index a6bb90e..1a4976b 100644 --- a/sign_test.go +++ b/sign_test.go @@ -6,7 +6,6 @@ import ( "time" "github.com/gbrlsnchs/jwt/v3" - "github.com/gbrlsnchs/jwt/v3/internal" ) type testPayload struct { @@ -15,23 +14,30 @@ type testPayload struct { Int int `json:"int,omitempty"` } -var tp = testPayload{ - Payload: jwt.Payload{ - Subject: "test", - Audience: jwt.Audience{"github.com", "gsr.dev"}, - NotBefore: time.Now().Unix(), - }, - String: "foobar", - Int: 1337, -} +var ( + now = time.Now() + tp = testPayload{ + Payload: jwt.Payload{ + Issuer: "gbrlsnchs", + Subject: "someone", + Audience: jwt.Audience{"https://golang.org", "https://jwt.io"}, + ExpirationTime: jwt.NumericDate(now.Add(24 * 30 * 12 * time.Hour)), + NotBefore: jwt.NumericDate(now.Add(30 * time.Minute)), + IssuedAt: jwt.NumericDate(now), + JWTID: "foobar", + }, + String: "foobar", + Int: 1337, + } +) func TestSign(t *testing.T) { type testCase struct { alg jwt.Algorithm - hd jwt.Header payload interface{} verifyAlg jwt.Algorithm + opts []func(*jwt.RawToken) wantHeader jwt.Header wantPayload testPayload @@ -41,10 +47,9 @@ func TestSign(t *testing.T) { testCases := map[string][]testCase{ "HMAC": []testCase{ { - alg: jwt.NewHS256(hmacKey), - hd: jwt.Header{}, + alg: jwt.NewHS256(hmacKey1), payload: tp, - verifyAlg: jwt.NewHS256(hmacKey), + verifyAlg: jwt.NewHS256(hmacKey1), wantHeader: jwt.Header{ Algorithm: "HS256", Type: "JWT", @@ -54,10 +59,33 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewHS384(hmacKey), - hd: jwt.Header{}, + alg: jwt.NewHS256(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS256(hmacKey2), + wantHeader: jwt.Header{ + Algorithm: "HS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS256(hmacKey1), payload: tp, - verifyAlg: jwt.NewHS384(hmacKey), + verifyAlg: jwt.NewHS384(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS384(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS384(hmacKey1), wantHeader: jwt.Header{ Algorithm: "HS384", Type: "JWT", @@ -67,10 +95,33 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewHS512(hmacKey), - hd: jwt.Header{}, + alg: jwt.NewHS384(hmacKey1), payload: tp, - verifyAlg: jwt.NewHS512(hmacKey), + verifyAlg: jwt.NewHS384(hmacKey2), + wantHeader: jwt.Header{ + Algorithm: "HS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS384(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS256(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS512(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS512(hmacKey1), wantHeader: jwt.Header{ Algorithm: "HS512", Type: "JWT", @@ -79,13 +130,36 @@ func TestSign(t *testing.T) { signErr: nil, verifyErr: nil, }, + { + alg: jwt.NewHS512(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS512(hmacKey2), + wantHeader: jwt.Header{ + Algorithm: "HS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS512(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS256(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, }, "RSA": []testCase{ { - alg: jwt.NewRS256(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewRS256(rsaPrivateKey1, nil), + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), wantHeader: jwt.Header{ Algorithm: "RS256", Type: "JWT", @@ -95,10 +169,33 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewRS256(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewRS256(nil, rsaPublicKey1), + verifyAlg: jwt.NewRS256(jwt.RSAPublicKey(rsaPublicKey1)), wantHeader: jwt.Header{ Algorithm: "RS256", Type: "JWT", @@ -108,10 +205,21 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewRS384(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewRS384(rsaPrivateKey1, nil), + verifyAlg: jwt.NewRS256(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), wantHeader: jwt.Header{ Algorithm: "RS384", Type: "JWT", @@ -121,10 +229,33 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewRS384(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewRS384(nil, rsaPublicKey1), + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPublicKey(rsaPublicKey1)), wantHeader: jwt.Header{ Algorithm: "RS384", Type: "JWT", @@ -134,10 +265,21 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewRS512(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewRS512(rsaPrivateKey1, nil), + verifyAlg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), wantHeader: jwt.Header{ Algorithm: "RS512", Type: "JWT", @@ -147,10 +289,33 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewRS512(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewRS512(nil, rsaPublicKey1), + verifyAlg: jwt.NewRS512(jwt.RSAPublicKey(rsaPublicKey1)), wantHeader: jwt.Header{ Algorithm: "RS512", Type: "JWT", @@ -159,13 +324,24 @@ func TestSign(t *testing.T) { signErr: nil, verifyErr: nil, }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS512(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, }, "RSA-PSS": []testCase{ { - alg: jwt.NewPS256(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewPS256(rsaPrivateKey1, nil), + verifyAlg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), wantHeader: jwt.Header{ Algorithm: "PS256", Type: "JWT", @@ -175,10 +351,45 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewPS256(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewPS256(nil, rsaPublicKey1), + verifyAlg: jwt.NewPS256(jwt.RSAPublicKey(rsaPublicKey1)), wantHeader: jwt.Header{ Algorithm: "PS256", Type: "JWT", @@ -188,10 +399,21 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewPS384(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS256(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewPS384(rsaPrivateKey1, nil), + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), wantHeader: jwt.Header{ Algorithm: "PS384", Type: "JWT", @@ -201,10 +423,45 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewPS384(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewPS384(nil, rsaPublicKey1), + verifyAlg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPublicKey(rsaPublicKey1)), wantHeader: jwt.Header{ Algorithm: "PS384", Type: "JWT", @@ -214,10 +471,21 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewPS512(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewPS512(rsaPrivateKey1, nil), + verifyAlg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), wantHeader: jwt.Header{ Algorithm: "PS512", Type: "JWT", @@ -227,10 +495,45 @@ func TestSign(t *testing.T) { verifyErr: nil, }, { - alg: jwt.NewPS512(rsaPrivateKey1, nil), - hd: jwt.Header{}, + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), payload: tp, - verifyAlg: jwt.NewPS512(nil, rsaPublicKey1), + verifyAlg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS512(jwt.RSAPublicKey(rsaPublicKey1)), wantHeader: jwt.Header{ Algorithm: "PS512", Type: "JWT", @@ -239,36 +542,275 @@ func TestSign(t *testing.T) { signErr: nil, verifyErr: nil, }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS512(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + }, + "ECDSA": []testCase{ + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPublicKey(es256PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es256PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPublicKey(es256PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es384PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPublicKey(es384PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es384PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPublicKey(es512PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es512PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPublicKey(es512PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + }, + "Ed25519": []testCase{ + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PublicKey(ed25519PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrEd25519Verification, + }, + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PublicKey(ed25519PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrEd25519Verification, + }, }, - "ECDSA": []testCase{}, - "Ed25519": []testCase{}, } for k, v := range testCases { t.Run(k, func(t *testing.T) { for _, tc := range v { t.Run(tc.alg.Name(), func(t *testing.T) { - token, err := jwt.Sign(tc.alg, tc.hd, tc.payload) - if want, got := tc.signErr, err; !internal.ErrorIs(got, want) { + token, err := jwt.Sign(tc.payload, tc.alg) + if want, got := tc.signErr, err; got != want { t.Fatalf("want %v, got %v", want, got) } if err != nil { return } - raw, err := jwt.Verify(tc.verifyAlg, token) - if want, got := tc.verifyErr, err; !internal.ErrorIs(got, want) { + var ( + hd jwt.Header + payload testPayload + ) + hd, err = jwt.Verify(token, tc.verifyAlg, &payload) + if want, got := tc.verifyErr, err; got != want { t.Fatalf("want %v, got %v", want, got) } - if err != nil { - return - } - if want, got := tc.wantHeader, raw.Header(); !reflect.DeepEqual(got, want) { + if want, got := tc.wantHeader, hd; !reflect.DeepEqual(got, want) { t.Errorf("want %#+v, got %#+v", want, got) } - var payload testPayload - if err = raw.Decode(&payload); err != nil { - t.Fatal(err) - } if want, got := tc.wantPayload, payload; !reflect.DeepEqual(got, want) { t.Errorf("want %#+v, got %#+v", want, got) } diff --git a/time.go b/time.go new file mode 100644 index 0000000..ed3bbbb --- /dev/null +++ b/time.go @@ -0,0 +1,46 @@ +package jwt + +import ( + "encoding/json" + "time" + + "github.com/gbrlsnchs/jwt/v3/internal" +) + +// Time is the allowed format for time, as per the RFC 7519. +type Time struct { + time.Time +} + +// NumericDate is a resolved Unix time. +func NumericDate(tt time.Time) *Time { + if tt.Before(internal.Epoch) { + tt = internal.Epoch + } + return &Time{time.Unix(tt.Unix(), 0)} // set time using Unix time +} + +// MarshalJSON implements a marshaling function for time-related claims. +func (t Time) MarshalJSON() ([]byte, error) { + if t.Before(internal.Epoch) { + return json.Marshal(0) + } + return json.Marshal(t.Unix()) +} + +// UnmarshalJSON implements an unmarshaling function for time-related claims. +func (t *Time) UnmarshalJSON(b []byte) error { + var unix *int64 + if err := json.Unmarshal(b, &unix); err != nil { + return err + } + if unix == nil { + return nil + } + tt := time.Unix(*unix, 0) + if tt.Before(internal.Epoch) { + tt = internal.Epoch + } + t.Time = tt + return nil +} diff --git a/time_test.go b/time_test.go new file mode 100644 index 0000000..eccef04 --- /dev/null +++ b/time_test.go @@ -0,0 +1,72 @@ +package jwt_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/gbrlsnchs/jwt/v3" + "github.com/gbrlsnchs/jwt/v3/internal" +) + +func TestTimeMarshalJSON(t *testing.T) { + now := time.Now() + testCases := []struct { + tt jwt.Time + want int64 + }{ + {jwt.Time{}, 0}, + {jwt.Time{now}, now.Unix()}, + {jwt.Time{now.Add(24 * time.Hour)}, now.Add(24 * time.Hour).Unix()}, + {jwt.Time{now.Add(24 * 30 * 12 * time.Hour)}, now.Add(24 * 30 * 12 * time.Hour).Unix()}, + } + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + b, err := tc.tt.MarshalJSON() + if err != nil { + t.Fatal(err) + } + var n int64 + if err = json.Unmarshal(b, &n); err != nil { + t.Fatal(err) + } + if want, got := tc.want, n; got != want { + t.Errorf("want %d, got %d", want, got) + } + }) + } +} + +func TestTimeUnmarshalJSON(t *testing.T) { + now := time.Now() + testCases := []struct { + n int64 + want jwt.Time + isNil bool + }{ + {now.Unix(), jwt.Time{now}, false}, + {internal.Epoch.Unix() - 1337, jwt.Time{internal.Epoch}, false}, + {internal.Epoch.Unix(), jwt.Time{internal.Epoch}, false}, + {internal.Epoch.Unix() + 1337, jwt.Time{internal.Epoch.Add(1337 * time.Second)}, false}, + {0, jwt.Time{}, true}, + } + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + var n *int64 + if !tc.isNil { + n = &tc.n + } + b, err := json.Marshal(n) + if err != nil { + t.Fatal(err) + } + var tt jwt.Time + if err = tt.UnmarshalJSON(b); err != nil { + t.Fatal(err) + } + if want, got := tc.want, tt; got.Unix() != want.Unix() { + t.Errorf("want %d, got %d", want.Unix(), got.Unix()) + } + }) + } +} diff --git a/validator.go b/validator.go deleted file mode 100644 index 35064ca..0000000 --- a/validator.go +++ /dev/null @@ -1,5 +0,0 @@ -package jwt - -type Validator interface { - Validate(...ValidatorFunc) error -} diff --git a/validators.go b/validators.go index d9353c1..f16b90f 100644 --- a/validators.go +++ b/validators.go @@ -22,16 +22,15 @@ var ( ErrSubValidation = errors.New("jwt: sub claim is invalid") ) -// ValidatorFunc is a function for running extra -// validators when parsing a Payload string. -type ValidatorFunc func(*Payload) error +// Validator is a function that validates a Payload pointer. +type Validator func(*Payload) error // AudienceValidator validates the "aud" claim. // It checks if at least one of the audiences in the JWT's payload is listed in aud. -func AudienceValidator(aud Audience) ValidatorFunc { - return func(p *Payload) error { +func AudienceValidator(aud Audience) Validator { + return func(pl *Payload) error { for _, serverAud := range aud { - for _, clientAud := range p.Audience { + for _, clientAud := range pl.Audience { if clientAud == serverAud { return nil } @@ -42,13 +41,9 @@ func AudienceValidator(aud Audience) ValidatorFunc { } // ExpirationTimeValidator validates the "exp" claim. -func ExpirationTimeValidator(now time.Time, validateZero bool) ValidatorFunc { - return func(p *Payload) error { - expint := p.ExpirationTime - if !validateZero && expint == 0 { - return nil - } - if exp := time.Unix(expint, 0); now.After(exp) { +func ExpirationTimeValidator(now time.Time) Validator { + return func(pl *Payload) error { + if pl.ExpirationTime == nil || NumericDate(now).After(pl.ExpirationTime.Time) { return ErrExpValidation } return nil @@ -56,9 +51,9 @@ func ExpirationTimeValidator(now time.Time, validateZero bool) ValidatorFunc { } // IssuedAtValidator validates the "iat" claim. -func IssuedAtValidator(now time.Time) ValidatorFunc { - return func(p *Payload) error { - if iat := time.Unix(p.IssuedAt, 0); now.Before(iat) { +func IssuedAtValidator(now time.Time) Validator { + return func(pl *Payload) error { + if pl.IssuedAt != nil && NumericDate(now).Before(pl.IssuedAt.Time) { return ErrIatValidation } return nil @@ -66,9 +61,9 @@ func IssuedAtValidator(now time.Time) ValidatorFunc { } // IssuerValidator validates the "iss" claim. -func IssuerValidator(iss string) ValidatorFunc { - return func(p *Payload) error { - if p.Issuer != iss { +func IssuerValidator(iss string) Validator { + return func(pl *Payload) error { + if pl.Issuer != iss { return ErrIssValidation } return nil @@ -76,9 +71,9 @@ func IssuerValidator(iss string) ValidatorFunc { } // IDValidator validates the "jti" claim. -func IDValidator(jti string) ValidatorFunc { - return func(p *Payload) error { - if p.JWTID != jti { +func IDValidator(jti string) Validator { + return func(pl *Payload) error { + if pl.JWTID != jti { return ErrJtiValidation } return nil @@ -86,9 +81,9 @@ func IDValidator(jti string) ValidatorFunc { } // NotBeforeValidator validates the "nbf" claim. -func NotBeforeValidator(now time.Time) ValidatorFunc { - return func(p *Payload) error { - if nbf := time.Unix(p.NotBefore, 0); now.Before(nbf) { +func NotBeforeValidator(now time.Time) Validator { + return func(pl *Payload) error { + if pl.NotBefore != nil && NumericDate(now).Before(pl.NotBefore.Time) { return ErrNbfValidation } return nil @@ -96,9 +91,9 @@ func NotBeforeValidator(now time.Time) ValidatorFunc { } // SubjectValidator validates the "sub" claim. -func SubjectValidator(sub string) ValidatorFunc { - return func(p *Payload) error { - if p.Subject != sub { +func SubjectValidator(sub string) Validator { + return func(pl *Payload) error { + if pl.Subject != sub { return ErrSubValidation } return nil diff --git a/validators_test.go b/validators_test.go index 09e1b40..69ddee1 100644 --- a/validators_test.go +++ b/validators_test.go @@ -1,65 +1,55 @@ package jwt_test import ( - "reflect" - "runtime" - "strings" "testing" "time" - . "github.com/gbrlsnchs/jwt/v3" + "github.com/gbrlsnchs/jwt/v3" ) func TestValidators(t *testing.T) { now := time.Now() - iat := now.Unix() - exp := now.Add(24 * time.Hour).Unix() - nbf := now.Add(15 * time.Second).Unix() + iat := jwt.NumericDate(now) + exp := jwt.NumericDate(now.Add(24 * time.Hour)) + nbf := jwt.NumericDate(now.Add(15 * time.Second)) jti := "jti" - aud := Audience{"aud", "aud1", "aud2", "aud3"} + aud := jwt.Audience{"aud", "aud1", "aud2", "aud3"} sub := "sub" iss := "iss" testCases := []struct { - p Payload - vl ValidatorFunc - err error + claim string + pl *jwt.Payload + vl jwt.Validator + err error }{ - {Payload{Issuer: iss}, IssuerValidator("iss"), nil}, - {Payload{Issuer: iss}, IssuerValidator("not_iss"), ErrIssValidation}, - {Payload{Subject: sub}, SubjectValidator("sub"), nil}, - {Payload{Subject: sub}, SubjectValidator("not_sub"), ErrSubValidation}, - {Payload{Audience: aud}, AudienceValidator(Audience{"aud"}), nil}, - {Payload{Audience: aud}, AudienceValidator(Audience{"foo", "aud1"}), nil}, - {Payload{Audience: aud}, AudienceValidator(Audience{"bar", "aud2"}), nil}, - {Payload{Audience: aud}, AudienceValidator(Audience{"baz", "aud3"}), nil}, - {Payload{Audience: aud}, AudienceValidator(Audience{"qux", "aud4"}), ErrAudValidation}, - {Payload{Audience: aud}, AudienceValidator(Audience{"not_aud"}), ErrAudValidation}, - {Payload{ExpirationTime: exp}, ExpirationTimeValidator(now, true), nil}, - {Payload{ExpirationTime: exp}, ExpirationTimeValidator(now, false), nil}, - {Payload{ExpirationTime: exp}, ExpirationTimeValidator(time.Unix(now.Unix()-int64(24*time.Hour), 0), true), nil}, - {Payload{ExpirationTime: exp}, ExpirationTimeValidator(time.Unix(now.Unix()-int64(24*time.Hour), 0), false), nil}, - {Payload{ExpirationTime: exp}, ExpirationTimeValidator(time.Unix(now.Unix()+int64(24*time.Hour), 0), true), ErrExpValidation}, - {Payload{ExpirationTime: exp}, ExpirationTimeValidator(time.Unix(now.Unix()+int64(24*time.Hour), 0), false), ErrExpValidation}, - {Payload{}, ExpirationTimeValidator(time.Now(), false), nil}, - {Payload{}, ExpirationTimeValidator(time.Now(), true), ErrExpValidation}, - {Payload{NotBefore: nbf}, NotBeforeValidator(now), ErrNbfValidation}, - {Payload{NotBefore: nbf}, NotBeforeValidator(time.Unix(now.Unix()+int64(15*time.Second), 0)), nil}, - {Payload{NotBefore: nbf}, NotBeforeValidator(time.Unix(now.Unix()-int64(15*time.Second), 0)), ErrNbfValidation}, - {Payload{}, NotBeforeValidator(time.Now()), nil}, - {Payload{IssuedAt: iat}, IssuedAtValidator(now), nil}, - {Payload{IssuedAt: iat}, IssuedAtValidator(time.Unix(now.Unix()+1, 0)), nil}, - {Payload{IssuedAt: iat}, IssuedAtValidator(time.Unix(now.Unix()-1, 0)), ErrIatValidation}, - {Payload{}, IssuedAtValidator(time.Now()), nil}, - {Payload{JWTID: jti}, IDValidator("jti"), nil}, - {Payload{JWTID: jti}, IDValidator("not_jti"), ErrJtiValidation}, + {"iss", &jwt.Payload{Issuer: iss}, jwt.IssuerValidator("iss"), nil}, + {"iss", &jwt.Payload{Issuer: iss}, jwt.IssuerValidator("not_iss"), jwt.ErrIssValidation}, + {"sub", &jwt.Payload{Subject: sub}, jwt.SubjectValidator("sub"), nil}, + {"sub", &jwt.Payload{Subject: sub}, jwt.SubjectValidator("not_sub"), jwt.ErrSubValidation}, + {"aud", &jwt.Payload{Audience: aud}, jwt.AudienceValidator(jwt.Audience{"aud"}), nil}, + {"aud", &jwt.Payload{Audience: aud}, jwt.AudienceValidator(jwt.Audience{"foo", "aud1"}), nil}, + {"aud", &jwt.Payload{Audience: aud}, jwt.AudienceValidator(jwt.Audience{"bar", "aud2"}), nil}, + {"aud", &jwt.Payload{Audience: aud}, jwt.AudienceValidator(jwt.Audience{"baz", "aud3"}), nil}, + {"aud", &jwt.Payload{Audience: aud}, jwt.AudienceValidator(jwt.Audience{"qux", "aud4"}), jwt.ErrAudValidation}, + {"aud", &jwt.Payload{Audience: aud}, jwt.AudienceValidator(jwt.Audience{"not_aud"}), jwt.ErrAudValidation}, + {"exp", &jwt.Payload{ExpirationTime: exp}, jwt.ExpirationTimeValidator(now), nil}, + {"exp", &jwt.Payload{ExpirationTime: exp}, jwt.ExpirationTimeValidator(time.Unix(now.Unix()-int64(24*time.Hour), 0)), nil}, + {"exp", &jwt.Payload{ExpirationTime: exp}, jwt.ExpirationTimeValidator(time.Unix(now.Unix()+int64(24*time.Hour), 0)), jwt.ErrExpValidation}, + {"exp", &jwt.Payload{}, jwt.ExpirationTimeValidator(time.Now()), jwt.ErrExpValidation}, + {"nbf", &jwt.Payload{NotBefore: nbf}, jwt.NotBeforeValidator(now), jwt.ErrNbfValidation}, + {"nbf", &jwt.Payload{NotBefore: nbf}, jwt.NotBeforeValidator(time.Unix(now.Unix()+int64(15*time.Second), 0)), nil}, + {"nbf", &jwt.Payload{NotBefore: nbf}, jwt.NotBeforeValidator(time.Unix(now.Unix()-int64(15*time.Second), 0)), jwt.ErrNbfValidation}, + {"nbf", &jwt.Payload{}, jwt.NotBeforeValidator(time.Now()), nil}, + {"iat", &jwt.Payload{IssuedAt: iat}, jwt.IssuedAtValidator(now), nil}, + {"iat", &jwt.Payload{IssuedAt: iat}, jwt.IssuedAtValidator(time.Unix(now.Unix()+1, 0)), nil}, + {"iat", &jwt.Payload{IssuedAt: iat}, jwt.IssuedAtValidator(time.Unix(now.Unix()-1, 0)), jwt.ErrIatValidation}, + {"iat", &jwt.Payload{}, jwt.IssuedAtValidator(time.Now()), nil}, + {"jti", &jwt.Payload{JWTID: jti}, jwt.IDValidator("jti"), nil}, + {"jti", &jwt.Payload{JWTID: jti}, jwt.IDValidator("not_jti"), jwt.ErrJtiValidation}, } for _, tc := range testCases { - fn := runtime.FuncForPC(reflect.ValueOf(tc.vl).Pointer()) - name := fn.Name()[:] - name = strings.TrimPrefix(name, "github.com/gbrlsnchs/jwt/v3.") - name = strings.TrimSuffix(name, ".func1") - t.Run(name, func(t *testing.T) { - if want, got := tc.err, tc.vl(&tc.p); want != got { + t.Run(tc.claim, func(t *testing.T) { + if want, got := tc.err, tc.vl(tc.pl); want != got { t.Errorf("want %v, got %v", want, got) } }) diff --git a/verify.go b/verify.go index 2edac85..ecbd9a8 100644 --- a/verify.go +++ b/verify.go @@ -2,36 +2,71 @@ package jwt import ( "bytes" - - "github.com/gbrlsnchs/jwt/v3/internal" + "errors" ) // ErrAlgValidation indicates an incoming JWT's "alg" field mismatches the Validator's. -var ErrAlgValidation = internal.NewError(`"alg" field mismatch`) +var ErrAlgValidation = errors.New(`"alg" field mismatch`) + +// VerifyOption is a functional option for verifying. +type VerifyOption func(*RawToken) error -// Verify verifies a token's signature. -func Verify(alg Algorithm, token []byte) (RawToken, error) { - var raw RawToken +// Verify verifies a token's signature using alg. Before verification, opts is iterated and +// each option in it is run. +func Verify(token []byte, alg Algorithm, payload interface{}, opts ...VerifyOption) (Header, error) { + rt := &RawToken{ + alg: alg, + } sep1 := bytes.IndexByte(token, '.') if sep1 < 0 { - return raw, ErrMalformed + return rt.hd, ErrMalformed } cbytes := token[sep1+1:] sep2 := bytes.IndexByte(cbytes, '.') if sep2 < 0 { - return raw, ErrMalformed + return rt.hd, ErrMalformed } - raw = raw.withToken(token, sep1, sep2) + rt.setToken(token, sep1, sep2) - if err := internal.Decode(raw.header(), &raw.hd); err != nil { - return raw, err + var err error + if err = rt.decodeHeader(); err != nil { + return rt.hd, err + } + if rv, ok := alg.(Resolver); ok { + if err = rv.Resolve(rt.hd); err != nil { + return rt.hd, err + } + } + for _, opt := range opts { + if err = opt(rt); err != nil { + return rt.hd, err + } + } + if err = alg.Verify(rt.headerPayload(), rt.sig()); err != nil { + return rt.hd, err } - raw.valid = true + return rt.hd, rt.decode(payload) +} - if alg.Name() != raw.hd.Algorithm { - return raw, internal.Errorf("jwt: unexpected algorithm %q: %w", raw.hd.Algorithm, ErrAlgValidation) +// ValidateHeader checks whether the algorithm contained +// in the JOSE header is the same used by the algorithm. +func ValidateHeader(rt *RawToken) error { + if rt.alg.Name() != rt.hd.Algorithm { + return ErrAlgValidation } - return raw, alg.Verify(raw.headerPayload(), raw.sig()) + return nil } + +// ValidatePayload runs validators against a Payload after it's been decoded. +func ValidatePayload(pl *Payload, vds ...Validator) VerifyOption { + return func(rt *RawToken) error { + rt.pl = pl + rt.vds = vds + return nil + } +} + +// Compile-time checks. +var _ VerifyOption = ValidateHeader diff --git a/verify_test.go b/verify_test.go new file mode 100644 index 0000000..0bd4cee --- /dev/null +++ b/verify_test.go @@ -0,0 +1,790 @@ +package jwt_test + +import ( + "reflect" + "testing" + + "github.com/gbrlsnchs/jwt/v3" +) + +func TestVerify(t *testing.T) { + type testCase struct { + alg jwt.Algorithm + payload interface{} + + verifyAlg jwt.Algorithm + opts []func(*jwt.RawToken) + wantHeader jwt.Header + wantPayload testPayload + + signErr error + verifyErr error + } + testCases := map[string][]testCase{ + "HMAC": []testCase{ + { + alg: jwt.NewHS256(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS256(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewHS256(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS256(hmacKey2), + wantHeader: jwt.Header{ + Algorithm: "HS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS256(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS384(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS384(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS384(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewHS384(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS384(hmacKey2), + wantHeader: jwt.Header{ + Algorithm: "HS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS384(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS256(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS512(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS512(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewHS512(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS512(hmacKey2), + wantHeader: jwt.Header{ + Algorithm: "HS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + { + alg: jwt.NewHS512(hmacKey1), + payload: tp, + verifyAlg: jwt.NewHS256(hmacKey1), + wantHeader: jwt.Header{ + Algorithm: "HS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrHMACVerification, + }, + }, + "RSA": []testCase{ + { + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPublicKey(rsaPublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPublicKey(rsaPublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS512(jwt.RSAPublicKey(rsaPublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS512(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "RS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + }, + "RSA-PSS": []testCase{ + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS256(jwt.RSAPublicKey(rsaPublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS256(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS256(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPublicKey(rsaPublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewRS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS384(jwt.RSAPrivateKey(rsaPrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS512(jwt.RSAPublicKey(rsaPublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewPS512(jwt.RSAPrivateKey(rsaPrivateKey1)), + payload: tp, + verifyAlg: jwt.NewPS512(jwt.RSAPublicKey(rsaPublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "PS512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrRSAVerification, + }, + }, + "ECDSA": []testCase{ + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPublicKey(es256PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es256PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPublicKey(es256PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPrivateKey(es256PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES256", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es384PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES256(jwt.ECDSAPublicKey(es384PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es384PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPrivateKey(es384PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES384", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPublicKey(es512PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES384(jwt.ECDSAPublicKey(es512PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPublicKey(es512PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewES512(jwt.ECDSAPrivateKey(es512PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "ES512", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrECDSAVerification, + }, + }, + "Ed25519": []testCase{ + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PublicKey(ed25519PublicKey1)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: tp, + signErr: nil, + verifyErr: nil, + }, + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey2)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrEd25519Verification, + }, + { + alg: jwt.NewEd25519(jwt.Ed25519PrivateKey(ed25519PrivateKey1)), + payload: tp, + verifyAlg: jwt.NewEd25519(jwt.Ed25519PublicKey(ed25519PublicKey2)), + wantHeader: jwt.Header{ + Algorithm: "Ed25519", + Type: "JWT", + }, + wantPayload: testPayload{}, + signErr: nil, + verifyErr: jwt.ErrEd25519Verification, + }, + }, + } + for k, v := range testCases { + t.Run(k, func(t *testing.T) { + for _, tc := range v { + t.Run(tc.verifyAlg.Name(), func(t *testing.T) { + token, err := jwt.Sign(tc.payload, tc.alg) + if err != nil { + t.Fatal(err) + } + var pl testPayload + hd, err := jwt.Verify(token, tc.verifyAlg, &pl) + if want, got := tc.verifyErr, err; got != want { + t.Errorf("want %v, got %v", want, got) + } + if want, got := tc.wantHeader, hd; !reflect.DeepEqual(got, want) { + t.Errorf("want %#+v, got %#+v", want, got) + } + if want, got := tc.wantPayload, pl; !reflect.DeepEqual(got, want) { + t.Errorf("want %#+v, got %#+v", want, got) + } + }) + } + }) + } +}