diff --git a/sdk/core/core-rest-pipeline/CHANGELOG.md b/sdk/core/core-rest-pipeline/CHANGELOG.md index df5eefd206b6..6dafc7b7338b 100644 --- a/sdk/core/core-rest-pipeline/CHANGELOG.md +++ b/sdk/core/core-rest-pipeline/CHANGELOG.md @@ -2,6 +2,7 @@ ## 1.0.4 (Unreleased) +- Rewrote `bearerTokenAuthenticationPolicy` to use a new backend that refreshes tokens only when they're about to expire and not multiple times before. This is based on a similar fix implemented on `@azure/core-http@1.2.4` ([PR with the changes](https://github.com/Azure/azure-sdk-for-js/pull/14223)). This fixes the issue: [13369](https://github.com/Azure/azure-sdk-for-js/issues/13369). ## 1.0.3 (2021-03-30) diff --git a/sdk/core/core-rest-pipeline/src/policies/bearerTokenAuthenticationPolicy.ts b/sdk/core/core-rest-pipeline/src/policies/bearerTokenAuthenticationPolicy.ts index 8ca6c095df2c..927a87539c74 100644 --- a/sdk/core/core-rest-pipeline/src/policies/bearerTokenAuthenticationPolicy.ts +++ b/sdk/core/core-rest-pipeline/src/policies/bearerTokenAuthenticationPolicy.ts @@ -3,8 +3,8 @@ import { PipelineResponse, PipelineRequest, SendRequest } from "../interfaces"; import { PipelinePolicy } from "../pipeline"; -import { TokenCredential, GetTokenOptions } from "@azure/core-auth"; -import { AccessTokenCache, ExpiringAccessTokenCache } from "../accessTokenCache"; +import { TokenCredential } from "@azure/core-auth"; +import { createTokenCycler } from "../util/tokenCycler"; /** * The programmatic identifier of the bearerTokenAuthenticationPolicy. @@ -33,20 +33,17 @@ export function bearerTokenAuthenticationPolicy( options: BearerTokenAuthenticationPolicyOptions ): PipelinePolicy { const { credential, scopes } = options; - const tokenCache: AccessTokenCache = new ExpiringAccessTokenCache(); - async function getToken(tokenOptions: GetTokenOptions): Promise { - let accessToken = tokenCache.getCachedToken(); - if (accessToken === undefined) { - accessToken = (await credential.getToken(scopes, tokenOptions)) || undefined; - tokenCache.setCachedToken(accessToken); - } - return accessToken ? accessToken.token : undefined; - } + // This function encapsulates the entire process of reliably retrieving the token + // The options are left out of the public API until there's demand to configure this. + // Remember to extend `BearerTokenAuthenticationPolicyOptions` with `TokenCyclerOptions` + // in order to pass through the `options` object. + const getToken = createTokenCycler(credential, scopes /* , options */); + return { name: bearerTokenAuthenticationPolicyName, async sendRequest(request: PipelineRequest, next: SendRequest): Promise { - const token = await getToken({ + const { token } = await getToken({ abortSignal: request.abortSignal, tracingOptions: request.tracingOptions }); diff --git a/sdk/core/core-rest-pipeline/src/util/tokenCycler.ts b/sdk/core/core-rest-pipeline/src/util/tokenCycler.ts new file mode 100644 index 000000000000..30022631b125 --- /dev/null +++ b/sdk/core/core-rest-pipeline/src/util/tokenCycler.ts @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +import { AccessToken, GetTokenOptions, TokenCredential } from "@azure/core-auth"; +import { delay } from "./helpers"; + +/** + * A function that gets a promise of an access token and allows providing + * options. + * + * @param options - the options to pass to the underlying token provider + */ +type AccessTokenGetter = (options: GetTokenOptions) => Promise; + +export interface TokenCyclerOptions { + /** + * The window of time before token expiration during which the token will be + * considered unusable due to risk of the token expiring before sending the + * request. + * + * This will only become meaningful if the refresh fails for over + * (refreshWindow - forcedRefreshWindow) milliseconds. + */ + forcedRefreshWindowInMs: number; + /** + * Interval in milliseconds to retry failed token refreshes. + */ + retryIntervalInMs: number; + /** + * The window of time before token expiration during which + * we will attempt to refresh the token. + */ + refreshWindowInMs: number; +} + +// Default options for the cycler if none are provided +export const DEFAULT_CYCLER_OPTIONS: TokenCyclerOptions = { + forcedRefreshWindowInMs: 1000, // Force waiting for a refresh 1s before the token expires + retryIntervalInMs: 3000, // Allow refresh attempts every 3s + refreshWindowInMs: 1000 * 60 * 2 // Start refreshing 2m before expiry +}; + +/** + * Converts an an unreliable access token getter (which may resolve with null) + * into an AccessTokenGetter by retrying the unreliable getter in a regular + * interval. + * + * @param getAccessToken - A function that produces a promise of an access token that may fail by returning null. + * @param retryIntervalInMs - The time (in milliseconds) to wait between retry attempts. + * @param refreshTimeout - The timestamp after which the refresh attempt will fail, throwing an exception. + * @returns - A promise that, if it resolves, will resolve with an access token. + */ +async function beginRefresh( + getAccessToken: () => Promise, + retryIntervalInMs: number, + refreshTimeout: number +): Promise { + // This wrapper handles exceptions gracefully as long as we haven't exceeded + // the timeout. + async function tryGetAccessToken(): Promise { + if (Date.now() < refreshTimeout) { + try { + return await getAccessToken(); + } catch { + return null; + } + } else { + const finalToken = await getAccessToken(); + + // Timeout is up, so throw if it's still null + if (finalToken === null) { + throw new Error("Failed to refresh access token."); + } + + return finalToken; + } + } + + let token: AccessToken | null = await tryGetAccessToken(); + + while (token === null) { + await delay(retryIntervalInMs); + + token = await tryGetAccessToken(); + } + + return token; +} + +/** + * Creates a token cycler from a credential, scopes, and optional settings. + * + * A token cycler represents a way to reliably retrieve a valid access token + * from a TokenCredential. It will handle initializing the token, refreshing it + * when it nears expiration, and synchronizes refresh attempts to avoid + * concurrency hazards. + * + * @param credential - the underlying TokenCredential that provides the access + * token + * @param scopes - the scopes to request authorization for + * @param tokenCyclerOptions - optionally override default settings for the cycler + * + * @returns - a function that reliably produces a valid access token + */ +export function createTokenCycler( + credential: TokenCredential, + scopes: string | string[], + tokenCyclerOptions?: Partial +): AccessTokenGetter { + let refreshWorker: Promise | null = null; + let token: AccessToken | null = null; + + const options = { + ...DEFAULT_CYCLER_OPTIONS, + ...tokenCyclerOptions + }; + + /** + * This little holder defines several predicates that we use to construct + * the rules of refreshing the token. + */ + const cycler = { + /** + * Produces true if a refresh job is currently in progress. + */ + get isRefreshing(): boolean { + return refreshWorker !== null; + }, + /** + * Produces true if the cycler SHOULD refresh (we are within the refresh + * window and not already refreshing) + */ + get shouldRefresh(): boolean { + return ( + !cycler.isRefreshing && + (token?.expiresOnTimestamp ?? 0) - options.refreshWindowInMs < Date.now() + ); + }, + /** + * Produces true if the cycler MUST refresh (null or nearly-expired + * token). + */ + get mustRefresh(): boolean { + return ( + token === null || token.expiresOnTimestamp - options.forcedRefreshWindowInMs < Date.now() + ); + } + }; + + /** + * Starts a refresh job or returns the existing job if one is already + * running. + */ + function refresh(getTokenOptions: GetTokenOptions): Promise { + if (!cycler.isRefreshing) { + // We bind `scopes` here to avoid passing it around a lot + const tryGetAccessToken = (): Promise => + credential.getToken(scopes, getTokenOptions); + + // Take advantage of promise chaining to insert an assignment to `token` + // before the refresh can be considered done. + refreshWorker = beginRefresh( + tryGetAccessToken, + options.retryIntervalInMs, + // If we don't have a token, then we should timeout immediately + token?.expiresOnTimestamp ?? Date.now() + ) + .then((_token) => { + refreshWorker = null; + token = _token; + return token; + }) + .catch((reason) => { + // We also should reset the refresher if we enter a failed state. All + // existing awaiters will throw, but subsequent requests will start a + // new retry chain. + refreshWorker = null; + token = null; + throw reason; + }); + } + + return refreshWorker as Promise; + } + + return async (tokenOptions: GetTokenOptions): Promise => { + // + // Simple rules: + // - If we MUST refresh, then return the refresh task, blocking + // the pipeline until a token is available. + // - If we SHOULD refresh, then run refresh but don't return it + // (we can still use the cached token). + // - Return the token, since it's fine if we didn't return in + // step 1. + // + + if (cycler.mustRefresh) return refresh(tokenOptions); + + if (cycler.shouldRefresh) { + refresh(tokenOptions); + } + + return token as AccessToken; + }; +} diff --git a/sdk/core/core-rest-pipeline/test/bearerTokenAuthenticationPolicy.spec.ts b/sdk/core/core-rest-pipeline/test/bearerTokenAuthenticationPolicy.spec.ts index 6118e781c46f..5072a69ba91e 100644 --- a/sdk/core/core-rest-pipeline/test/bearerTokenAuthenticationPolicy.spec.ts +++ b/sdk/core/core-rest-pipeline/test/bearerTokenAuthenticationPolicy.spec.ts @@ -5,7 +5,6 @@ import { assert } from "chai"; import * as sinon from "sinon"; import { TokenCredential, AccessToken } from "@azure/core-auth"; import {} from "../src/policies/bearerTokenAuthenticationPolicy"; -import { DefaultTokenRefreshBufferMs } from "../src/accessTokenCache"; import { PipelinePolicy, createPipelineRequest, @@ -14,8 +13,20 @@ import { bearerTokenAuthenticationPolicy, SendRequest } from "../src"; +import { DEFAULT_CYCLER_OPTIONS } from "../src/util/tokenCycler"; + +const { refreshWindowInMs: defaultRefreshWindow } = DEFAULT_CYCLER_OPTIONS; describe("BearerTokenAuthenticationPolicy", function() { + let clock: sinon.SinonFakeTimers; + + beforeEach(() => { + clock = sinon.useFakeTimers(Date.now()); + }); + afterEach(() => { + clock.restore(); + }); + it("correctly adds an Authentication header with the Bearer token", async function() { const mockToken = "token"; const tokenScopes = ["scope1", "scope2"]; @@ -48,19 +59,100 @@ describe("BearerTokenAuthenticationPolicy", function() { assert.strictEqual(request.headers.get("Authorization"), `Bearer ${mockToken}`); }); - it("refreshes access tokens when they expire", async () => { - const now = Date.now(); - const refreshCred1 = new MockRefreshAzureCredential(now); - const refreshCred2 = new MockRefreshAzureCredential(now + DefaultTokenRefreshBufferMs); - const notRefreshCred1 = new MockRefreshAzureCredential( - now + DefaultTokenRefreshBufferMs + 5000 - ); + it("refreshes the token on initial request", async () => { + const expiresOn = Date.now() + 1000 * 60; // One minute later. + const credential = new MockRefreshAzureCredential(expiresOn); + const request = createPipelineRequest({ url: "https://example.com" }); + + const successResponse: PipelineResponse = { + headers: createHttpHeaders(), + request, + status: 200 + }; + const next = sinon.stub, ReturnType>(); + next.resolves(successResponse); + + const policy = createBearerTokenPolicy("test-scope", credential); + + await policy.sendRequest(request, next); + assert.strictEqual(credential.authCount, 1); + }); + + it("refreshes the token during the refresh window", async () => { + const expireDelayMs = defaultRefreshWindow + 5000; + let tokenExpiration = Date.now() + expireDelayMs; + const credential = new MockRefreshAzureCredential(tokenExpiration); + + const request = createPipelineRequest({ url: "https://example.com" }); + const successResponse: PipelineResponse = { + headers: createHttpHeaders(), + request, + status: 200 + }; + const next = sinon.stub, ReturnType>(); + next.resolves(successResponse); + + const policy = createBearerTokenPolicy("test-scope", credential); + + // The token is cached and remains cached for a bit. + await policy.sendRequest(request, next); + await policy.sendRequest(request, next); + assert.strictEqual(credential.authCount, 1); + + // The token will remain cached until tokenExpiration - testTokenRefreshBufferMs, so in (5000 - 1000) milliseconds. + + // For safe measure, we test the token is still cached a second earlier than the forced refresh expectation. + clock.tick(expireDelayMs - defaultRefreshWindow - 1000); + await policy.sendRequest(request, next); + assert.strictEqual(credential.authCount, 1); + + // The new token will last for a few minutes again. + tokenExpiration = Date.now() + expireDelayMs; + credential.expiresOnTimestamp = tokenExpiration; + + // Now we wait until it expires: + clock.tick(expireDelayMs + 1000); + await policy.sendRequest(request, next); + assert.strictEqual(credential.authCount, 2); + }); + + it("access token refresher should prevent multiple initial getToken requests to break", async () => { + const expireDelayMs = 5000; + const startTime = Date.now(); + const tokenExpiration = startTime + expireDelayMs; + const getTokenDelay = 100; + const credential = new MockRefreshAzureCredential(tokenExpiration, getTokenDelay, clock); + + const request = createPipelineRequest({ url: "https://example.com" }); + const successResponse: PipelineResponse = { + headers: createHttpHeaders(), + request, + status: 200 + }; + const next = sinon.stub, ReturnType>(); + next.resolves(successResponse); + + const policy = createBearerTokenPolicy("test-scope", credential); - const credentialsToTest: [MockRefreshAzureCredential, number][] = [ - [refreshCred1, 2], - [refreshCred2, 2], - [notRefreshCred1, 1] + // Now we send some requests. + const promises = [ + policy.sendRequest(request, next), + policy.sendRequest(request, next), + policy.sendRequest(request, next) ]; + // Now we wait until they're all resolved. + for (const promise of promises) { + await promise; + } + assert.strictEqual(credential.authCount, 1, "The first authentication should have happened"); + }); + + it("credential errors should bubble up", async () => { + const expireDelayMs = 5000; + const startTime = Date.now(); + const tokenExpiration = startTime + expireDelayMs; + const getTokenDelay = 100; + const credential = new MockRefreshAzureCredential(tokenExpiration, getTokenDelay, clock); const request = createPipelineRequest({ url: "https://example.com" }); const successResponse: PipelineResponse = { @@ -71,12 +163,65 @@ describe("BearerTokenAuthenticationPolicy", function() { const next = sinon.stub, ReturnType>(); next.resolves(successResponse); - for (const [credentialToTest, expectedCalls] of credentialsToTest) { - const policy = createBearerTokenPolicy("testscope", credentialToTest); - await policy.sendRequest(request, next); + const policy = createBearerTokenPolicy("test-scope", credential); + + credential.shouldThrow = true; + + let error: Error | undefined; + try { await policy.sendRequest(request, next); - assert.strictEqual(credentialToTest.authCount, expectedCalls); + } catch (e) { + error = e; } + assert.equal(error?.message, "Failed to retrieve the token"); + + assert.strictEqual( + credential.authCount, + 1, + "The first authentication attempt should have happened" + ); + }); + + it("access token refresher should prevent refreshers to happen too fast while the token is about to expire", async () => { + const expireDelayMs = 5000; + const startTime = Date.now(); + const tokenExpiration = startTime + defaultRefreshWindow + expireDelayMs; + const getTokenDelay = 100; + const credential = new MockRefreshAzureCredential(tokenExpiration, getTokenDelay, clock); + + const request = createPipelineRequest({ url: "https://example.com" }); + const successResponse: PipelineResponse = { + headers: createHttpHeaders(), + request, + status: 200 + }; + const next = sinon.stub, ReturnType>(); + next.resolves(successResponse); + + const policy = createBearerTokenPolicy("test-scope", credential); + + await policy.sendRequest(request, next); + assert.strictEqual(credential.authCount, 1, "The first authentication should have happened"); + + clock.tick(tokenExpiration - startTime - defaultRefreshWindow); // Until we start refreshing the token + + // Now we wait until some requests are all resolved. + await Promise.all([ + policy.sendRequest(request, next), + policy.sendRequest(request, next), + policy.sendRequest(request, next) + ]); + + // Only getTokenDelay should have passed, and only one refresh should have happened. + assert.strictEqual( + credential.authCount, + 2, + "authCode should have been called once during the refresh time" + ); + + const exceptionMessage = + "the total time passed should be in the refresh room, plus the many getTokens that have happened so far"; + assert.equal(expireDelayMs + 2 * getTokenDelay, Date.now() - startTime, exceptionMessage); }); function createBearerTokenPolicy( @@ -91,15 +236,27 @@ describe("BearerTokenAuthenticationPolicy", function() { }); class MockRefreshAzureCredential implements TokenCredential { - private _expiresOnTimestamp: number; public authCount = 0; + public shouldThrow: boolean = false; - constructor(expiresOnTimestamp: number) { - this._expiresOnTimestamp = expiresOnTimestamp; - } + constructor( + public expiresOnTimestamp: number, + public getTokenDelay?: number, + public clock?: sinon.SinonFakeTimers + ) {} public async getToken(): Promise { this.authCount++; - return { token: "mocktoken", expiresOnTimestamp: this._expiresOnTimestamp }; + + if (this.shouldThrow) { + throw new Error("Failed to retrieve the token"); + } + + // Allowing getToken to take a while + if (this.getTokenDelay && this.clock) { + this.clock.tick(this.getTokenDelay); + } + + return { token: "mock-token", expiresOnTimestamp: this.expiresOnTimestamp }; } }