diff --git a/packages/assets-controllers/src/AccountTrackerController.ts b/packages/assets-controllers/src/AccountTrackerController.ts index e170d4adb7..5aa65496c0 100644 --- a/packages/assets-controllers/src/AccountTrackerController.ts +++ b/packages/assets-controllers/src/AccountTrackerController.ts @@ -11,7 +11,7 @@ import type { NetworkController, NetworkState, } from '@metamask/network-controller'; -import { PollingControllerV1 } from '@metamask/polling-controller'; +import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import { assert } from '@metamask/utils'; import { Mutex } from 'async-mutex'; @@ -61,7 +61,7 @@ export interface AccountTrackerState extends BaseState { /** * Controller that tracks the network balances for all user accounts. */ -export class AccountTrackerController extends PollingControllerV1< +export class AccountTrackerController extends StaticIntervalPollingControllerV1< AccountTrackerConfig, AccountTrackerState > { diff --git a/packages/assets-controllers/src/CurrencyRateController.ts b/packages/assets-controllers/src/CurrencyRateController.ts index 41c1eee8fc..bf3509f310 100644 --- a/packages/assets-controllers/src/CurrencyRateController.ts +++ b/packages/assets-controllers/src/CurrencyRateController.ts @@ -11,7 +11,7 @@ import type { NetworkClientId, NetworkControllerGetNetworkClientByIdAction, } from '@metamask/network-controller'; -import { PollingController } from '@metamask/polling-controller'; +import { StaticIntervalPollingController } from '@metamask/polling-controller'; import { Mutex } from 'async-mutex'; import { fetchExchangeRate as defaultFetchExchangeRate } from './crypto-compare'; @@ -82,7 +82,7 @@ const defaultState = { * Controller that passively polls on a set interval for an exchange rate from the current network * asset to the user's preferred currency. */ -export class CurrencyRateController extends PollingController< +export class CurrencyRateController extends StaticIntervalPollingController< typeof name, CurrencyRateState, CurrencyRateMessenger diff --git a/packages/assets-controllers/src/NftDetectionController.ts b/packages/assets-controllers/src/NftDetectionController.ts index 5ed01bca72..9db76bb073 100644 --- a/packages/assets-controllers/src/NftDetectionController.ts +++ b/packages/assets-controllers/src/NftDetectionController.ts @@ -11,7 +11,7 @@ import type { NetworkState, NetworkClient, } from '@metamask/network-controller'; -import { PollingControllerV1 } from '@metamask/polling-controller'; +import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; @@ -147,7 +147,7 @@ export interface NftDetectionConfig extends BaseConfig { /** * Controller that passively polls on a set interval for NFT auto detection */ -export class NftDetectionController extends PollingControllerV1< +export class NftDetectionController extends StaticIntervalPollingControllerV1< NftDetectionConfig, BaseState > { diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 2b877058e1..30a3249f18 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -8,7 +8,7 @@ import type { NetworkController, NetworkState, } from '@metamask/network-controller'; -import { PollingControllerV1 } from '@metamask/polling-controller'; +import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; @@ -44,7 +44,7 @@ export interface TokenDetectionConfig extends BaseConfig { /** * Controller that passively polls on a set interval for Tokens auto detection */ -export class TokenDetectionController extends PollingControllerV1< +export class TokenDetectionController extends StaticIntervalPollingControllerV1< TokenDetectionConfig, BaseState > { diff --git a/packages/assets-controllers/src/TokenListController.ts b/packages/assets-controllers/src/TokenListController.ts index 3603760fef..ba89c200c9 100644 --- a/packages/assets-controllers/src/TokenListController.ts +++ b/packages/assets-controllers/src/TokenListController.ts @@ -10,7 +10,7 @@ import type { NetworkState, NetworkControllerGetNetworkClientByIdAction, } from '@metamask/network-controller'; -import { PollingController } from '@metamask/polling-controller'; +import { StaticIntervalPollingController } from '@metamask/polling-controller'; import type { Hex } from '@metamask/utils'; import { Mutex } from 'async-mutex'; @@ -91,7 +91,7 @@ const defaultState: TokenListState = { /** * Controller that passively polls on a set interval for the list of tokens from metaswaps api */ -export class TokenListController extends PollingController< +export class TokenListController extends StaticIntervalPollingController< typeof name, TokenListState, TokenListMessenger diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index 8a2a08f69b..14bfdd007a 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -10,7 +10,7 @@ import type { NetworkController, NetworkState, } from '@metamask/network-controller'; -import { PollingControllerV1 } from '@metamask/polling-controller'; +import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import { isDeepStrictEqual } from 'util'; @@ -136,7 +136,7 @@ async function getCurrencyConversionRate({ * Controller that passively polls on a set interval for token-to-fiat exchange rates * for tokens stored in the TokensController */ -export class TokenRatesController extends PollingControllerV1< +export class TokenRatesController extends StaticIntervalPollingControllerV1< TokenRatesConfig, TokenRatesState > { diff --git a/packages/gas-fee-controller/src/GasFeeController.ts b/packages/gas-fee-controller/src/GasFeeController.ts index 94f25576c5..ecb397e607 100644 --- a/packages/gas-fee-controller/src/GasFeeController.ts +++ b/packages/gas-fee-controller/src/GasFeeController.ts @@ -18,7 +18,7 @@ import type { NetworkState, ProviderProxy, } from '@metamask/network-controller'; -import { PollingController } from '@metamask/polling-controller'; +import { StaticIntervalPollingController } from '@metamask/polling-controller'; import type { Hex } from '@metamask/utils'; import { v1 as random } from 'uuid'; @@ -253,7 +253,7 @@ const defaultState: GasFeeState = { /** * Controller that retrieves gas fee estimate data and polls for updated data on a set interval */ -export class GasFeeController extends PollingController< +export class GasFeeController extends StaticIntervalPollingController< typeof name, GasFeeState, GasFeeMessenger diff --git a/packages/polling-controller/src/AbstractPollingController.ts b/packages/polling-controller/src/AbstractPollingController.ts new file mode 100644 index 0000000000..621d481125 --- /dev/null +++ b/packages/polling-controller/src/AbstractPollingController.ts @@ -0,0 +1,138 @@ +import type { NetworkClientId } from '@metamask/network-controller'; +import type { Json } from '@metamask/utils'; +import stringify from 'fast-json-stable-stringify'; +import { v4 as random } from 'uuid'; + +export type IPollingController = { + startPollingByNetworkClientId( + networkClientId: NetworkClientId, + options: Json, + ): string; + + stopAllPolling(): void; + + stopPollingByPollingToken(pollingToken: string): void; + + onPollingCompleteByNetworkClientId( + networkClientId: NetworkClientId, + callback: (networkClientId: NetworkClientId) => void, + options: Json, + ): void; + + _executePoll(networkClientId: NetworkClientId, options: Json): Promise; + _startPollingByNetworkClientId( + networkClientId: NetworkClientId, + options: Json, + ): void; + _stopPollingByPollingTokenSetId(key: PollingTokenSetId): void; +}; + +export const getKey = ( + networkClientId: NetworkClientId, + options: Json, +): PollingTokenSetId => `${networkClientId}:${stringify(options)}`; + +export type PollingTokenSetId = `${NetworkClientId}:${string}`; + +type Constructor = new (...args: any[]) => object; + +/** + * AbstractPollingControllerBaseMixin + * + * @param Base - The base class to mix onto. + * @returns The composed class. + */ +export function AbstractPollingControllerBaseMixin( + Base: TBase, +) { + abstract class AbstractPollingControllerBase + extends Base + implements IPollingController + { + readonly #pollingTokenSets: Map> = new Map(); + + #callbacks: Map< + PollingTokenSetId, + Set<(PollingTokenSetId: PollingTokenSetId) => void> + > = new Map(); + + abstract _executePoll( + networkClientId: NetworkClientId, + options: Json, + ): Promise; + + abstract _startPollingByNetworkClientId( + networkClientId: NetworkClientId, + options: Json, + ): void; + + abstract _stopPollingByPollingTokenSetId(key: PollingTokenSetId): void; + + startPollingByNetworkClientId( + networkClientId: NetworkClientId, + options: Json = {}, + ): string { + const pollToken = random(); + const key = getKey(networkClientId, options); + const pollingTokenSet = + this.#pollingTokenSets.get(key) ?? new Set(); + pollingTokenSet.add(pollToken); + this.#pollingTokenSets.set(key, pollingTokenSet); + + if (pollingTokenSet.size === 1) { + this._startPollingByNetworkClientId(networkClientId, options); + } + + return pollToken; + } + + stopAllPolling() { + this.#pollingTokenSets.forEach((tokenSet, _key) => { + tokenSet.forEach((token) => { + this.stopPollingByPollingToken(token); + }); + }); + } + + stopPollingByPollingToken(pollingToken: string) { + if (!pollingToken) { + throw new Error('pollingToken required'); + } + + let keyToDelete: PollingTokenSetId | null = null; + for (const [key, tokenSet] of this.#pollingTokenSets) { + if (tokenSet.delete(pollingToken)) { + if (tokenSet.size === 0) { + keyToDelete = key; + } + break; + } + } + + if (keyToDelete) { + this._stopPollingByPollingTokenSetId(keyToDelete); + this.#pollingTokenSets.delete(keyToDelete); + const callbacks = this.#callbacks.get(keyToDelete); + if (callbacks) { + for (const callback of callbacks) { + // eslint-disable-next-line n/callback-return + callback(keyToDelete); + } + callbacks.clear(); + } + } + } + + onPollingCompleteByNetworkClientId( + networkClientId: NetworkClientId, + callback: (networkClientId: NetworkClientId) => void, + options: Json = {}, + ) { + const key = getKey(networkClientId, options); + const callbacks = this.#callbacks.get(key) ?? new Set(); + callbacks.add(callback); + this.#callbacks.set(key, callbacks); + } + } + return AbstractPollingControllerBase; +} diff --git a/packages/polling-controller/src/BlockTrackerPollingController.test.ts b/packages/polling-controller/src/BlockTrackerPollingController.test.ts new file mode 100644 index 0000000000..6205fb8642 --- /dev/null +++ b/packages/polling-controller/src/BlockTrackerPollingController.test.ts @@ -0,0 +1,268 @@ +import { ControllerMessenger } from '@metamask/base-controller'; +import type { NetworkClient } from '@metamask/network-controller'; +import EventEmitter from 'events'; +import { useFakeTimers } from 'sinon'; + +import { BlockTrackerPollingController } from './BlockTrackerPollingController'; + +const createExecutePollMock = () => { + const executePollMock = jest.fn().mockImplementation(async () => { + return true; + }); + return executePollMock; +}; + +let getNetworkClientByIdStub: jest.Mock; +class ChildBlockTrackerPollingController extends BlockTrackerPollingController< + any, + any, + any +> { + _executePoll = createExecutePollMock(); + + _getNetworkClientById(networkClientId: string): NetworkClient | undefined { + return getNetworkClientByIdStub(networkClientId); + } +} + +class TestBlockTracker extends EventEmitter { + private latestBlockNumber = 0; + + emitBlockEvent() { + this.latestBlockNumber += 1; + this.emit('latest', this.latestBlockNumber); + } +} + +describe('BlockTrackerPollingController', () => { + let clock: sinon.SinonFakeTimers; + let mockMessenger: any; + let controller: any; + let mainnetBlockTracker: TestBlockTracker; + let goerliBlockTracker: TestBlockTracker; + let sepoliaBlockTracker: TestBlockTracker; + beforeEach(() => { + mockMessenger = new ControllerMessenger(); + controller = new ChildBlockTrackerPollingController({ + messenger: mockMessenger, + metadata: {}, + name: 'PollingController', + state: { foo: 'bar' }, + }); + + mainnetBlockTracker = new TestBlockTracker(); + goerliBlockTracker = new TestBlockTracker(); + sepoliaBlockTracker = new TestBlockTracker(); + + getNetworkClientByIdStub = jest + .fn() + .mockImplementation((networkClientId: string) => { + switch (networkClientId) { + case 'mainnet': + return { + blockTracker: mainnetBlockTracker, + }; + case 'goerli': + return { + blockTracker: goerliBlockTracker, + }; + case 'sepolia': + return { + blockTracker: sepoliaBlockTracker, + }; + default: + throw new Error(`Unknown networkClientId: ${networkClientId}`); + } + }); + clock = useFakeTimers(); + }); + afterEach(() => { + clock.restore(); + }); + + describe('startPollingByNetworkClientId', () => { + it('should call _executePoll on "latest" block events emitted by blockTrackers for each networkClientId passed to startPollingByNetworkClientId', async () => { + controller.startPollingByNetworkClientId('mainnet'); + controller.startPollingByNetworkClientId('goerli'); + // await advanceTime({ clock, duration: 5 }); + mainnetBlockTracker.emitBlockEvent(); + + expect(controller._executePoll).toHaveBeenCalledTimes(1); + expect(controller._executePoll).toHaveBeenCalledWith('mainnet', {}, 1); + + mainnetBlockTracker.emitBlockEvent(); + goerliBlockTracker.emitBlockEvent(); + + expect(controller._executePoll).toHaveBeenNthCalledWith( + 2, + 'mainnet', + {}, + 2, // 2nd block for mainnet + ); + expect(controller._executePoll).toHaveBeenNthCalledWith( + 3, + 'goerli', + {}, + 1, // 1st block for goerli + ); + + mainnetBlockTracker.emitBlockEvent(); + goerliBlockTracker.emitBlockEvent(); + + // sepolioa not being listened to yet, so first block for sepolia will not cause an executePoll + sepoliaBlockTracker.emitBlockEvent(); + + expect(controller._executePoll).toHaveBeenNthCalledWith( + 4, + 'mainnet', + {}, + 3, + ); + expect(controller._executePoll).toHaveBeenNthCalledWith( + 5, + 'goerli', + {}, + 2, + ); + + controller.startPollingByNetworkClientId('sepolia'); + + mainnetBlockTracker.emitBlockEvent(); + sepoliaBlockTracker.emitBlockEvent(); + + expect(controller._executePoll).toHaveBeenNthCalledWith( + 6, + 'mainnet', + {}, + 4, + ); + expect(controller._executePoll).toHaveBeenNthCalledWith( + 7, + 'sepolia', + {}, + 2, + ); + + controller.stopAllPolling(); + }); + }); + + describe('stopPollingByPollingToken', () => { + it('should should stop polling when all polling tokens for a networkClientId are deleted', async () => { + const pollingToken1 = controller.startPollingByNetworkClientId('mainnet'); + + // await advanceTime({ clock, duration: 5 }); + mainnetBlockTracker.emitBlockEvent(); + + expect(controller._executePoll).toHaveBeenCalledTimes(1); + expect(controller._executePoll).toHaveBeenCalledWith('mainnet', {}, 1); + + const pollingToken2 = controller.startPollingByNetworkClientId('mainnet'); + + mainnetBlockTracker.emitBlockEvent(); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ]); + + controller.stopPollingByPollingToken(pollingToken1); + + mainnetBlockTracker.emitBlockEvent(); + + // polling is still active for mainnet because pollingToken2 is still active + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ]); + + controller.stopPollingByPollingToken(pollingToken2); + + mainnetBlockTracker.emitBlockEvent(); + mainnetBlockTracker.emitBlockEvent(); + mainnetBlockTracker.emitBlockEvent(); + + // no further polling should occur regardless of how many blocks are emitted + // because all pollingTokens for mainnet have been deleted + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ]); + }); + + it('should should stop polling for one networkClientId when all polling tokens for that networkClientId are deleted, without stopping polling for networkClientIds with active pollingTokens', async () => { + const pollingToken1 = controller.startPollingByNetworkClientId('mainnet'); + + mainnetBlockTracker.emitBlockEvent(); + + expect(controller._executePoll).toHaveBeenCalledWith('mainnet', {}, 1); + + const pollingToken2 = controller.startPollingByNetworkClientId('mainnet'); + + mainnetBlockTracker.emitBlockEvent(); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ]); + + controller.startPollingByNetworkClientId('goerli'); + + mainnetBlockTracker.emitBlockEvent(); + + // we are polling for mainnet and goerli but goerli has not emitted any blocks yet + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ]); + + controller.stopPollingByPollingToken(pollingToken1); + + mainnetBlockTracker.emitBlockEvent(); + goerliBlockTracker.emitBlockEvent(); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ['mainnet', {}, 4], + ['goerli', {}, 1], + ]); + + controller.stopPollingByPollingToken(pollingToken2); + + mainnetBlockTracker.emitBlockEvent(); + mainnetBlockTracker.emitBlockEvent(); + mainnetBlockTracker.emitBlockEvent(); + goerliBlockTracker.emitBlockEvent(); + goerliBlockTracker.emitBlockEvent(); + + // no further polling for mainnet should occur + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}, 1], + ['mainnet', {}, 2], + ['mainnet', {}, 3], + ['mainnet', {}, 4], + ['goerli', {}, 1], + ['goerli', {}, 2], + ['goerli', {}, 3], + ]); + + controller.stopAllPolling(); + }); + }); + + describe('onPollingCompleteByNetworkClientId', () => { + it('should publish "pollingComplete" callback function set by "onPollingCompleteByNetworkClientId" when polling stops', async () => { + const pollingComplete: any = jest.fn(); + controller.onPollingCompleteByNetworkClientId('mainnet', pollingComplete); + const pollingToken = controller.startPollingByNetworkClientId('mainnet'); + controller.stopPollingByPollingToken(pollingToken); + expect(pollingComplete).toHaveBeenCalledTimes(1); + expect(pollingComplete).toHaveBeenCalledWith('mainnet:{}'); + }); + }); +}); diff --git a/packages/polling-controller/src/BlockTrackerPollingController.ts b/packages/polling-controller/src/BlockTrackerPollingController.ts new file mode 100644 index 0000000000..07d5395562 --- /dev/null +++ b/packages/polling-controller/src/BlockTrackerPollingController.ts @@ -0,0 +1,87 @@ +import { BaseController, BaseControllerV1 } from '@metamask/base-controller'; +import type { + NetworkClientId, + NetworkClient, +} from '@metamask/network-controller'; +import type { Json } from '@metamask/utils'; + +import { + AbstractPollingControllerBaseMixin, + getKey, +} from './AbstractPollingController'; +import type { PollingTokenSetId } from './AbstractPollingController'; + +type Constructor = new (...args: any[]) => object; + +/** + * BlockTrackerPollingControllerMixin + * A polling controller that polls using a block tracker. + * + * @param Base - The base class to mix onto. + * @returns The composed class. + */ +function BlockTrackerPollingControllerMixin( + Base: TBase, +) { + abstract class BlockTrackerPollingController extends AbstractPollingControllerBaseMixin( + Base, + ) { + #activeListeners: Record Promise> = {}; + + abstract _getNetworkClientById( + networkClientId: NetworkClientId, + ): NetworkClient | undefined; + + _startPollingByNetworkClientId( + networkClientId: NetworkClientId, + options: Json, + ) { + const key = getKey(networkClientId, options); + + if (this.#activeListeners[key]) { + return; + } + + const networkClient = this._getNetworkClientById(networkClientId); + if (networkClient) { + const updateOnNewBlock = this._executePoll.bind( + this, + networkClientId, + options, + ); + networkClient.blockTracker.addListener('latest', updateOnNewBlock); + this.#activeListeners[key] = updateOnNewBlock; + } else { + throw new Error( + `Unable to retrieve blockTracker for networkClientId ${networkClientId}`, + ); + } + } + + _stopPollingByPollingTokenSetId(key: PollingTokenSetId) { + const [networkClientId] = key.split(':'); + const networkClient = this._getNetworkClientById( + networkClientId as NetworkClientId, + ); + + if (networkClient && this.#activeListeners[key]) { + const listener = this.#activeListeners[key]; + if (listener) { + networkClient.blockTracker.removeListener('latest', listener); + delete this.#activeListeners[key]; + } + } + } + } + + return BlockTrackerPollingController; +} + +class Empty {} + +export const BlockTrackerPollingControllerOnly = + BlockTrackerPollingControllerMixin(Empty); +export const BlockTrackerPollingController = + BlockTrackerPollingControllerMixin(BaseController); +export const BlockTrackerPollingControllerV1 = + BlockTrackerPollingControllerMixin(BaseControllerV1); diff --git a/packages/polling-controller/src/PollingController.test.ts b/packages/polling-controller/src/PollingController.test.ts deleted file mode 100644 index b6a9344de5..0000000000 --- a/packages/polling-controller/src/PollingController.test.ts +++ /dev/null @@ -1,386 +0,0 @@ -import { ControllerMessenger } from '@metamask/base-controller'; -import { useFakeTimers } from 'sinon'; - -import { advanceTime } from '../../../tests/helpers'; -import { PollingController, PollingControllerOnly } from './PollingController'; - -const TICK_TIME = 1000; - -const createExecutePollMock = () => { - const executePollMock = jest.fn().mockImplementation(async () => { - return true; - }); - return executePollMock; -}; - -describe('PollingController', () => { - let clock: sinon.SinonFakeTimers; - beforeEach(() => { - clock = useFakeTimers(); - }); - afterEach(() => { - clock.restore(); - }); - describe('start', () => { - it('should start polling if not polling', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME }); - expect(controller._executePoll).toHaveBeenCalledTimes(2); - controller.stopAllPolling(); - }); - }); - describe('stop', () => { - it('should stop polling when called with a valid polling that was the only active pollingToken for a given networkClient', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - const pollingToken = controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME }); - controller.stopPollingByPollingToken(pollingToken); - await advanceTime({ clock, duration: TICK_TIME }); - expect(controller._executePoll).toHaveBeenCalledTimes(2); - controller.stopAllPolling(); - }); - it('should not stop polling if called with one of multiple active polling tokens for a given networkClient', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - const pollingToken1 = controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - - controller.startPollingByNetworkClientId('mainnet'); - expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME }); - controller.stopPollingByPollingToken(pollingToken1); - await advanceTime({ clock, duration: TICK_TIME }); - expect(controller._executePoll).toHaveBeenCalledTimes(3); - controller.stopAllPolling(); - }); - it('should error if no pollingToken is passed', () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.startPollingByNetworkClientId('mainnet'); - expect(() => { - controller.stopPollingByPollingToken(undefined as unknown as any); - }).toThrow('pollingToken required'); - controller.stopAllPolling(); - }); - it('should error if no matching pollingToken is found', () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.startPollingByNetworkClientId('mainnet'); - expect(() => { - controller.stopPollingByPollingToken('potato'); - }).toThrow('pollingToken not found'); - controller.stopAllPolling(); - }); - }); - describe('startPollingByNetworkClientId', () => { - it('should call _executePoll immediately and on interval if polling', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME * 2 }); - expect(controller._executePoll).toHaveBeenCalledTimes(3); - }); - it('should call _executePoll immediately once and continue calling _executePoll on interval when start is called again with the same networkClientId', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - - controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - - expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME * 2 }); - - expect(controller._executePoll).toHaveBeenCalledTimes(3); - controller.stopAllPolling(); - }); - it('should publish "pollingComplete" when stop is called', async () => { - const pollingComplete: any = jest.fn(); - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const name = 'PollingController'; - - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name, - state: { foo: 'bar' }, - }); - controller.onPollingCompleteByNetworkClientId('mainnet', pollingComplete); - const pollingToken = controller.startPollingByNetworkClientId('mainnet'); - controller.stopPollingByPollingToken(pollingToken); - expect(pollingComplete).toHaveBeenCalledTimes(1); - }); - it('should poll at the interval length when set via setIntervalLength', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.setIntervalLength(TICK_TIME); - controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME / 2 }); - - expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME / 2 }); - - expect(controller._executePoll).toHaveBeenCalledTimes(2); - }); - it('should start and stop polling sessions for different networkClientIds with the same options', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - const pollToken1 = controller.startPollingByNetworkClientId('mainnet', { - address: '0x1', - }); - controller.startPollingByNetworkClientId('mainnet', { address: '0x2' }); - await advanceTime({ clock, duration: 0 }); - - controller.startPollingByNetworkClientId('sepolia', { address: '0x2' }); - await advanceTime({ clock, duration: 0 }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', { address: '0x1' }], - ['mainnet', { address: '0x2' }], - ['sepolia', { address: '0x2' }], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', { address: '0x1' }], - ['mainnet', { address: '0x2' }], - ['sepolia', { address: '0x2' }], - ['mainnet', { address: '0x1' }], - ['mainnet', { address: '0x2' }], - ['sepolia', { address: '0x2' }], - ]); - controller.stopPollingByPollingToken(pollToken1); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', { address: '0x1' }], - ['mainnet', { address: '0x2' }], - ['sepolia', { address: '0x2' }], - ['mainnet', { address: '0x1' }], - ['mainnet', { address: '0x2' }], - ['sepolia', { address: '0x2' }], - ['mainnet', { address: '0x2' }], - ['sepolia', { address: '0x2' }], - ]); - }); - }); - describe('multiple networkClientIds', () => { - it('should poll for each networkClientId', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - - controller.startPollingByNetworkClientId('rinkeby'); - await advanceTime({ clock, duration: 0 }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['rinkeby', {}], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['rinkeby', {}], - ['mainnet', {}], - ['rinkeby', {}], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['rinkeby', {}], - ['mainnet', {}], - ['rinkeby', {}], - ['mainnet', {}], - ['rinkeby', {}], - ]); - controller.stopAllPolling(); - }); - - it('should poll multiple networkClientIds when setting interval length', async () => { - class MyGasFeeController extends PollingController { - _executePoll = createExecutePollMock(); - } - const mockMessenger = new ControllerMessenger(); - - const controller = new MyGasFeeController({ - messenger: mockMessenger, - metadata: {}, - name: 'PollingController', - state: { foo: 'bar' }, - }); - controller.setIntervalLength(TICK_TIME * 2); - controller.startPollingByNetworkClientId('mainnet'); - await advanceTime({ clock, duration: 0 }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - controller.startPollingByNetworkClientId('sepolia'); - await advanceTime({ clock, duration: 0 }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['sepolia', {}], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['sepolia', {}], - ['mainnet', {}], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['sepolia', {}], - ['mainnet', {}], - ['sepolia', {}], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['sepolia', {}], - ['mainnet', {}], - ['sepolia', {}], - ['mainnet', {}], - ]); - await advanceTime({ clock, duration: TICK_TIME }); - - expect(controller._executePoll.mock.calls).toMatchObject([ - ['mainnet', {}], - ['sepolia', {}], - ['mainnet', {}], - ['sepolia', {}], - ['mainnet', {}], - ['sepolia', {}], - ]); - }); - }); - describe('PollingControllerOnly', () => { - it('can be extended from and constructed', async () => { - class MyClass extends PollingControllerOnly { - _executePoll = createExecutePollMock(); - } - const c = new MyClass(); - expect(c._executePoll).toBeDefined(); - expect(c.getIntervalLength).toBeDefined(); - expect(c.setIntervalLength).toBeDefined(); - expect(c.stopAllPolling).toBeDefined(); - expect(c.startPollingByNetworkClientId).toBeDefined(); - expect(c.stopPollingByPollingToken).toBeDefined(); - }); - }); -}); diff --git a/packages/polling-controller/src/PollingController.ts b/packages/polling-controller/src/PollingController.ts deleted file mode 100644 index 0327a7806e..0000000000 --- a/packages/polling-controller/src/PollingController.ts +++ /dev/null @@ -1,194 +0,0 @@ -import { BaseController, BaseControllerV1 } from '@metamask/base-controller'; -import type { NetworkClientId } from '@metamask/network-controller'; -import type { Json } from '@metamask/utils'; -import stringify from 'fast-json-stable-stringify'; -import { v4 as random } from 'uuid'; - -// Mixin classes require a constructor with an `...any[]` parameter -// See TS2545 -type Constructor = new (...args: any[]) => object; - -/** - * Returns a unique key for a networkClientId and options. This is used to group networkClientId polls with the same options - * @param networkClientId - The networkClientId to get a key for - * @param options - The options used to group the polling events - * @returns The unique key - */ -export const getKey = ( - networkClientId: NetworkClientId, - options: Json, -): PollingTokenSetId => `${networkClientId}:${stringify(options)}`; - -type PollingTokenSetId = `${NetworkClientId}:${string}`; -/** - * PollingControllerMixin - * - * @param Base - The base class to mix onto. - * @returns The composed class. - */ -function PollingControllerMixin(Base: TBase) { - /** - * PollingController is an abstract class that implements the polling - * functionality for a controller. It is meant to be extended by a controller - * that needs to poll for data by networkClientId. - * - */ - abstract class PollingControllerBase extends Base { - readonly #pollingTokenSets: Map> = new Map(); - - readonly #intervalIds: Record = {}; - - #callbacks: Map< - NetworkClientId, - Set<(networkClientId: NetworkClientId) => void> - > = new Map(); - - #intervalLength = 1000; - - getIntervalLength() { - return this.#intervalLength; - } - - /** - * Sets the length of the polling interval - * - * @param length - The length of the polling interval in milliseconds - */ - setIntervalLength(length: number) { - this.#intervalLength = length; - } - - /** - * Starts polling for a networkClientId - * - * @param networkClientId - The networkClientId to start polling for - * @param options - The options used to group the polling events - * @returns void - */ - startPollingByNetworkClientId( - networkClientId: NetworkClientId, - options: Json = {}, - ) { - const pollToken = random(); - - const key = getKey(networkClientId, options); - - const pollingTokenSet = this.#pollingTokenSets.get(key); - if (pollingTokenSet) { - pollingTokenSet.add(pollToken); - } else { - const set = new Set(); - set.add(pollToken); - this.#pollingTokenSets.set(key, set); - } - this.#poll(networkClientId, options); - return pollToken; - } - - /** - * Stops polling for all networkClientIds - */ - stopAllPolling() { - this.#pollingTokenSets.forEach((tokenSet, _networkClientId) => { - tokenSet.forEach((token) => { - this.stopPollingByPollingToken(token); - }); - }); - } - - /** - * Stops polling for a networkClientId - * - * @param pollingToken - The polling token to stop polling for - */ - stopPollingByPollingToken(pollingToken: string) { - if (!pollingToken) { - throw new Error('pollingToken required'); - } - let found = false; - this.#pollingTokenSets.forEach((tokenSet, key) => { - if (tokenSet.has(pollingToken)) { - found = true; - tokenSet.delete(pollingToken); - if (tokenSet.size === 0) { - clearTimeout(this.#intervalIds[key]); - delete this.#intervalIds[key]; - this.#pollingTokenSets.delete(key); - this.#callbacks.get(key)?.forEach((callback) => { - callback(key); - }); - this.#callbacks.get(key)?.clear(); - } - } - }); - if (!found) { - throw new Error('pollingToken not found'); - } - } - - /** - * Executes the poll for a networkClientId - * - * @param networkClientId - The networkClientId to execute the poll for - * @param options - The options passed to startPollingByNetworkClientId - */ - abstract _executePoll( - networkClientId: NetworkClientId, - options: Json, - ): Promise; - - #poll(networkClientId: NetworkClientId, options: Json) { - const key = getKey(networkClientId, options); - const interval = this.#intervalIds[key]; - if (interval) { - clearTimeout(interval); - delete this.#intervalIds[key]; - } - // setTimeout is not `await`ing this async function, which is expected - // We're just using async here for improved stack traces - // eslint-disable-next-line @typescript-eslint/no-misused-promises - this.#intervalIds[key] = setTimeout( - async () => { - try { - await this._executePoll(networkClientId, options); - } catch (error) { - console.error(error); - } - this.#poll(networkClientId, options); - }, - interval ? this.#intervalLength : 0, - ); - } - - /** - * Adds a callback to execute when polling is complete - * - * @param networkClientId - The networkClientId to listen for polling complete events - * @param callback - The callback to execute when polling is complete - * @param options - The options used to group the polling events - */ - onPollingCompleteByNetworkClientId( - networkClientId: NetworkClientId, - callback: (networkClientId: NetworkClientId) => void, - options: Json = {}, - ) { - const key = getKey(networkClientId, options); - const callbacks = this.#callbacks.get(key); - - if (callbacks === undefined) { - const set = new Set(); - set.add(callback); - this.#callbacks.set(key, set); - } else { - callbacks.add(callback); - } - } - } - return PollingControllerBase; -} - -class Empty {} - -export const PollingControllerOnly = PollingControllerMixin(Empty); -export const PollingController = PollingControllerMixin(BaseController); -export const PollingControllerV1 = PollingControllerMixin(BaseControllerV1); diff --git a/packages/polling-controller/src/StaticIntervalPollingController.test.ts b/packages/polling-controller/src/StaticIntervalPollingController.test.ts new file mode 100644 index 0000000000..196d886050 --- /dev/null +++ b/packages/polling-controller/src/StaticIntervalPollingController.test.ts @@ -0,0 +1,237 @@ +import { ControllerMessenger } from '@metamask/base-controller'; +import { useFakeTimers } from 'sinon'; + +import { advanceTime } from '../../../tests/helpers'; +import { StaticIntervalPollingController } from './StaticIntervalPollingController'; + +const TICK_TIME = 5; + +const createExecutePollMock = () => { + const executePollMock = jest.fn().mockImplementation(async () => { + return true; + }); + return executePollMock; +}; + +class ChildBlockTrackerPollingController extends StaticIntervalPollingController< + any, + any, + any +> { + _executePoll = createExecutePollMock(); +} + +describe('StaticIntervalPollingController', () => { + let clock: sinon.SinonFakeTimers; + let mockMessenger: any; + let controller: any; + beforeEach(() => { + mockMessenger = new ControllerMessenger(); + controller = new ChildBlockTrackerPollingController({ + messenger: mockMessenger, + metadata: {}, + name: 'PollingController', + state: { foo: 'bar' }, + }); + controller.setIntervalLength(TICK_TIME); + clock = useFakeTimers(); + }); + afterEach(() => { + clock.restore(); + }); + + describe('startPollingByNetworkClientId', () => { + it('should start polling if not already polling', async () => { + controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + expect(controller._executePoll).toHaveBeenCalledTimes(1); + await advanceTime({ clock, duration: TICK_TIME }); + expect(controller._executePoll).toHaveBeenCalledTimes(2); + controller.stopAllPolling(); + }); + + it('should call _executePoll immediately once and continue calling _executePoll on interval when called again with the same networkClientId', async () => { + controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + + controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + + expect(controller._executePoll).toHaveBeenCalledTimes(1); + await advanceTime({ clock, duration: TICK_TIME * 2 }); + + expect(controller._executePoll).toHaveBeenCalledTimes(3); + controller.stopAllPolling(); + }); + describe('multiple networkClientIds', () => { + it('should poll for each networkClientId', async () => { + controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + + controller.startPollingByNetworkClientId('rinkeby'); + await advanceTime({ clock, duration: 0 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['rinkeby', {}], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['rinkeby', {}], + ['mainnet', {}], + ['rinkeby', {}], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['rinkeby', {}], + ['mainnet', {}], + ['rinkeby', {}], + ['mainnet', {}], + ['rinkeby', {}], + ]); + controller.stopAllPolling(); + }); + + it('should poll multiple networkClientIds when setting interval length', async () => { + controller.setIntervalLength(TICK_TIME * 2); + controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + controller.startPollingByNetworkClientId('sepolia'); + await advanceTime({ clock, duration: 0 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['sepolia', {}], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], + ['sepolia', {}], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], + ['sepolia', {}], + ['mainnet', {}], + ['sepolia', {}], + ]); + }); + }); + }); + + describe('stopPollingByPollingToken', () => { + it('should stop polling when called with a valid polling that was the only active pollingToken for a given networkClient', async () => { + const pollingToken = controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + expect(controller._executePoll).toHaveBeenCalledTimes(1); + await advanceTime({ clock, duration: TICK_TIME }); + controller.stopPollingByPollingToken(pollingToken); + await advanceTime({ clock, duration: TICK_TIME }); + expect(controller._executePoll).toHaveBeenCalledTimes(2); + controller.stopAllPolling(); + }); + it('should not stop polling if called with one of multiple active polling tokens for a given networkClient', async () => { + const pollingToken1 = controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + + controller.startPollingByNetworkClientId('mainnet'); + expect(controller._executePoll).toHaveBeenCalledTimes(1); + await advanceTime({ clock, duration: TICK_TIME }); + controller.stopPollingByPollingToken(pollingToken1); + await advanceTime({ clock, duration: TICK_TIME }); + expect(controller._executePoll).toHaveBeenCalledTimes(3); + controller.stopAllPolling(); + }); + it('should error if no pollingToken is passed', () => { + controller.startPollingByNetworkClientId('mainnet'); + expect(() => { + controller.stopPollingByPollingToken(); + }).toThrow('pollingToken required'); + controller.stopAllPolling(); + }); + + it('should start and stop polling sessions for different networkClientIds with the same options', async () => { + controller.setIntervalLength(TICK_TIME); + const pollToken1 = controller.startPollingByNetworkClientId('mainnet', { + address: '0x1', + }); + controller.startPollingByNetworkClientId('mainnet', { address: '0x2' }); + await advanceTime({ clock, duration: 0 }); + + controller.startPollingByNetworkClientId('sepolia', { address: '0x2' }); + await advanceTime({ clock, duration: 0 }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ]); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ]); + controller.stopPollingByPollingToken(pollToken1); + await advanceTime({ clock, duration: TICK_TIME }); + + expect(controller._executePoll.mock.calls).toMatchObject([ + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ['mainnet', { address: '0x1' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ['mainnet', { address: '0x2' }], + ['sepolia', { address: '0x2' }], + ]); + }); + }); + + describe('onPollingCompleteByNetworkClientId', () => { + it('should publish "pollingComplete" callback function set by "onPollingCompleteByNetworkClientId" when polling stops', async () => { + const pollingComplete: any = jest.fn(); + controller.onPollingCompleteByNetworkClientId('mainnet', pollingComplete); + const pollingToken = controller.startPollingByNetworkClientId('mainnet'); + controller.stopPollingByPollingToken(pollingToken); + expect(pollingComplete).toHaveBeenCalledTimes(1); + expect(pollingComplete).toHaveBeenCalledWith('mainnet:{}'); + }); + }); +}); diff --git a/packages/polling-controller/src/StaticIntervalPollingController.ts b/packages/polling-controller/src/StaticIntervalPollingController.ts new file mode 100644 index 0000000000..fb2f5bb4c6 --- /dev/null +++ b/packages/polling-controller/src/StaticIntervalPollingController.ts @@ -0,0 +1,86 @@ +import { BaseController, BaseControllerV1 } from '@metamask/base-controller'; +import type { NetworkClientId } from '@metamask/network-controller'; +import type { Json } from '@metamask/utils'; + +import { + AbstractPollingControllerBaseMixin, + getKey, +} from './AbstractPollingController'; +import type { + IPollingController, + PollingTokenSetId, +} from './AbstractPollingController'; + +type Constructor = new (...args: any[]) => object; + +/** + * StaticIntervalPollingControllerMixin + * A polling controller that polls on a static interval. + * + * @param Base - The base class to mix onto. + * @returns The composed class. + */ +function StaticIntervalPollingControllerMixin( + Base: TBase, +) { + abstract class StaticIntervalPollingController + extends AbstractPollingControllerBaseMixin(Base) + implements IPollingController + { + readonly #intervalIds: Record = {}; + + #intervalLength: number | undefined = 1000; + + setIntervalLength(intervalLength: number) { + this.#intervalLength = intervalLength; + } + + getIntervalLength() { + return this.#intervalLength; + } + + _startPollingByNetworkClientId( + networkClientId: NetworkClientId, + options: Json, + ) { + if (!this.#intervalLength) { + throw new Error('intervalLength must be defined and greater than 0'); + } + + const key = getKey(networkClientId, options); + const existingInterval = this.#intervalIds[key]; + this._stopPollingByPollingTokenSetId(key); + + this.#intervalIds[key] = setTimeout( + async () => { + try { + await this._executePoll(networkClientId, options); + } catch (error) { + console.error(error); + } + this._startPollingByNetworkClientId(networkClientId, options); + }, + existingInterval ? this.#intervalLength : 0, + ); + } + + _stopPollingByPollingTokenSetId(key: PollingTokenSetId) { + const intervalId = this.#intervalIds[key]; + if (intervalId) { + clearTimeout(intervalId); + delete this.#intervalIds[key]; + } + } + } + + return StaticIntervalPollingController; +} + +class Empty {} + +export const StaticIntervalPollingControllerOnly = + StaticIntervalPollingControllerMixin(Empty); +export const StaticIntervalPollingController = + StaticIntervalPollingControllerMixin(BaseController); +export const StaticIntervalPollingControllerV1 = + StaticIntervalPollingControllerMixin(BaseControllerV1); diff --git a/packages/polling-controller/src/index.ts b/packages/polling-controller/src/index.ts index e5ed2df383..5bfd5e5366 100644 --- a/packages/polling-controller/src/index.ts +++ b/packages/polling-controller/src/index.ts @@ -1,5 +1,13 @@ export { - PollingController, - PollingControllerV1, - PollingControllerOnly, -} from './PollingController'; + BlockTrackerPollingControllerOnly, + BlockTrackerPollingController, + BlockTrackerPollingControllerV1, +} from './BlockTrackerPollingController'; + +export { + StaticIntervalPollingControllerOnly, + StaticIntervalPollingController, + StaticIntervalPollingControllerV1, +} from './StaticIntervalPollingController'; + +export type { IPollingController } from './AbstractPollingController'; diff --git a/packages/user-operation-controller/src/helpers/PendingUserOperationTracker.ts b/packages/user-operation-controller/src/helpers/PendingUserOperationTracker.ts index 328720e356..914707f21f 100644 --- a/packages/user-operation-controller/src/helpers/PendingUserOperationTracker.ts +++ b/packages/user-operation-controller/src/helpers/PendingUserOperationTracker.ts @@ -1,7 +1,7 @@ import { query } from '@metamask/controller-utils'; import EthQuery from '@metamask/eth-query'; import type { Provider } from '@metamask/network-controller'; -import { PollingControllerOnly } from '@metamask/polling-controller'; +import { StaticIntervalPollingControllerOnly } from '@metamask/polling-controller'; import type { Json } from '@metamask/utils'; import { createModuleLogger } from '@metamask/utils'; import EventEmitter from 'events'; @@ -34,7 +34,7 @@ export type PendingUserOperationTrackerEventEmitter = EventEmitter & { emit(eventName: T, ...args: Events[T]): boolean; }; -export class PendingUserOperationTracker extends PollingControllerOnly { +export class PendingUserOperationTracker extends StaticIntervalPollingControllerOnly { hub: PendingUserOperationTrackerEventEmitter; #getUserOperations: () => UserOperationMetadata[];