From fa8b639c277e1694b08a08c7152341b22ec1725d Mon Sep 17 00:00:00 2001 From: Filip Skokan Date: Fri, 10 May 2024 15:02:58 +0200 Subject: [PATCH] feat: allow observing remote JWKS resolver state and its manual reload --- src/jwks/local.ts | 17 ++++++++++-- src/jwks/remote.ts | 52 +++++++++++++++++++++++++++++++++-- tap/jwks.ts | 9 ++++++ test/jwks/local.test.mjs | 9 ++++++ test/jwks/remote.test.mjs | 56 ++++++++++++++++++++++++++++++++++++++ test/types/index.test-d.ts | 9 ++++++ 6 files changed, 148 insertions(+), 4 deletions(-) diff --git a/src/jwks/local.ts b/src/jwks/local.ts index 474eb2c1d8..0936a7eec8 100644 --- a/src/jwks/local.ts +++ b/src/jwks/local.ts @@ -59,7 +59,7 @@ function clone(obj: T): T { /** @private */ export class LocalJWKSet { - protected _jwks?: JSONWebKeySet + private _jwks?: JSONWebKeySet private _cached: WeakMap> = new WeakMap() @@ -252,8 +252,21 @@ async function importWithAlgCache( */ export function createLocalJWKSet(jwks: JSONWebKeySet) { const set = new LocalJWKSet(jwks) - return async ( + + const localJWKSet = async ( protectedHeader?: JWSHeaderParameters, token?: FlattenedJWSInput, ): Promise => set.getKey(protectedHeader, token) + + Object.defineProperties(localJWKSet, { + jwks: { + // @ts-expect-error + value: () => clone(set._jwks), + enumerable: true, + configurable: false, + writable: false, + }, + }) + + return localJWKSet } diff --git a/src/jwks/remote.ts b/src/jwks/remote.ts index 3dff8cb026..64b8fe76dd 100644 --- a/src/jwks/remote.ts +++ b/src/jwks/remote.ts @@ -225,10 +225,58 @@ class RemoteJWKSet { export function createRemoteJWKSet( url: URL, options?: RemoteJWKSetOptions, -) { +): { + (protectedHeader?: JWSHeaderParameters, token?: FlattenedJWSInput): Promise + /** @ignore */ + coolingDown: boolean + /** @ignore */ + fresh: boolean + /** @ignore */ + reloading: boolean + /** @ignore */ + reload: () => Promise + /** @ignore */ + jwks: () => JSONWebKeySet | undefined +} { const set = new RemoteJWKSet(url, options) - return async ( + + const remoteJWKSet = async ( protectedHeader?: JWSHeaderParameters, token?: FlattenedJWSInput, ): Promise => set.getKey(protectedHeader, token) + + Object.defineProperties(remoteJWKSet, { + coolingDown: { + get: () => set.coolingDown(), + enumerable: true, + configurable: false, + }, + fresh: { + get: () => set.fresh(), + enumerable: true, + configurable: false, + }, + reload: { + value: () => set.reload(), + enumerable: true, + configurable: false, + writable: false, + }, + reloading: { + // @ts-expect-error + get: () => !!set._pendingFetch, + enumerable: true, + configurable: false, + }, + jwks: { + // @ts-expect-error + value: () => set._local?.jwks(), + enumerable: true, + configurable: false, + writable: false, + }, + }) + + // @ts-expect-error + return remoteJWKSet } diff --git a/tap/jwks.ts b/tap/jwks.ts index ab8c265e40..b2e0783819 100644 --- a/tap/jwks.ts +++ b/tap/jwks.ts @@ -10,13 +10,22 @@ export default (QUnit: QUnit, lib: typeof jose) => { test('[createRemoteJWKSet] fetches the JWKSet', async (t: typeof QUnit.assert) => { const response = await fetch(jwksUri).then((r) => r.json()) const { alg, kid } = response.keys[0] + const jwks = lib.createRemoteJWKSet(new URL(jwksUri)) + t.false(jwks.coolingDown) + t.false(jwks.fresh) + t.equal(jwks.jwks(), undefined) + await t.rejects(jwks({ alg: 'RS256' }), 'multiple matching keys found in the JSON Web Key Set') await t.rejects( jwks({ kid: 'foo', alg: 'RS256' }), 'no applicable key found in the JSON Web Key Set', ) t.ok(await Promise.all([jwks({ alg, kid }), jwks({ alg, kid })])) + + t.true(jwks.coolingDown) + t.true(jwks.fresh) + t.ok(jwks.jwks()) }) test('[createLocalJWKSet] establishes local JWKSet', async (t: typeof QUnit.assert) => { diff --git a/test/jwks/local.test.mjs b/test/jwks/local.test.mjs index 02682790f7..0181b8fbc4 100644 --- a/test/jwks/local.test.mjs +++ b/test/jwks/local.test.mjs @@ -17,4 +17,13 @@ test('LocalJWKSet', async (t) => { ]) { t.throws(() => createLocalJWKSet(f), { code: 'ERR_JWKS_INVALID' }) } + + const jwks = { keys: [] } + const set = createLocalJWKSet(jwks) + + const clone = set.jwks() + t.false(clone === jwks) + t.false(clone === set.jwks()) + t.deepEqual(clone, jwks) + t.deepEqual(clone, set.jwks()) }) diff --git a/test/jwks/remote.test.mjs b/test/jwks/remote.test.mjs index ef09e48a3c..2639c1d09a 100644 --- a/test/jwks/remote.test.mjs +++ b/test/jwks/remote.test.mjs @@ -267,6 +267,62 @@ test.serial('refreshes the JWKS once off cooldown', async (t) => { } }) +test.serial('createRemoteJWKSet manual reload', async (t) => { + timekeeper.freeze(now * 1000) + const jwk = { + crv: 'P-256', + x: 'fqCXPnWs3sSfwztvwYU9SthmRdoT4WCXxS8eD8icF6U', + y: 'nP6GIc42c61hoKqPcZqkvzhzIJkBV3Jw3g8sGG7UeP8', + d: 'XikZvoy8ayRpOnuz7ont2DkgMxp_kmmg1EKcuIJWX_E', + kty: 'EC', + } + const jwks = { + keys: [ + { + crv: 'P-256', + x: 'fqCXPnWs3sSfwztvwYU9SthmRdoT4WCXxS8eD8icF6U', + y: 'nP6GIc42c61hoKqPcZqkvzhzIJkBV3Jw3g8sGG7UeP8', + kty: 'EC', + kid: 'one', + }, + ], + } + + const scope = nock('https://as.example.com').get('/jwks').once().reply(200, jwks) + + const url = new URL('https://as.example.com/jwks') + const JWKS = createRemoteJWKSet(url) + t.false(JWKS.coolingDown) + t.false(JWKS.fresh) + t.false(JWKS.reloading) + t.is(JWKS.jwks(), undefined) + const key = await importJWK({ ...jwk, alg: 'ES256' }) + { + const jwt = await new SignJWT().setProtectedHeader({ alg: 'ES256', kid: 'two' }).sign(key) + await t.throwsAsync(jwtVerify(jwt, JWKS), { + code: 'ERR_JWKS_NO_MATCHING_KEY', + message: 'no applicable key found in the JSON Web Key Set', + }) + jwks.keys[0].kid = 'two' + scope.get('/jwks').once().reply(200, jwks) + t.true(JWKS.coolingDown) + t.true(JWKS.fresh) + t.false(JWKS.reloading) + t.notDeepEqual(JWKS.jwks(), jwks) + const reload = JWKS.reload() + t.true(JWKS.reloading) + await reload + t.true(JWKS.coolingDown) + t.true(JWKS.fresh) + t.false(JWKS.reloading) + t.deepEqual(JWKS.jwks(), jwks) + await t.notThrowsAsync(jwtVerify(jwt, JWKS)) + JWKS.jwks().keys = [] + t.deepEqual(JWKS.jwks(), jwks) + await t.notThrowsAsync(jwtVerify(jwt, JWKS)) + } +}) + test.serial('refreshes the JWKS once stale', async (t) => { timekeeper.freeze(now * 1000) const jwk = { diff --git a/test/types/index.test-d.ts b/test/types/index.test-d.ts index 5d3e90823f..06019a9c89 100644 --- a/test/types/index.test-d.ts +++ b/test/types/index.test-d.ts @@ -205,6 +205,15 @@ expectType(await lib.createRemoteJWKSet(new URL(''))()) expectType(await lib.createRemoteJWKSet(new URL(''))()) expectType(await lib.createRemoteJWKSet(new URL(''))()) +{ + const jwks = lib.createRemoteJWKSet(new URL('')) + expectType(jwks.fresh) + expectType(jwks.coolingDown) + expectType(jwks.reloading) + expectType>(jwks.reload()) + expectType(jwks.jwks()) +} + expectType(await lib.EmbeddedJWK()) expectType(await lib.EmbeddedJWK()) expectType(await lib.EmbeddedJWK())