diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index ca77659f4b..93c08b7508 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -1,193 +1,348 @@ +import { ControllerMessenger } from '@metamask/base-controller'; +import { toHex } from '@metamask/controller-utils'; import { BN } from 'ethereumjs-util'; -import * as sinon from 'sinon'; -import { advanceTime } from '../../../tests/helpers'; -import { - BN as exportedBn, - TokenBalancesController, -} from './TokenBalancesController'; +import { flushPromises } from '../../../tests/helpers'; +import { TokenBalancesController } from './TokenBalancesController'; +import type { Token } from './TokenRatesController'; import { getDefaultTokensState, type TokensState } from './TokensController'; +const controllerName = 'TokenBalancesController'; + +/** + * Constructs a restricted controller messenger. + * + * @returns A restricted controller messenger. + */ +function getMessenger() { + return new ControllerMessenger().getRestricted< + typeof controllerName, + never, + never + >({ + name: controllerName, + }); +} + describe('TokenBalancesController', () => { - let clock: sinon.SinonFakeTimers; - const getToken = ( - tokenBalances: TokenBalancesController, - address: string, - ) => { - const { tokens } = tokenBalances.config; - return tokens.find((token) => token.address === address); - }; beforeEach(() => { - clock = sinon.useFakeTimers(); + jest.useFakeTimers(); }); afterEach(() => { - clock.restore(); - sinon.restore(); + jest.useRealTimers(); }); - it('should re-export BN', () => { - expect(exportedBn).toStrictEqual(BN); + it('should set default state', () => { + const controller = new TokenBalancesController({ + onTokensStateChange: jest.fn(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn(), + messenger: getMessenger(), + }); + + expect(controller.state).toStrictEqual({ contractBalances: {} }); }); - it('should set default state', () => { - const tokenBalances = new TokenBalancesController({ - onTokensStateChange: sinon.stub(), + it('should poll and update balances in the right interval', async () => { + const updateBalancesSpy = jest.spyOn( + TokenBalancesController.prototype, + 'updateBalances', + ); + + new TokenBalancesController({ + interval: 10, + onTokensStateChange: jest.fn(), getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), + getERC20BalanceOf: jest.fn(), + messenger: getMessenger(), }); - expect(tokenBalances.state).toStrictEqual({ contractBalances: {} }); + await flushPromises(); + + expect(updateBalancesSpy).toHaveBeenCalled(); + expect(updateBalancesSpy).not.toHaveBeenCalledTimes(2); + + jest.advanceTimersByTime(15); + + expect(updateBalancesSpy).toHaveBeenCalledTimes(2); }); - it('should set default config', () => { - const tokenBalances = new TokenBalancesController({ - onTokensStateChange: sinon.stub(), + it('should update balances if enabled', async () => { + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: jest.fn(), getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), }); - expect(tokenBalances.config).toStrictEqual({ - interval: 180000, - tokens: [], + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), }); }); - it('should poll and update balances in the right interval', async () => { - const mock = sinon.stub( - TokenBalancesController.prototype, - 'updateBalances', - ); - new TokenBalancesController( - { - onTokensStateChange: sinon.stub(), - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), - }, - { interval: 10 }, - ); - expect(mock.called).toBe(true); - expect(mock.calledTwice).toBe(false); + it('should not update balances if disabled', async () => { + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: jest.fn(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), + }); + + await controller.updateBalances(); - await advanceTime({ clock, duration: 15 }); - expect(mock.calledTwice).toBe(true); + expect(controller.state.contractBalances).toStrictEqual({}); }); - it('should not update rates if disabled', async () => { - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: sinon.stub(), - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), + it('should update balances if controller is manually enabled', async () => { + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: jest.fn(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), + }); + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({}); + + controller.enable(); + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), + }); + }); + + it('should not update balances if controller is manually disabled', async () => { + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: jest.fn(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), + }); + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), + }); + + controller.disable(); + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), + }); + }); + + it('should update balances if tokens change and controller is manually enabled', async () => { + const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); }, - { - disabled: true, - interval: 10, + messenger: getMessenger(), + }); + const triggerTokensStateChange = async (state: TokensState) => { + for (const listener of tokensStateChangeListeners) { + listener(state); + } + }; + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({}); + + controller.enable(); + await triggerTokensStateChange({ + ...getDefaultTokensState(), + tokens: [ + { + address: '0x00', + symbol: 'FOO', + decimals: 18, + }, + ], + }); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x00': toHex(new BN(1)), + }); + }); + + it('should not update balances if tokens change and controller is manually disabled', async () => { + const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); }, - ); - const mock = sinon.stub(tokenBalances, 'update'); - await tokenBalances.updateBalances(); - expect(mock.called).toBe(false); + messenger: getMessenger(), + }); + const triggerTokensStateChange = async (state: TokensState) => { + for (const listener of tokensStateChangeListeners) { + listener(state); + } + }; + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), + }); + + controller.disable(); + await triggerTokensStateChange({ + ...getDefaultTokensState(), + tokens: [ + { + address: '0x00', + symbol: 'FOO', + decimals: 18, + }, + ], + }); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), + }); }); it('should clear previous interval', async () => { - const mock = sinon.stub(global, 'clearTimeout'); - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: sinon.stub(), - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), - }, - { interval: 1337 }, - ); - await tokenBalances.poll(1338); - await advanceTime({ clock, duration: 1339 }); - expect(mock.called).toBe(true); + const controller = new TokenBalancesController({ + interval: 1337, + onTokensStateChange: jest.fn(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn(), + messenger: getMessenger(), + }); + + const mockClearTimeout = jest.spyOn(global, 'clearTimeout'); + + await controller.poll(1338); + + jest.advanceTimersByTime(1339); + + expect(mockClearTimeout).toHaveBeenCalled(); }); it('should update all balances', async () => { const selectedAddress = '0x0000000000000000000000000000000000000001'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: jest.fn(), - getSelectedAddress: () => selectedAddress, - getERC20BalanceOf: sinon.stub().returns(new BN(1)), - }, + const tokens: Token[] = [ { - interval: 1337, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + address, + decimals: 18, + symbol: 'EOS', + aggregators: [], }, - ); - expect(tokenBalances.state.contractBalances).toStrictEqual({}); + ]; + const controller = new TokenBalancesController({ + interval: 1337, + tokens, + onTokensStateChange: jest.fn(), + getSelectedAddress: () => selectedAddress, + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), + }); - await tokenBalances.updateBalances(); - const mytoken = getToken(tokenBalances, address); - expect(mytoken?.balanceError).toBeNull(); - expect(Object.keys(tokenBalances.state.contractBalances)).toContain( - address, - ); + expect(controller.state.contractBalances).toStrictEqual({}); + + await controller.updateBalances(); - expect( - tokenBalances.state.contractBalances[address].toNumber(), - ).toBeGreaterThan(0); + expect(tokens[0].balanceError).toBeNull(); + expect(Object.keys(controller.state.contractBalances)).toContain(address); + expect(controller.state.contractBalances[address]).not.toBe(toHex(0)); }); it('should handle `getERC20BalanceOf` error case', async () => { const errorMsg = 'Failed to get balance'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const getERC20BalanceOfStub = sinon - .stub() - .returns(Promise.reject(new Error(errorMsg))); - const tokenBalances = new TokenBalancesController( + const getERC20BalanceOfStub = jest + .fn() + .mockReturnValue(Promise.reject(new Error(errorMsg))); + const tokens: Token[] = [ { - onTokensStateChange: jest.fn(), - getSelectedAddress: jest.fn(), - getERC20BalanceOf: getERC20BalanceOfStub, + address, + decimals: 18, + symbol: 'EOS', + aggregators: [], }, - { - interval: 1337, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - }, - ); + ]; + const controller = new TokenBalancesController({ + interval: 1337, + tokens, + onTokensStateChange: jest.fn(), + getSelectedAddress: jest.fn(), + getERC20BalanceOf: getERC20BalanceOfStub, + messenger: getMessenger(), + }); - expect(tokenBalances.state.contractBalances).toStrictEqual({}); - await tokenBalances.updateBalances(); - const mytoken = getToken(tokenBalances, address); - expect(mytoken?.balanceError).toBeInstanceOf(Error); - expect(mytoken?.balanceError).toHaveProperty('message', errorMsg); - expect(tokenBalances.state.contractBalances[address].toNumber()).toBe(0); - - getERC20BalanceOfStub.returns(new BN(1)); - await tokenBalances.updateBalances(); - expect(mytoken?.balanceError).toBeNull(); - expect(Object.keys(tokenBalances.state.contractBalances)).toContain( - address, - ); + expect(controller.state.contractBalances).toStrictEqual({}); + + await controller.updateBalances(); + + expect(tokens[0].balanceError).toBeInstanceOf(Error); + expect(tokens[0].balanceError).toHaveProperty('message', errorMsg); + expect(controller.state.contractBalances[address]).toBe(toHex(0)); + + getERC20BalanceOfStub.mockReturnValue(new BN(1)); + + await controller.updateBalances(); - expect( - tokenBalances.state.contractBalances[address].toNumber(), - ).toBeGreaterThan(0); + expect(tokens[0].balanceError).toBeNull(); + expect(Object.keys(controller.state.contractBalances)).toContain(address); + expect(controller.state.contractBalances[address]).not.toBe(0); }); it('should update balances when tokens change', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - getSelectedAddress: jest.fn(), - getERC20BalanceOf: jest.fn(), + const controller = new TokenBalancesController({ + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); }, - { interval: 1337 }, - ); - const triggerTokensStateChange = (state: TokensState) => { + getSelectedAddress: jest.fn(), + getERC20BalanceOf: jest.fn(), + interval: 1337, + messenger: getMessenger(), + }); + const triggerTokensStateChange = async (state: TokensState) => { for (const listener of tokensStateChangeListeners) { listener(state); } }; - const updateBalances = sinon.stub(tokenBalances, 'updateBalances'); + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); - triggerTokensStateChange({ + await triggerTokensStateChange({ ...getDefaultTokensState(), tokens: [ { @@ -198,31 +353,28 @@ describe('TokenBalancesController', () => { ], }); - expect(updateBalances.called).toBe(true); + expect(updateBalancesSpy).toHaveBeenCalled(); }); it('should update token balances when detected tokens are added', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub().returns(new BN(1)), - }, - { - interval: 1337, + const controller = new TokenBalancesController({ + interval: 1337, + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); }, - ); - const triggerTokensStateChange = (state: TokensState) => { + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), + }); + const triggerTokensStateChange = async (state: TokensState) => { for (const listener of tokensStateChangeListeners) { listener(state); } }; - expect(tokenBalances.state.contractBalances).toStrictEqual({}); + expect(controller.state.contractBalances).toStrictEqual({}); - triggerTokensStateChange({ + await triggerTokensStateChange({ ...getDefaultTokensState(), detectedTokens: [ { @@ -236,10 +388,8 @@ describe('TokenBalancesController', () => { tokens: [], }); - await tokenBalances.updateBalances(); - - expect(tokenBalances.state.contractBalances).toStrictEqual({ - '0x02': new BN(1), + expect(controller.state.contractBalances).toStrictEqual({ + '0x02': toHex(new BN(1)), }); }); }); diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index bca6cfa90d..82d94b57a8 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -1,141 +1,227 @@ -import type { BaseConfig, BaseState } from '@metamask/base-controller'; -import { BaseControllerV1 } from '@metamask/base-controller'; -import { safelyExecute } from '@metamask/controller-utils'; +import { + type RestrictedControllerMessenger, + type ControllerGetStateAction, + type ControllerStateChangeEvent, + BaseController, +} from '@metamask/base-controller'; +import { safelyExecute, toHex } from '@metamask/controller-utils'; import type { PreferencesState } from '@metamask/preferences-controller'; -import { BN } from 'ethereumjs-util'; import type { AssetsContractController } from './AssetsContractController'; import type { Token } from './TokenRatesController'; import type { TokensState } from './TokensController'; -// TODO: Remove this export in the next major release -export { BN }; +const DEFAULT_INTERVAL = 180000; + +const controllerName = 'TokenBalancesController'; + +const metadata = { + contractBalances: { persist: true, anonymous: false }, +}; /** - * @type TokenBalancesConfig - * - * Token balances controller configuration - * @property interval - Polling interval used to fetch new token balances - * @property tokens - List of tokens to track balances for + * Token balances controller options + * @property interval - Polling interval used to fetch new token balances. + * @property tokens - List of tokens to track balances for. + * @property disabled - If set to true, all tracked tokens contract balances updates are blocked. + * @property onTokensStateChange - Allows subscribing to assets controller state changes. + * @property getSelectedAddress - Gets the current selected address. + * @property getERC20BalanceOf - Gets the balance of the given account at the given contract address. */ -// This interface was created before this ESLint rule was added. -// Convert to a `type` in a future major version. -// eslint-disable-next-line @typescript-eslint/consistent-type-definitions -export interface TokenBalancesConfig extends BaseConfig { - interval: number; - tokens: Token[]; -} +type TokenBalancesControllerOptions = { + interval?: number; + tokens?: Token[]; + disabled?: boolean; + onTokensStateChange: (listener: (tokenState: TokensState) => void) => void; + getSelectedAddress: () => PreferencesState['selectedAddress']; + getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; + messenger: TokenBalancesControllerMessenger; + state?: Partial; +}; + +/** + * Represents a mapping of hash token contract addresses to their balances. + */ +type ContractBalances = Record; /** - * @type TokenBalancesState - * * Token balances controller state * @property contractBalances - Hash of token contract addresses to balances */ -// This interface was created before this ESLint rule was added. -// Convert to a `type` in a future major version. -// eslint-disable-next-line @typescript-eslint/consistent-type-definitions -export interface TokenBalancesState extends BaseState { - contractBalances: { [address: string]: BN }; +export type TokenBalancesControllerState = { + contractBalances: ContractBalances; +}; + +export type TokenBalancesControllerGetStateAction = ControllerGetStateAction< + typeof controllerName, + TokenBalancesControllerState +>; + +export type TokenBalancesControllerActions = + TokenBalancesControllerGetStateAction; + +export type TokenBalancesControllerStateChangeEvent = + ControllerStateChangeEvent< + typeof controllerName, + TokenBalancesControllerState + >; + +export type TokenBalancesControllerEvents = + TokenBalancesControllerStateChangeEvent; + +export type TokenBalancesControllerMessenger = RestrictedControllerMessenger< + typeof controllerName, + TokenBalancesControllerActions, + TokenBalancesControllerEvents, + never, + never +>; + +/** + * Get the default TokenBalancesController state. + * + * @returns The default TokenBalancesController state. + */ +function getDefaultState(): TokenBalancesControllerState { + return { + contractBalances: {}, + }; } /** * Controller that passively polls on a set interval token balances * for tokens stored in the TokensController */ -export class TokenBalancesController extends BaseControllerV1< - TokenBalancesConfig, - TokenBalancesState +export class TokenBalancesController extends BaseController< + typeof controllerName, + TokenBalancesControllerState, + TokenBalancesControllerMessenger > { - private handle?: ReturnType; + #handle?: ReturnType; - /** - * Name of this controller used during composition - */ - override name = 'TokenBalancesController'; + #getSelectedAddress: () => PreferencesState['selectedAddress']; + + #getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; - private readonly getSelectedAddress: () => PreferencesState['selectedAddress']; + #interval: number; - private readonly getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; + #tokens: Token[]; + + #disabled: boolean; /** - * Creates a TokenBalancesController instance. + * Construct a Token Balances Controller. * * @param options - The controller options. + * @param options.interval - Polling interval used to fetch new token balances. + * @param options.tokens - List of tokens to track balances for. + * @param options.disabled - If set to true, all tracked tokens contract balances updates are blocked. * @param options.onTokensStateChange - Allows subscribing to assets controller state changes. * @param options.getSelectedAddress - Gets the current selected address. * @param options.getERC20BalanceOf - Gets the balance of the given account at the given contract address. - * @param config - Initial options used to configure this controller. - * @param state - Initial state to set on this controller. + * @param options.state - Initial state to set on this controller. + * @param options.messenger - The controller restricted messenger. */ - constructor( - { - onTokensStateChange, - getSelectedAddress, - getERC20BalanceOf, - }: { - onTokensStateChange: ( - listener: (tokenState: TokensState) => void, - ) => void; - getSelectedAddress: () => PreferencesState['selectedAddress']; - getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; - }, - config?: Partial, - state?: Partial, - ) { - super(config, state); - this.defaultConfig = { - interval: 180000, - tokens: [], - }; - this.defaultState = { contractBalances: {} }; - this.initialize(); - onTokensStateChange(({ tokens, detectedTokens }) => { - this.configure({ tokens: [...tokens, ...detectedTokens] }); - this.updateBalances(); + constructor({ + interval = DEFAULT_INTERVAL, + tokens = [], + disabled = false, + onTokensStateChange, + getSelectedAddress, + getERC20BalanceOf, + messenger, + state = {}, + }: TokenBalancesControllerOptions) { + super({ + name: controllerName, + metadata, + messenger, + state: { + ...getDefaultState(), + ...state, + }, }); - this.getSelectedAddress = getSelectedAddress; - this.getERC20BalanceOf = getERC20BalanceOf; + + this.#disabled = disabled; + this.#interval = interval; + this.#tokens = tokens; + + onTokensStateChange(this.#tokensStateChangeListener.bind(this)); + + this.#getSelectedAddress = getSelectedAddress; + this.#getERC20BalanceOf = getERC20BalanceOf; + this.poll(); } + /* + * Tokens state changes listener. + */ + #tokensStateChangeListener({ tokens, detectedTokens }: TokensState) { + this.#tokens = [...tokens, ...detectedTokens]; + this.updateBalances(); + } + + /** + * Allows controller to update tracked tokens contract balances. + */ + enable() { + this.#disabled = false; + } + + /** + * Blocks controller from updating tracked tokens contract balances. + */ + disable() { + this.#disabled = true; + } + /** * Starts a new polling interval. * * @param interval - Polling interval used to fetch new token balances. */ async poll(interval?: number): Promise { - interval && this.configure({ interval }, false, false); - this.handle && clearTimeout(this.handle); + if (interval) { + this.#interval = interval; + } + + if (this.#handle) { + clearTimeout(this.#handle); + } + await safelyExecute(() => this.updateBalances()); - this.handle = setTimeout(() => { - this.poll(this.config.interval); - }, this.config.interval); + + this.#handle = setTimeout(() => { + this.poll(this.#interval); + }, this.#interval); } /** * Updates balances for all tokens. */ async updateBalances() { - if (this.disabled) { + if (this.#disabled) { return; } - const { tokens } = this.config; - const newContractBalances: { [address: string]: BN } = {}; - for (const token of tokens) { + + const newContractBalances: ContractBalances = {}; + for (const token of this.#tokens) { const { address } = token; try { - newContractBalances[address] = await this.getERC20BalanceOf( - address, - this.getSelectedAddress(), + newContractBalances[address] = toHex( + await this.#getERC20BalanceOf(address, this.#getSelectedAddress()), ); token.balanceError = null; } catch (error) { - newContractBalances[address] = new BN(0); + newContractBalances[address] = toHex(0); token.balanceError = error; } } - this.update({ contractBalances: newContractBalances }); + + this.update((state) => { + state.contractBalances = newContractBalances; + }); } } diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index c76812bc52..d77143a607 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -10,16 +10,19 @@ import { add0x } from '@metamask/utils'; import nock from 'nock'; import { useFakeTimers } from 'sinon'; -import { advanceTime } from '../../../tests/helpers'; +import { advanceTime, flushPromises } from '../../../tests/helpers'; import { TOKEN_PRICES_BATCH_SIZE } from './assetsUtil'; import type { AbstractTokenPricesService, TokenPrice, TokenPricesByTokenAddress, } from './token-prices-service/abstract-token-prices-service'; -import type { TokenBalancesState } from './TokenBalancesController'; import { TokenRatesController } from './TokenRatesController'; -import type { TokenRatesConfig, Token } from './TokenRatesController'; +import type { + TokenRatesConfig, + Token, + TokenRatesState, +} from './TokenRatesController'; import type { TokensState } from './TokensController'; const defaultSelectedAddress = '0x0000000000000000000000000000000000000001'; @@ -2228,7 +2231,7 @@ type WithControllerCallback = ({ type PartialConstructorParameters = { options?: Partial[0]>; config?: Partial; - state?: Partial; + state?: Partial; }; type WithControllerArgs = @@ -2288,18 +2291,6 @@ async function withController( } } -/** - * Resolve all pending promises. - * - * This method is used for async tests that use fake timers. - * See https://stackoverflow.com/a/58716087 and https://jestjs.io/docs/timer-mocks. - * - * TODO: migrate this to @metamask/utils - */ -async function flushPromises(): Promise { - await new Promise(jest.requireActual('timers').setImmediate); -} - /** * Call an "update exchange rates" method with the given parameters. * diff --git a/packages/assets-controllers/src/index.ts b/packages/assets-controllers/src/index.ts index fc1a068472..1f2786d784 100644 --- a/packages/assets-controllers/src/index.ts +++ b/packages/assets-controllers/src/index.ts @@ -3,7 +3,14 @@ export * from './AssetsContractController'; export * from './CurrencyRateController'; export * from './NftController'; export * from './NftDetectionController'; -export * from './TokenBalancesController'; +export type { + TokenBalancesControllerMessenger, + TokenBalancesControllerActions, + TokenBalancesControllerGetStateAction, + TokenBalancesControllerEvents, + TokenBalancesControllerStateChangeEvent, +} from './TokenBalancesController'; +export { TokenBalancesController } from './TokenBalancesController'; export type { TokenDetectionControllerMessenger, TokenDetectionControllerActions, diff --git a/packages/transaction-controller/src/TransactionController.test.ts b/packages/transaction-controller/src/TransactionController.test.ts index 65ceaea740..00817adfd8 100644 --- a/packages/transaction-controller/src/TransactionController.test.ts +++ b/packages/transaction-controller/src/TransactionController.test.ts @@ -23,6 +23,7 @@ import { errorCodes, providerErrors, rpcErrors } from '@metamask/rpc-errors'; import * as NonceTrackerPackage from 'nonce-tracker'; import { FakeBlockTracker } from '../../../tests/fake-block-tracker'; +import { flushPromises } from '../../../tests/helpers'; import { mockNetwork } from '../../../tests/mock-network'; import { IncomingTransactionHelper } from './helpers/IncomingTransactionHelper'; import { PendingTransactionTracker } from './helpers/PendingTransactionTracker'; @@ -306,15 +307,6 @@ function waitForTransactionFinished( }); } -/** - * Resolve all pending promises. - * This method is used for async tests that use fake timers. - * See https://stackoverflow.com/a/58716087 and https://jestjs.io/docs/timer-mocks. - */ -function flushPromises(): Promise { - return new Promise(jest.requireActual('timers').setImmediate); -} - const MOCK_PREFERENCES = { state: { selectedAddress: 'foo' } }; const INFURA_PROJECT_ID = '341eacb578dd44a1a049cbc5f6fd4035'; const GOERLI_PROVIDER = new HttpProvider( @@ -2640,7 +2632,7 @@ describe('TransactionController', () => { externalBaseFeePerGas, ); - await new Promise(jest.requireActual('timers').setImmediate); + await flushPromises(); expect(mockPostTransactionBalanceUpdatedListener).toHaveBeenCalledTimes( 1, diff --git a/tests/helpers.ts b/tests/helpers.ts index 114d47736c..ed0b2660f8 100644 --- a/tests/helpers.ts +++ b/tests/helpers.ts @@ -27,3 +27,15 @@ export async function advanceTime({ duration -= stepSize; } while (duration > 0); } + +/** + * Resolve all pending promises. + * + * This method is used for async tests that use fake timers. + * See https://stackoverflow.com/a/58716087 and https://jestjs.io/docs/timer-mocks. + * + * TODO: migrate this to @metamask/utils + */ +export async function flushPromises(): Promise { + await new Promise(jest.requireActual('timers').setImmediate); +}