diff --git a/packages/polling-controller/package.json b/packages/polling-controller/package.json index fafe44ebe0..c4e7383a91 100644 --- a/packages/polling-controller/package.json +++ b/packages/polling-controller/package.json @@ -34,6 +34,7 @@ "@metamask/network-controller": "^14.0.0", "@metamask/utils": "^8.1.0", "@types/uuid": "^8.3.0", + "fast-json-stable-stringify": "^2.1.0", "uuid": "^8.3.2" }, "devDependencies": { diff --git a/packages/polling-controller/src/PollingController.test.ts b/packages/polling-controller/src/PollingController.test.ts index 81993e7908..caffce0058 100644 --- a/packages/polling-controller/src/PollingController.test.ts +++ b/packages/polling-controller/src/PollingController.test.ts @@ -211,6 +211,40 @@ describe('PollingController', () => { await Promise.resolve(); expect(controller.executePoll).toHaveBeenCalledTimes(2); }); + it('should start and stop polling sessions for different networkClientIds with the same options', async () => { + jest.useFakeTimers(); + + class MyGasFeeController extends PollingController { + executePoll = createExecutePollMock(); + } + const mockMessenger = new ControllerMessenger(); + + const controller = new MyGasFeeController({ + messenger: mockMessenger, + metadata: {}, + name: 'PollingController', + state: { foo: 'bar' }, + }); + const pollToken1 = controller.startPollingByNetworkClientId('mainnet', { + address: '0x1', + }); + controller.startPollingByNetworkClientId('mainnet', { address: '0x2' }); + controller.startPollingByNetworkClientId('sepolia', { address: '0x2' }); + jest.advanceTimersByTime(TICK_TIME); + await Promise.resolve(); + expect(controller.executePoll).toHaveBeenCalledTimes(3); + controller.stopPollingByNetworkClientId(pollToken1); + jest.advanceTimersByTime(TICK_TIME); + await Promise.resolve(); + expect(controller.executePoll).toHaveBeenCalledTimes(5); + expect(controller.executePoll.mock.calls).toMatchObject([ + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ]); + }); }); describe('multiple networkClientIds', () => { it('should poll for each networkClientId', async () => { @@ -231,16 +265,16 @@ describe('PollingController', () => { jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller.executePoll.mock.calls).toMatchObject([ - ['mainnet'], - ['rinkeby'], + ['mainnet', {}], + ['rinkeby', {}], ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller.executePoll.mock.calls).toMatchObject([ - ['mainnet'], - ['rinkeby'], - ['mainnet'], - ['rinkeby'], + ['mainnet', {}], + ['rinkeby', {}], + ['mainnet', {}], + ['rinkeby', {}], ]); controller.stopAllPolling(); }); @@ -267,27 +301,29 @@ describe('PollingController', () => { expect(controller.executePoll.mock.calls).toMatchObject([]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); - expect(controller.executePoll.mock.calls).toMatchObject([['mainnet']]); + expect(controller.executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller.executePoll.mock.calls).toMatchObject([ - ['mainnet'], - ['sepolia'], + ['mainnet', {}], + ['sepolia', {}], ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller.executePoll.mock.calls).toMatchObject([ - ['mainnet'], - ['sepolia'], - ['mainnet'], + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], ]); jest.advanceTimersByTime(TICK_TIME); await Promise.resolve(); expect(controller.executePoll.mock.calls).toMatchObject([ - ['mainnet'], - ['sepolia'], - ['mainnet'], - ['sepolia'], + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], + ['sepolia', {}], ]); }); }); diff --git a/packages/polling-controller/src/PollingController.ts b/packages/polling-controller/src/PollingController.ts index 7a8098a8db..b8a74f09e2 100644 --- a/packages/polling-controller/src/PollingController.ts +++ b/packages/polling-controller/src/PollingController.ts @@ -1,11 +1,25 @@ import { BaseController, BaseControllerV2 } from '@metamask/base-controller'; import type { NetworkClientId } from '@metamask/network-controller'; +import type { Json } from '@metamask/utils'; +import stringify from 'fast-json-stable-stringify'; import { v4 as random } from 'uuid'; // Mixin classes require a constructor with an `...any[]` parameter // See TS2545 type Constructor = new (...args: any[]) => object; +/** + * Returns a unique key for a networkClientId and options. This is used to group networkClientId polls with the same options + * @param networkClientId - The networkClientId to get a key for + * @param options - The options used to group the polling events + * @returns The unique key + */ +export const getKey = ( + networkClientId: NetworkClientId, + options: Json, +): PollingGroupId => `${networkClientId}:${stringify(options)}`; + +type PollingGroupId = `${NetworkClientId}:${string}`; /** * PollingControllerMixin * @@ -20,10 +34,9 @@ function PollingControllerMixin(Base: TBase) { * */ abstract class PollingControllerBase extends Base { - readonly #networkClientIdTokensMap: Map> = - new Map(); + readonly #pollingTokenSets: Map> = new Map(); - readonly #intervalIds: Record = {}; + readonly #intervalIds: Record = {}; #callbacks: Map< NetworkClientId, @@ -49,28 +62,35 @@ function PollingControllerMixin(Base: TBase) { * Starts polling for a networkClientId * * @param networkClientId - The networkClientId to start polling for + * @param options - The options used to group the polling events * @returns void */ - startPollingByNetworkClientId(networkClientId: NetworkClientId) { - const innerPollToken = random(); - if (this.#networkClientIdTokensMap.has(networkClientId)) { - const set = this.#networkClientIdTokensMap.get(networkClientId); - set?.add(innerPollToken); + startPollingByNetworkClientId( + networkClientId: NetworkClientId, + options: Json = {}, + ) { + const pollToken = random(); + + const key = getKey(networkClientId, options); + + const pollingTokenSet = this.#pollingTokenSets.get(key); + if (pollingTokenSet) { + pollingTokenSet.add(pollToken); } else { const set = new Set(); - set.add(innerPollToken); - this.#networkClientIdTokensMap.set(networkClientId, set); + set.add(pollToken); + this.#pollingTokenSets.set(key, set); } - this.#poll(networkClientId); - return innerPollToken; + this.#poll(networkClientId, options); + return pollToken; } /** * Stops polling for all networkClientIds */ stopAllPolling() { - this.#networkClientIdTokensMap.forEach((tokens, _networkClientId) => { - tokens.forEach((token) => { + this.#pollingTokenSets.forEach((tokenSet, _networkClientId) => { + tokenSet.forEach((token) => { this.stopPollingByNetworkClientId(token); }); }); @@ -86,20 +106,18 @@ function PollingControllerMixin(Base: TBase) { throw new Error('pollingToken required'); } let found = false; - this.#networkClientIdTokensMap.forEach((tokens, networkClientId) => { - if (tokens.has(pollingToken)) { + this.#pollingTokenSets.forEach((tokenSet, key) => { + if (tokenSet.has(pollingToken)) { found = true; - this.#networkClientIdTokensMap - .get(networkClientId) - ?.delete(pollingToken); - if (this.#networkClientIdTokensMap.get(networkClientId)?.size === 0) { - clearTimeout(this.#intervalIds[networkClientId]); - delete this.#intervalIds[networkClientId]; - this.#networkClientIdTokensMap.delete(networkClientId); - this.#callbacks.get(networkClientId)?.forEach((callback) => { - callback(networkClientId); + tokenSet.delete(pollingToken); + if (tokenSet.size === 0) { + clearTimeout(this.#intervalIds[key]); + delete this.#intervalIds[key]; + this.#pollingTokenSets.delete(key); + this.#callbacks.get(key)?.forEach((callback) => { + callback(key); }); - this.#callbacks.get(networkClientId)?.clear(); + this.#callbacks.get(key)?.clear(); } } }); @@ -112,24 +130,29 @@ function PollingControllerMixin(Base: TBase) { * Executes the poll for a networkClientId * * @param networkClientId - The networkClientId to execute the poll for + * @param options - The options passed to startPollingByNetworkClientId */ - abstract executePoll(networkClientId: NetworkClientId): Promise; + abstract executePoll( + networkClientId: NetworkClientId, + options: Json, + ): Promise; - #poll(networkClientId: NetworkClientId) { - if (this.#intervalIds[networkClientId]) { - clearTimeout(this.#intervalIds[networkClientId]); - delete this.#intervalIds[networkClientId]; + #poll(networkClientId: NetworkClientId, options: Json) { + const key = getKey(networkClientId, options); + if (this.#intervalIds[key]) { + clearTimeout(this.#intervalIds[key]); + delete this.#intervalIds[key]; } // setTimeout is not `await`ing this async function, which is expected // We're just using async here for improved stack traces // eslint-disable-next-line @typescript-eslint/no-misused-promises - this.#intervalIds[networkClientId] = setTimeout(async () => { + this.#intervalIds[key] = setTimeout(async () => { try { - await this.executePoll(networkClientId); + await this.executePoll(networkClientId, options); } catch (error) { console.error(error); } - this.#poll(networkClientId); + this.#poll(networkClientId, options); }, this.#intervalLength); } @@ -138,17 +161,22 @@ function PollingControllerMixin(Base: TBase) { * * @param networkClientId - The networkClientId to listen for polling complete events * @param callback - The callback to execute when polling is complete + * @param options - The options used to group the polling events */ onPollingCompleteByNetworkClientId( networkClientId: NetworkClientId, callback: (networkClientId: NetworkClientId) => void, + options: Json = {}, ) { - if (this.#callbacks.has(networkClientId)) { - this.#callbacks.get(networkClientId)?.add(callback); - } else { + const key = getKey(networkClientId, options); + const callbacks = this.#callbacks.get(key); + + if (callbacks === undefined) { const set = new Set(); set.add(callback); - this.#callbacks.set(networkClientId, set); + this.#callbacks.set(key, set); + } else { + callbacks.add(callback); } } } diff --git a/yarn.lock b/yarn.lock index 658b9bde47..b395558b73 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2164,6 +2164,7 @@ __metadata: "@types/jest": ^27.4.1 "@types/uuid": ^8.3.0 deepmerge: ^4.2.2 + fast-json-stable-stringify: ^2.1.0 jest: ^27.5.1 ts-jest: ^27.1.4 typedoc: ^0.24.8