diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 12b178f2dd..c7c9b9a4b4 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -1972,6 +1972,79 @@ describe('TokenRatesController', () => { }, ); }); + + it('only updates rates once when called twice', async () => { + const tokenAddresses = [ + '0x0000000000000000000000000000000000000001', + '0x0000000000000000000000000000000000000002', + ]; + const fetchTokenPricesMock = jest.fn().mockResolvedValue({ + [tokenAddresses[0]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[0], + value: 0.001, + }, + [tokenAddresses[1]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[1], + value: 0.002, + }, + }); + const tokenPricesService = buildMockTokenPricesService({ + fetchTokenPrices: fetchTokenPricesMock, + }); + await withController( + { options: { tokenPricesService } }, + async ({ controller, controllerEvents }) => { + const updateExchangeRates = async () => + await callUpdateExchangeRatesMethod({ + allTokens: { + [toHex(1)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddresses[0], + decimals: 18, + symbol: 'TST1', + aggregators: [], + }, + { + address: tokenAddresses[1], + decimals: 18, + symbol: 'TST2', + aggregators: [], + }, + ], + }, + }, + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + }); + + await Promise.all([updateExchangeRates(), updateExchangeRates()]); + + expect(fetchTokenPricesMock).toHaveBeenCalledTimes(1); + expect(controller.state).toMatchInlineSnapshot(` + Object { + "contractExchangeRates": Object { + "0x0000000000000000000000000000000000000001": 0.001, + "0x0000000000000000000000000000000000000002": 0.002, + }, + "contractExchangeRatesByChainId": Object { + "0x1": Object { + "ETH": Object { + "0x0000000000000000000000000000000000000001": 0.001, + "0x0000000000000000000000000000000000000002": 0.002, + }, + }, + }, + } + `); + }, + ); + }); }); }); @@ -2059,9 +2132,22 @@ async function withController( }); } finally { controller.stop(); + await flushPromises(); } } +/** + * Resolve all pending promises. + * + * This method is used for async tests that use fake timers. + * See https://stackoverflow.com/a/58716087 and https://jestjs.io/docs/timer-mocks. + * + * TODO: migrate this to @metamask/utils + */ +async function flushPromises(): Promise { + await new Promise(jest.requireActual('timers').setImmediate); +} + /** * Call an "update exchange rates" method with the given parameters. * diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index dfc0068480..8a2a08f69b 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -146,6 +146,8 @@ export class TokenRatesController extends PollingControllerV1< #tokenPricesService: AbstractTokenPricesService; + #inProcessExchangeRateUpdates: Record<`${Hex}:${string}`, Promise> = {}; + /** * Name of this controller used during composition */ @@ -360,36 +362,60 @@ export class TokenRatesController extends PollingControllerV1< return; } - const newContractExchangeRates = await this.#fetchAndMapExchangeRates({ - tokenContractAddresses, - chainId, - nativeCurrency, - }); + const updateKey: `${Hex}:${string}` = `${chainId}:${nativeCurrency}`; + if (updateKey in this.#inProcessExchangeRateUpdates) { + // This prevents redundant updates + // This promise is resolved after the in-progress update has finished, + // and state has been updated. + await this.#inProcessExchangeRateUpdates[updateKey]; + return; + } + + const { + promise: inProgressUpdate, + resolve: updateSucceeded, + reject: updateFailed, + } = createDeferredPromise({ suppressUnhandledRejection: true }); + this.#inProcessExchangeRateUpdates[updateKey] = inProgressUpdate; + + try { + const newContractExchangeRates = await this.#fetchAndMapExchangeRates({ + tokenContractAddresses, + chainId, + nativeCurrency, + }); - const existingContractExchangeRates = this.state.contractExchangeRates; - const updatedContractExchangeRates = - chainId === this.config.chainId && - nativeCurrency === this.config.nativeCurrency - ? newContractExchangeRates - : existingContractExchangeRates; - - const existingContractExchangeRatesForChainId = - this.state.contractExchangeRatesByChainId[chainId] ?? {}; - const updatedContractExchangeRatesForChainId = { - ...this.state.contractExchangeRatesByChainId, - [chainId]: { - ...existingContractExchangeRatesForChainId, - [nativeCurrency]: { - ...existingContractExchangeRatesForChainId[nativeCurrency], - ...newContractExchangeRates, + const existingContractExchangeRates = this.state.contractExchangeRates; + const updatedContractExchangeRates = + chainId === this.config.chainId && + nativeCurrency === this.config.nativeCurrency + ? newContractExchangeRates + : existingContractExchangeRates; + + const existingContractExchangeRatesForChainId = + this.state.contractExchangeRatesByChainId[chainId] ?? {}; + const updatedContractExchangeRatesForChainId = { + ...this.state.contractExchangeRatesByChainId, + [chainId]: { + ...existingContractExchangeRatesForChainId, + [nativeCurrency]: { + ...existingContractExchangeRatesForChainId[nativeCurrency], + ...newContractExchangeRates, + }, }, - }, - }; + }; - this.update({ - contractExchangeRates: updatedContractExchangeRates, - contractExchangeRatesByChainId: updatedContractExchangeRatesForChainId, - }); + this.update({ + contractExchangeRates: updatedContractExchangeRates, + contractExchangeRatesByChainId: updatedContractExchangeRatesForChainId, + }); + updateSucceeded(); + } catch (error: unknown) { + updateFailed(error); + throw error; + } finally { + delete this.#inProcessExchangeRateUpdates[updateKey]; + } } /** @@ -548,4 +574,60 @@ export class TokenRatesController extends PollingControllerV1< } } +/** + * A deferred Promise. + * + * A deferred Promise is one that can be resolved or rejected independently of + * the Promise construction. + */ +type DeferredPromise = { + /** + * The Promise that has been deferred. + */ + promise: Promise; + /** + * A function that resolves the Promise. + */ + resolve: () => void; + /** + * A function that rejects the Promise. + */ + reject: (error: unknown) => void; +}; + +/** + * Create a defered Promise. + * + * TODO: Migrate this to utils + * + * @param args - The arguments. + * @param args.suppressUnhandledRejection - This option adds an empty error handler + * to the Promise to suppress the UnhandledPromiseRejection error. This can be + * useful if the deferred Promise is sometimes intentionally not used. + * @returns A deferred Promise. + */ +function createDeferredPromise({ + suppressUnhandledRejection = false, +}: { + suppressUnhandledRejection: boolean; +}): DeferredPromise { + let resolve: DeferredPromise['resolve']; + let reject: DeferredPromise['reject']; + const promise = new Promise( + (innerResolve: () => void, innerReject: () => void) => { + resolve = innerResolve; + reject = innerReject; + }, + ); + + if (suppressUnhandledRejection) { + promise.catch((_error) => { + // This handler is used to suppress the UnhandledPromiseRejection error + }); + } + + // @ts-expect-error We know that these are assigned, but TypeScript doesn't + return { promise, resolve, reject }; +} + export default TokenRatesController;