diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f59c0326..db41886aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ jobs: strategy: matrix: go_tags: [ 'stdlib', 'goccy', 'es256k', 'alltags'] - go: [ '1.19', '1.18', '1.17' ] + go: [ '1.20', '1.19', '1.18'] name: "Test [ Go ${{ matrix.go }} / Tags ${{ matrix.go_tags }} ]" steps: - name: Checkout repository @@ -57,7 +57,8 @@ jobs: uses: codecov/codecov-action@v3 with: file: ./coverage.out + - uses: bazelbuild/setup-bazelisk@v2 + - run: bazel run //:gazelle-update-repos - name: Check difference between generation code and commit code run: make check_diffs - - uses: bazelbuild/setup-bazelisk@v2 - - run: bazel build //... + diff --git a/Changes b/Changes index e20b75771..3b7330000 100644 --- a/Changes +++ b/Changes @@ -4,6 +4,18 @@ Changes v2 has many incompatibilities with v1. To see the full list of differences between v1 and v2, please read the Changes-v2.md file (https://github.com/lestrrat-go/jwx/blob/develop/v2/Changes-v2.md) +v2.0.10 - UNRELEASED +[Bug fixes] + * Registering JWS signers/verifiers did not work since v2.0.0, because the + way we handle algorithm names changed in 2aa98ce6884187180a7145b73da78c859dd46c84. + (We previously thought that this would be checked by the example code, but it + apparently failed to flag us properly) + + The logic behind managing the internal database has been fixed, and + `jws.RegisterSigner` and `jws.RegisterVerifier` now properly hooks into the new + `jwa.RegisterSignatureAlgorithm` to automatically register new algorithm names + (#910, #911) + v2.0.9 - 21 Mar 2023 [Security Fixes] * Updated use of golang.org/x/crypto to v0.7.0 diff --git a/jwa/compression_gen.go b/jwa/compression_gen.go index ad27ea3f3..9fb65220d 100644 --- a/jwa/compression_gen.go +++ b/jwa/compression_gen.go @@ -17,25 +17,55 @@ const ( NoCompress CompressionAlgorithm = "" // No compression ) -var allCompressionAlgorithms = map[CompressionAlgorithm]struct{}{ - Deflate: {}, - NoCompress: {}, +var muCompressionAlgorithms sync.RWMutex +var allCompressionAlgorithms map[CompressionAlgorithm]struct{} +var listCompressionAlgorithm []CompressionAlgorithm + +func init() { + muCompressionAlgorithms.Lock() + defer muCompressionAlgorithms.Unlock() + allCompressionAlgorithms = make(map[CompressionAlgorithm]struct{}) + allCompressionAlgorithms[Deflate] = struct{}{} + allCompressionAlgorithms[NoCompress] = struct{}{} + rebuildCompressionAlgorithm() } -var listCompressionAlgorithmOnce sync.Once -var listCompressionAlgorithm []CompressionAlgorithm +// RegisterCompressionAlgorithm registers a new CompressionAlgorithm so that the jwx can properly handle the new value. +// Duplicates will silently be ignored +func RegisterCompressionAlgorithm(v CompressionAlgorithm) { + muCompressionAlgorithms.Lock() + defer muCompressionAlgorithms.Unlock() + if _, ok := allCompressionAlgorithms[v]; !ok { + allCompressionAlgorithms[v] = struct{}{} + rebuildCompressionAlgorithm() + } +} + +// UnregisterCompressionAlgorithm unregisters a CompressionAlgorithm from its known database. +// Non-existentn entries will silently be ignored +func UnregisterCompressionAlgorithm(v CompressionAlgorithm) { + muCompressionAlgorithms.Lock() + defer muCompressionAlgorithms.Unlock() + if _, ok := allCompressionAlgorithms[v]; ok { + delete(allCompressionAlgorithms, v) + rebuildCompressionAlgorithm() + } +} + +func rebuildCompressionAlgorithm() { + listCompressionAlgorithm = make([]CompressionAlgorithm, 0, len(allCompressionAlgorithms)) + for v := range allCompressionAlgorithms { + listCompressionAlgorithm = append(listCompressionAlgorithm, v) + } + sort.Slice(listCompressionAlgorithm, func(i, j int) bool { + return string(listCompressionAlgorithm[i]) < string(listCompressionAlgorithm[j]) + }) +} // CompressionAlgorithms returns a list of all available values for CompressionAlgorithm func CompressionAlgorithms() []CompressionAlgorithm { - listCompressionAlgorithmOnce.Do(func() { - listCompressionAlgorithm = make([]CompressionAlgorithm, 0, len(allCompressionAlgorithms)) - for v := range allCompressionAlgorithms { - listCompressionAlgorithm = append(listCompressionAlgorithm, v) - } - sort.Slice(listCompressionAlgorithm, func(i, j int) bool { - return string(listCompressionAlgorithm[i]) < string(listCompressionAlgorithm[j]) - }) - }) + muCompressionAlgorithms.RLock() + defer muCompressionAlgorithms.RUnlock() return listCompressionAlgorithm } diff --git a/jwa/content_encryption_gen.go b/jwa/content_encryption_gen.go index bc82d5058..115fa18e0 100644 --- a/jwa/content_encryption_gen.go +++ b/jwa/content_encryption_gen.go @@ -21,29 +21,59 @@ const ( A256GCM ContentEncryptionAlgorithm = "A256GCM" // AES-GCM (256) ) -var allContentEncryptionAlgorithms = map[ContentEncryptionAlgorithm]struct{}{ - A128CBC_HS256: {}, - A128GCM: {}, - A192CBC_HS384: {}, - A192GCM: {}, - A256CBC_HS512: {}, - A256GCM: {}, +var muContentEncryptionAlgorithms sync.RWMutex +var allContentEncryptionAlgorithms map[ContentEncryptionAlgorithm]struct{} +var listContentEncryptionAlgorithm []ContentEncryptionAlgorithm + +func init() { + muContentEncryptionAlgorithms.Lock() + defer muContentEncryptionAlgorithms.Unlock() + allContentEncryptionAlgorithms = make(map[ContentEncryptionAlgorithm]struct{}) + allContentEncryptionAlgorithms[A128CBC_HS256] = struct{}{} + allContentEncryptionAlgorithms[A128GCM] = struct{}{} + allContentEncryptionAlgorithms[A192CBC_HS384] = struct{}{} + allContentEncryptionAlgorithms[A192GCM] = struct{}{} + allContentEncryptionAlgorithms[A256CBC_HS512] = struct{}{} + allContentEncryptionAlgorithms[A256GCM] = struct{}{} + rebuildContentEncryptionAlgorithm() } -var listContentEncryptionAlgorithmOnce sync.Once -var listContentEncryptionAlgorithm []ContentEncryptionAlgorithm +// RegisterContentEncryptionAlgorithm registers a new ContentEncryptionAlgorithm so that the jwx can properly handle the new value. +// Duplicates will silently be ignored +func RegisterContentEncryptionAlgorithm(v ContentEncryptionAlgorithm) { + muContentEncryptionAlgorithms.Lock() + defer muContentEncryptionAlgorithms.Unlock() + if _, ok := allContentEncryptionAlgorithms[v]; !ok { + allContentEncryptionAlgorithms[v] = struct{}{} + rebuildContentEncryptionAlgorithm() + } +} + +// UnregisterContentEncryptionAlgorithm unregisters a ContentEncryptionAlgorithm from its known database. +// Non-existentn entries will silently be ignored +func UnregisterContentEncryptionAlgorithm(v ContentEncryptionAlgorithm) { + muContentEncryptionAlgorithms.Lock() + defer muContentEncryptionAlgorithms.Unlock() + if _, ok := allContentEncryptionAlgorithms[v]; ok { + delete(allContentEncryptionAlgorithms, v) + rebuildContentEncryptionAlgorithm() + } +} + +func rebuildContentEncryptionAlgorithm() { + listContentEncryptionAlgorithm = make([]ContentEncryptionAlgorithm, 0, len(allContentEncryptionAlgorithms)) + for v := range allContentEncryptionAlgorithms { + listContentEncryptionAlgorithm = append(listContentEncryptionAlgorithm, v) + } + sort.Slice(listContentEncryptionAlgorithm, func(i, j int) bool { + return string(listContentEncryptionAlgorithm[i]) < string(listContentEncryptionAlgorithm[j]) + }) +} // ContentEncryptionAlgorithms returns a list of all available values for ContentEncryptionAlgorithm func ContentEncryptionAlgorithms() []ContentEncryptionAlgorithm { - listContentEncryptionAlgorithmOnce.Do(func() { - listContentEncryptionAlgorithm = make([]ContentEncryptionAlgorithm, 0, len(allContentEncryptionAlgorithms)) - for v := range allContentEncryptionAlgorithms { - listContentEncryptionAlgorithm = append(listContentEncryptionAlgorithm, v) - } - sort.Slice(listContentEncryptionAlgorithm, func(i, j int) bool { - return string(listContentEncryptionAlgorithm[i]) < string(listContentEncryptionAlgorithm[j]) - }) - }) + muContentEncryptionAlgorithms.RLock() + defer muContentEncryptionAlgorithms.RUnlock() return listContentEncryptionAlgorithm } diff --git a/jwa/elliptic_gen.go b/jwa/elliptic_gen.go index 6e813989e..fbfe466aa 100644 --- a/jwa/elliptic_gen.go +++ b/jwa/elliptic_gen.go @@ -23,30 +23,60 @@ const ( X448 EllipticCurveAlgorithm = "X448" ) -var allEllipticCurveAlgorithms = map[EllipticCurveAlgorithm]struct{}{ - Ed25519: {}, - Ed448: {}, - P256: {}, - P384: {}, - P521: {}, - X25519: {}, - X448: {}, +var muEllipticCurveAlgorithms sync.RWMutex +var allEllipticCurveAlgorithms map[EllipticCurveAlgorithm]struct{} +var listEllipticCurveAlgorithm []EllipticCurveAlgorithm + +func init() { + muEllipticCurveAlgorithms.Lock() + defer muEllipticCurveAlgorithms.Unlock() + allEllipticCurveAlgorithms = make(map[EllipticCurveAlgorithm]struct{}) + allEllipticCurveAlgorithms[Ed25519] = struct{}{} + allEllipticCurveAlgorithms[Ed448] = struct{}{} + allEllipticCurveAlgorithms[P256] = struct{}{} + allEllipticCurveAlgorithms[P384] = struct{}{} + allEllipticCurveAlgorithms[P521] = struct{}{} + allEllipticCurveAlgorithms[X25519] = struct{}{} + allEllipticCurveAlgorithms[X448] = struct{}{} + rebuildEllipticCurveAlgorithm() } -var listEllipticCurveAlgorithmOnce sync.Once -var listEllipticCurveAlgorithm []EllipticCurveAlgorithm +// RegisterEllipticCurveAlgorithm registers a new EllipticCurveAlgorithm so that the jwx can properly handle the new value. +// Duplicates will silently be ignored +func RegisterEllipticCurveAlgorithm(v EllipticCurveAlgorithm) { + muEllipticCurveAlgorithms.Lock() + defer muEllipticCurveAlgorithms.Unlock() + if _, ok := allEllipticCurveAlgorithms[v]; !ok { + allEllipticCurveAlgorithms[v] = struct{}{} + rebuildEllipticCurveAlgorithm() + } +} + +// UnregisterEllipticCurveAlgorithm unregisters a EllipticCurveAlgorithm from its known database. +// Non-existentn entries will silently be ignored +func UnregisterEllipticCurveAlgorithm(v EllipticCurveAlgorithm) { + muEllipticCurveAlgorithms.Lock() + defer muEllipticCurveAlgorithms.Unlock() + if _, ok := allEllipticCurveAlgorithms[v]; ok { + delete(allEllipticCurveAlgorithms, v) + rebuildEllipticCurveAlgorithm() + } +} + +func rebuildEllipticCurveAlgorithm() { + listEllipticCurveAlgorithm = make([]EllipticCurveAlgorithm, 0, len(allEllipticCurveAlgorithms)) + for v := range allEllipticCurveAlgorithms { + listEllipticCurveAlgorithm = append(listEllipticCurveAlgorithm, v) + } + sort.Slice(listEllipticCurveAlgorithm, func(i, j int) bool { + return string(listEllipticCurveAlgorithm[i]) < string(listEllipticCurveAlgorithm[j]) + }) +} // EllipticCurveAlgorithms returns a list of all available values for EllipticCurveAlgorithm func EllipticCurveAlgorithms() []EllipticCurveAlgorithm { - listEllipticCurveAlgorithmOnce.Do(func() { - listEllipticCurveAlgorithm = make([]EllipticCurveAlgorithm, 0, len(allEllipticCurveAlgorithms)) - for v := range allEllipticCurveAlgorithms { - listEllipticCurveAlgorithm = append(listEllipticCurveAlgorithm, v) - } - sort.Slice(listEllipticCurveAlgorithm, func(i, j int) bool { - return string(listEllipticCurveAlgorithm[i]) < string(listEllipticCurveAlgorithm[j]) - }) - }) + muEllipticCurveAlgorithms.RLock() + defer muEllipticCurveAlgorithms.RUnlock() return listEllipticCurveAlgorithm } diff --git a/jwa/key_encryption_gen.go b/jwa/key_encryption_gen.go index f85574ffa..49ed1f678 100644 --- a/jwa/key_encryption_gen.go +++ b/jwa/key_encryption_gen.go @@ -32,40 +32,70 @@ const ( RSA_OAEP_256 KeyEncryptionAlgorithm = "RSA-OAEP-256" // RSA-OAEP-SHA256 ) -var allKeyEncryptionAlgorithms = map[KeyEncryptionAlgorithm]struct{}{ - A128GCMKW: {}, - A128KW: {}, - A192GCMKW: {}, - A192KW: {}, - A256GCMKW: {}, - A256KW: {}, - DIRECT: {}, - ECDH_ES: {}, - ECDH_ES_A128KW: {}, - ECDH_ES_A192KW: {}, - ECDH_ES_A256KW: {}, - PBES2_HS256_A128KW: {}, - PBES2_HS384_A192KW: {}, - PBES2_HS512_A256KW: {}, - RSA1_5: {}, - RSA_OAEP: {}, - RSA_OAEP_256: {}, +var muKeyEncryptionAlgorithms sync.RWMutex +var allKeyEncryptionAlgorithms map[KeyEncryptionAlgorithm]struct{} +var listKeyEncryptionAlgorithm []KeyEncryptionAlgorithm + +func init() { + muKeyEncryptionAlgorithms.Lock() + defer muKeyEncryptionAlgorithms.Unlock() + allKeyEncryptionAlgorithms = make(map[KeyEncryptionAlgorithm]struct{}) + allKeyEncryptionAlgorithms[A128GCMKW] = struct{}{} + allKeyEncryptionAlgorithms[A128KW] = struct{}{} + allKeyEncryptionAlgorithms[A192GCMKW] = struct{}{} + allKeyEncryptionAlgorithms[A192KW] = struct{}{} + allKeyEncryptionAlgorithms[A256GCMKW] = struct{}{} + allKeyEncryptionAlgorithms[A256KW] = struct{}{} + allKeyEncryptionAlgorithms[DIRECT] = struct{}{} + allKeyEncryptionAlgorithms[ECDH_ES] = struct{}{} + allKeyEncryptionAlgorithms[ECDH_ES_A128KW] = struct{}{} + allKeyEncryptionAlgorithms[ECDH_ES_A192KW] = struct{}{} + allKeyEncryptionAlgorithms[ECDH_ES_A256KW] = struct{}{} + allKeyEncryptionAlgorithms[PBES2_HS256_A128KW] = struct{}{} + allKeyEncryptionAlgorithms[PBES2_HS384_A192KW] = struct{}{} + allKeyEncryptionAlgorithms[PBES2_HS512_A256KW] = struct{}{} + allKeyEncryptionAlgorithms[RSA1_5] = struct{}{} + allKeyEncryptionAlgorithms[RSA_OAEP] = struct{}{} + allKeyEncryptionAlgorithms[RSA_OAEP_256] = struct{}{} + rebuildKeyEncryptionAlgorithm() } -var listKeyEncryptionAlgorithmOnce sync.Once -var listKeyEncryptionAlgorithm []KeyEncryptionAlgorithm +// RegisterKeyEncryptionAlgorithm registers a new KeyEncryptionAlgorithm so that the jwx can properly handle the new value. +// Duplicates will silently be ignored +func RegisterKeyEncryptionAlgorithm(v KeyEncryptionAlgorithm) { + muKeyEncryptionAlgorithms.Lock() + defer muKeyEncryptionAlgorithms.Unlock() + if _, ok := allKeyEncryptionAlgorithms[v]; !ok { + allKeyEncryptionAlgorithms[v] = struct{}{} + rebuildKeyEncryptionAlgorithm() + } +} + +// UnregisterKeyEncryptionAlgorithm unregisters a KeyEncryptionAlgorithm from its known database. +// Non-existentn entries will silently be ignored +func UnregisterKeyEncryptionAlgorithm(v KeyEncryptionAlgorithm) { + muKeyEncryptionAlgorithms.Lock() + defer muKeyEncryptionAlgorithms.Unlock() + if _, ok := allKeyEncryptionAlgorithms[v]; ok { + delete(allKeyEncryptionAlgorithms, v) + rebuildKeyEncryptionAlgorithm() + } +} + +func rebuildKeyEncryptionAlgorithm() { + listKeyEncryptionAlgorithm = make([]KeyEncryptionAlgorithm, 0, len(allKeyEncryptionAlgorithms)) + for v := range allKeyEncryptionAlgorithms { + listKeyEncryptionAlgorithm = append(listKeyEncryptionAlgorithm, v) + } + sort.Slice(listKeyEncryptionAlgorithm, func(i, j int) bool { + return string(listKeyEncryptionAlgorithm[i]) < string(listKeyEncryptionAlgorithm[j]) + }) +} // KeyEncryptionAlgorithms returns a list of all available values for KeyEncryptionAlgorithm func KeyEncryptionAlgorithms() []KeyEncryptionAlgorithm { - listKeyEncryptionAlgorithmOnce.Do(func() { - listKeyEncryptionAlgorithm = make([]KeyEncryptionAlgorithm, 0, len(allKeyEncryptionAlgorithms)) - for v := range allKeyEncryptionAlgorithms { - listKeyEncryptionAlgorithm = append(listKeyEncryptionAlgorithm, v) - } - sort.Slice(listKeyEncryptionAlgorithm, func(i, j int) bool { - return string(listKeyEncryptionAlgorithm[i]) < string(listKeyEncryptionAlgorithm[j]) - }) - }) + muKeyEncryptionAlgorithms.RLock() + defer muKeyEncryptionAlgorithms.RUnlock() return listKeyEncryptionAlgorithm } diff --git a/jwa/key_type_gen.go b/jwa/key_type_gen.go index 2b602c67a..e1f9e3896 100644 --- a/jwa/key_type_gen.go +++ b/jwa/key_type_gen.go @@ -20,27 +20,57 @@ const ( RSA KeyType = "RSA" // RSA ) -var allKeyTypes = map[KeyType]struct{}{ - EC: {}, - OKP: {}, - OctetSeq: {}, - RSA: {}, +var muKeyTypes sync.RWMutex +var allKeyTypes map[KeyType]struct{} +var listKeyType []KeyType + +func init() { + muKeyTypes.Lock() + defer muKeyTypes.Unlock() + allKeyTypes = make(map[KeyType]struct{}) + allKeyTypes[EC] = struct{}{} + allKeyTypes[OKP] = struct{}{} + allKeyTypes[OctetSeq] = struct{}{} + allKeyTypes[RSA] = struct{}{} + rebuildKeyType() } -var listKeyTypeOnce sync.Once -var listKeyType []KeyType +// RegisterKeyType registers a new KeyType so that the jwx can properly handle the new value. +// Duplicates will silently be ignored +func RegisterKeyType(v KeyType) { + muKeyTypes.Lock() + defer muKeyTypes.Unlock() + if _, ok := allKeyTypes[v]; !ok { + allKeyTypes[v] = struct{}{} + rebuildKeyType() + } +} + +// UnregisterKeyType unregisters a KeyType from its known database. +// Non-existentn entries will silently be ignored +func UnregisterKeyType(v KeyType) { + muKeyTypes.Lock() + defer muKeyTypes.Unlock() + if _, ok := allKeyTypes[v]; ok { + delete(allKeyTypes, v) + rebuildKeyType() + } +} + +func rebuildKeyType() { + listKeyType = make([]KeyType, 0, len(allKeyTypes)) + for v := range allKeyTypes { + listKeyType = append(listKeyType, v) + } + sort.Slice(listKeyType, func(i, j int) bool { + return string(listKeyType[i]) < string(listKeyType[j]) + }) +} // KeyTypes returns a list of all available values for KeyType func KeyTypes() []KeyType { - listKeyTypeOnce.Do(func() { - listKeyType = make([]KeyType, 0, len(allKeyTypes)) - for v := range allKeyTypes { - listKeyType = append(listKeyType, v) - } - sort.Slice(listKeyType, func(i, j int) bool { - return string(listKeyType[i]) < string(listKeyType[j]) - }) - }) + muKeyTypes.RLock() + defer muKeyTypes.RUnlock() return listKeyType } diff --git a/jwa/signature_gen.go b/jwa/signature_gen.go index bc2cbb91c..eaa2f8662 100644 --- a/jwa/signature_gen.go +++ b/jwa/signature_gen.go @@ -30,38 +30,68 @@ const ( RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512 ) -var allSignatureAlgorithms = map[SignatureAlgorithm]struct{}{ - ES256: {}, - ES256K: {}, - ES384: {}, - ES512: {}, - EdDSA: {}, - HS256: {}, - HS384: {}, - HS512: {}, - NoSignature: {}, - PS256: {}, - PS384: {}, - PS512: {}, - RS256: {}, - RS384: {}, - RS512: {}, +var muSignatureAlgorithms sync.RWMutex +var allSignatureAlgorithms map[SignatureAlgorithm]struct{} +var listSignatureAlgorithm []SignatureAlgorithm + +func init() { + muSignatureAlgorithms.Lock() + defer muSignatureAlgorithms.Unlock() + allSignatureAlgorithms = make(map[SignatureAlgorithm]struct{}) + allSignatureAlgorithms[ES256] = struct{}{} + allSignatureAlgorithms[ES256K] = struct{}{} + allSignatureAlgorithms[ES384] = struct{}{} + allSignatureAlgorithms[ES512] = struct{}{} + allSignatureAlgorithms[EdDSA] = struct{}{} + allSignatureAlgorithms[HS256] = struct{}{} + allSignatureAlgorithms[HS384] = struct{}{} + allSignatureAlgorithms[HS512] = struct{}{} + allSignatureAlgorithms[NoSignature] = struct{}{} + allSignatureAlgorithms[PS256] = struct{}{} + allSignatureAlgorithms[PS384] = struct{}{} + allSignatureAlgorithms[PS512] = struct{}{} + allSignatureAlgorithms[RS256] = struct{}{} + allSignatureAlgorithms[RS384] = struct{}{} + allSignatureAlgorithms[RS512] = struct{}{} + rebuildSignatureAlgorithm() } -var listSignatureAlgorithmOnce sync.Once -var listSignatureAlgorithm []SignatureAlgorithm +// RegisterSignatureAlgorithm registers a new SignatureAlgorithm so that the jwx can properly handle the new value. +// Duplicates will silently be ignored +func RegisterSignatureAlgorithm(v SignatureAlgorithm) { + muSignatureAlgorithms.Lock() + defer muSignatureAlgorithms.Unlock() + if _, ok := allSignatureAlgorithms[v]; !ok { + allSignatureAlgorithms[v] = struct{}{} + rebuildSignatureAlgorithm() + } +} + +// UnregisterSignatureAlgorithm unregisters a SignatureAlgorithm from its known database. +// Non-existentn entries will silently be ignored +func UnregisterSignatureAlgorithm(v SignatureAlgorithm) { + muSignatureAlgorithms.Lock() + defer muSignatureAlgorithms.Unlock() + if _, ok := allSignatureAlgorithms[v]; ok { + delete(allSignatureAlgorithms, v) + rebuildSignatureAlgorithm() + } +} + +func rebuildSignatureAlgorithm() { + listSignatureAlgorithm = make([]SignatureAlgorithm, 0, len(allSignatureAlgorithms)) + for v := range allSignatureAlgorithms { + listSignatureAlgorithm = append(listSignatureAlgorithm, v) + } + sort.Slice(listSignatureAlgorithm, func(i, j int) bool { + return string(listSignatureAlgorithm[i]) < string(listSignatureAlgorithm[j]) + }) +} // SignatureAlgorithms returns a list of all available values for SignatureAlgorithm func SignatureAlgorithms() []SignatureAlgorithm { - listSignatureAlgorithmOnce.Do(func() { - listSignatureAlgorithm = make([]SignatureAlgorithm, 0, len(allSignatureAlgorithms)) - for v := range allSignatureAlgorithms { - listSignatureAlgorithm = append(listSignatureAlgorithm, v) - } - sort.Slice(listSignatureAlgorithm, func(i, j int) bool { - return string(listSignatureAlgorithm[i]) < string(listSignatureAlgorithm[j]) - }) - }) + muSignatureAlgorithms.RLock() + defer muSignatureAlgorithms.RUnlock() return listSignatureAlgorithm } diff --git a/jws/jws_test.go b/jws/jws_test.go index 06e9eb3b4..53d1803b1 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -8,8 +8,10 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "crypto/sha256" "crypto/sha512" "encoding/asn1" + "errors" "fmt" "io" "math/big" @@ -1989,3 +1991,56 @@ func TestGH888(t *testing.T) { require.NoError(t, err, `jws.Verify should succeed`) require.Equal(t, []byte(`foo`), verified) } + +// Some stuff required for testing #910 +// The original code used an external library to sign/verify, but here +// we just use a simple SHA256 digest here so that we don't force +// users to download an optional dependency +type s256SignerVerifier struct{} + +const sha256Algo jwa.SignatureAlgorithm = "SillyTest256" + +func (s256SignerVerifier) Algorithm() jwa.SignatureAlgorithm { + return sha256Algo +} + +func (s256SignerVerifier) Sign(payload []byte, _ interface{}) ([]byte, error) { + h := sha256.Sum256(payload) + return h[:], nil +} + +func (s256SignerVerifier) Verify(payload, signature []byte, _ interface{}) error { + h := sha256.Sum256(payload) + if !bytes.Equal(h[:], signature) { + return errors.New("invalid signature") + } + return nil +} + +func TestGH910(t *testing.T) { + // Note: This has global effect. You can't run this in parallel with other tests + jws.RegisterSigner(sha256Algo, jws.SignerFactoryFn(func() (jws.Signer, error) { + return s256SignerVerifier{}, nil + })) + defer jws.UnregisterSigner(sha256Algo) + + jws.RegisterVerifier(sha256Algo, jws.VerifierFactoryFn(func() (jws.Verifier, error) { + return s256SignerVerifier{}, nil + })) + defer jws.UnregisterVerifier(sha256Algo) + defer jwa.UnregisterSignatureAlgorithm(sha256Algo) + + var sa jwa.SignatureAlgorithm + require.NoError(t, sa.Accept(sha256Algo.String()), `jwa.SignatureAlgorithm.Accept should succeed`) + + // Now that we have established that the signature algorithm works, + // we can proceed with the test + const src = `Lorem Ipsum` + signed, err := jws.Sign([]byte(src), jws.WithKey(sha256Algo, nil)) + require.NoError(t, err, `jws.Sign should succeed`) + + verified, err := jws.Verify(signed, jws.WithKey(sha256Algo, nil)) + require.NoError(t, err, `jws.Verify should succeed`) + + require.Equal(t, src, string(verified), `verified payload should match`) +} diff --git a/jws/signer.go b/jws/signer.go index 279386c5f..44c8bfb76 100644 --- a/jws/signer.go +++ b/jws/signer.go @@ -2,6 +2,7 @@ package jws import ( "fmt" + "sync" "github.com/lestrrat-go/jwx/v2/jwa" ) @@ -15,6 +16,7 @@ func (fn SignerFactoryFn) Create() (Signer, error) { return fn() } +var muSignerDB sync.RWMutex var signerDB map[jwa.SignatureAlgorithm]SignerFactory // RegisterSigner is used to register a factory object that creates @@ -23,8 +25,30 @@ var signerDB map[jwa.SignatureAlgorithm]SignerFactory // For example, if you would like to provide a custom signer for // jwa.EdDSA, use this function to register a `SignerFactory` // (probably in your `init()`) +// +// Unlike the `UnregisterSigner` function, this function automatically +// calls `jwa.RegisterSignatureAlgorithm` to register the algorithm +// in the known algorithms database. func RegisterSigner(alg jwa.SignatureAlgorithm, f SignerFactory) { + jwa.RegisterSignatureAlgorithm(alg) + muSignerDB.Lock() signerDB[alg] = f + muSignerDB.Unlock() +} + +// UnregisterSigner removes the signer factory associated with +// the given algorithm. +// +// Note that when you call this function, the algorithm itself is +// not automatically unregistered from the known algorithms database. +// This is because the algorithm may still be required for verification or +// some other operation (however unlikely, it is still possible). +// Therefore, in order to completely remove the algorithm, you must +// call `jwa.UnregisterSignatureAlgorithm` yourself. +func UnregisterSigner(alg jwa.SignatureAlgorithm) { + muSignerDB.Lock() + delete(signerDB, alg) + muSignerDB.Unlock() } func init() { @@ -61,7 +85,10 @@ func init() { // NewSigner creates a signer that signs payloads using the given signature algorithm. func NewSigner(alg jwa.SignatureAlgorithm) (Signer, error) { + muSignerDB.RLock() f, ok := signerDB[alg] + muSignerDB.RUnlock() + if ok { return f.Create() } diff --git a/jws/verifier.go b/jws/verifier.go index 8093f8795..2dd29c848 100644 --- a/jws/verifier.go +++ b/jws/verifier.go @@ -2,6 +2,7 @@ package jws import ( "fmt" + "sync" "github.com/lestrrat-go/jwx/v2/jwa" ) @@ -15,6 +16,7 @@ func (fn VerifierFactoryFn) Create() (Verifier, error) { return fn() } +var muVerifierDB sync.RWMutex var verifierDB map[jwa.SignatureAlgorithm]VerifierFactory // RegisterVerifier is used to register a factory object that creates @@ -23,8 +25,30 @@ var verifierDB map[jwa.SignatureAlgorithm]VerifierFactory // For example, if you would like to provide a custom verifier for // jwa.EdDSA, use this function to register a `VerifierFactory` // (probably in your `init()`) +// +// Unlike the `UnregisterVerifier` function, this function automatically +// calls `jwa.RegisterSignatureAlgorithm` to register the algorithm +// in the known algorithms database. func RegisterVerifier(alg jwa.SignatureAlgorithm, f VerifierFactory) { + jwa.RegisterSignatureAlgorithm(alg) + muVerifierDB.Lock() verifierDB[alg] = f + muVerifierDB.Unlock() +} + +// UnregisterVerifier removes the signer factory associated with +// the given algorithm. +// +// Note that when you call this function, the algorithm itself is +// not automatically unregistered from the known algorithms database. +// This is because the algorithm may still be required for signing or +// some other operation (however unlikely, it is still possible). +// Therefore, in order to completely remove the algorithm, you must +// call `jwa.UnregisterSignatureAlgorithm` yourself. +func UnregisterVerifier(alg jwa.SignatureAlgorithm) { + muVerifierDB.Lock() + delete(verifierDB, alg) + muVerifierDB.Unlock() } func init() { @@ -61,7 +85,10 @@ func init() { // NewVerifier creates a verifier that signs payloads using the given signature algorithm. func NewVerifier(alg jwa.SignatureAlgorithm) (Verifier, error) { + muVerifierDB.RLock() f, ok := verifierDB[alg] + muVerifierDB.RUnlock() + if ok { return f.Create() } diff --git a/tools/cmd/genjwa/main.go b/tools/cmd/genjwa/main.go index fd123c3e8..b7f8a7f66 100644 --- a/tools/cmd/genjwa/main.go +++ b/tools/cmd/genjwa/main.go @@ -400,19 +400,48 @@ func (t typ) Generate() error { } o.L(")") // end const - o.L("var all%[1]ss = map[%[1]s]struct{} {", t.name) + // Register%s and related tools are provided so users can register their own types. + // This triggers some re-building of data structures that are otherwise + // reused for efficiency + o.LL("var mu%[1]ss sync.RWMutex", t.name) + o.L("var all%[1]ss map[%[1]s]struct{}", t.name) + o.L("var list%[1]s []%[1]s", t.name) + + o.LL("func init() {") + o.L("mu%[1]ss.Lock()", t.name) + o.L("defer mu%[1]ss.Unlock()", t.name) + o.L("all%[1]ss = make(map[%[1]s]struct{})", t.name) for _, e := range t.elements { if !e.invalid { - o.L("%s: {},", e.name) + o.L("all%[1]ss[%[2]s] = struct{}{}", t.name, e.name) } } + o.L("rebuild%[1]s()", t.name) o.L("}") - o.LL("var list%sOnce sync.Once", t.name) - o.L("var list%[1]s []%[1]s", t.name) - o.LL("// %[1]ss returns a list of all available values for %[1]s", t.name) - o.L("func %[1]ss() []%[1]s {", t.name) - o.L("list%sOnce.Do(func() {", t.name) + o.LL("// Register%[1]s registers a new %[1]s so that the jwx can properly handle the new value.", t.name) + o.L("// Duplicates will silently be ignored") + o.L("func Register%[1]s(v %[1]s) {", t.name) + o.L("mu%[1]ss.Lock()", t.name) + o.L("defer mu%[1]ss.Unlock()", t.name) + o.L("if _, ok := all%[1]ss[v]; !ok {", t.name) + o.L("all%[1]ss[v] = struct{}{}", t.name) + o.L("rebuild%[1]s()", t.name) + o.L("}") + o.L("}") + + o.LL("// Unregister%[1]s unregisters a %[1]s from its known database.", t.name) + o.L("// Non-existentn entries will silently be ignored") + o.L("func Unregister%[1]s(v %[1]s) {", t.name) + o.L("mu%[1]ss.Lock()", t.name) + o.L("defer mu%[1]ss.Unlock()", t.name) + o.L("if _, ok := all%[1]ss[v]; ok {", t.name) + o.L("delete(all%[1]ss, v)", t.name) + o.L("rebuild%[1]s()", t.name) + o.L("}") + o.L("}") + + o.LL("func rebuild%[1]s() {", t.name) o.L("list%[1]s = make([]%[1]s, 0, len(all%[1]ss))", t.name) o.L("for v := range all%ss {", t.name) o.L("list%[1]s = append(list%[1]s, v)", t.name) @@ -420,7 +449,12 @@ func (t typ) Generate() error { o.L("sort.Slice(list%s, func(i, j int) bool {", t.name) o.L("return string(list%[1]s[i]) < string(list%[1]s[j])", t.name) o.L("})") - o.L("})") + o.L("}") + + o.LL("// %[1]ss returns a list of all available values for %[1]s", t.name) + o.L("func %[1]ss() []%[1]s {", t.name) + o.L("mu%[1]ss.RLock()", t.name) + o.L("defer mu%[1]ss.RUnlock()", t.name) o.L("return list%s", t.name) o.L("}")