From d6512480cd7c19e7035d5caf77a9b865747f3f97 Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Tue, 9 Jan 2024 16:59:59 +0100 Subject: [PATCH 1/9] feat: migrate token balances controller to base controller v2 --- .../src/TokenBalancesController.test.ts | 173 +++++++------- .../src/TokenBalancesController.ts | 223 ++++++++++++------ .../src/TokenRatesController.test.ts | 4 +- 3 files changed, 230 insertions(+), 170 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index ca77659f4b..c6bbd7397f 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -1,21 +1,35 @@ +import { ControllerMessenger } from '@metamask/base-controller'; import { BN } from 'ethereumjs-util'; import * as sinon from 'sinon'; import { advanceTime } from '../../../tests/helpers'; -import { - BN as exportedBn, - TokenBalancesController, -} from './TokenBalancesController'; +import { TokenBalancesController } from './TokenBalancesController'; import { getDefaultTokensState, type TokensState } from './TokensController'; +const controllerName = 'TokenBalancesController'; + +/** + * Constructs a restricted controller messenger. + * + * @returns A restricted controller messenger. + */ +function getMessenger() { + return new ControllerMessenger().getRestricted< + typeof controllerName, + never, + never + >({ + name: controllerName, + }); +} + describe('TokenBalancesController', () => { let clock: sinon.SinonFakeTimers; const getToken = ( tokenBalances: TokenBalancesController, address: string, ) => { - const { tokens } = tokenBalances.config; - return tokens.find((token) => token.address === address); + return tokenBalances.getTokens().find((token) => token.address === address); }; beforeEach(() => { clock = sinon.useFakeTimers(); @@ -26,44 +40,28 @@ describe('TokenBalancesController', () => { sinon.restore(); }); - it('should re-export BN', () => { - expect(exportedBn).toStrictEqual(BN); - }); - it('should set default state', () => { const tokenBalances = new TokenBalancesController({ onTokensStateChange: sinon.stub(), getSelectedAddress: () => '0x1234', getERC20BalanceOf: sinon.stub(), + messenger: getMessenger(), }); expect(tokenBalances.state).toStrictEqual({ contractBalances: {} }); }); - it('should set default config', () => { - const tokenBalances = new TokenBalancesController({ - onTokensStateChange: sinon.stub(), - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), - }); - expect(tokenBalances.config).toStrictEqual({ - interval: 180000, - tokens: [], - }); - }); - it('should poll and update balances in the right interval', async () => { const mock = sinon.stub( TokenBalancesController.prototype, 'updateBalances', ); - new TokenBalancesController( - { - onTokensStateChange: sinon.stub(), - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), - }, - { interval: 10 }, - ); + new TokenBalancesController({ + interval: 10, + onTokensStateChange: sinon.stub(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: sinon.stub(), + messenger: getMessenger(), + }); expect(mock.called).toBe(true); expect(mock.calledTwice).toBe(false); @@ -72,32 +70,29 @@ describe('TokenBalancesController', () => { }); it('should not update rates if disabled', async () => { - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: sinon.stub(), - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), - }, - { - disabled: true, - interval: 10, - }, - ); - const mock = sinon.stub(tokenBalances, 'update'); + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const tokenBalances = new TokenBalancesController({ + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: sinon.stub(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: sinon.stub().returns(new BN(1)), + messenger: getMessenger(), + }); await tokenBalances.updateBalances(); - expect(mock.called).toBe(false); + expect(Object.keys(tokenBalances.state.contractBalances)).toStrictEqual({}); }); it('should clear previous interval', async () => { const mock = sinon.stub(global, 'clearTimeout'); - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: sinon.stub(), - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), - }, - { interval: 1337 }, - ); + const tokenBalances = new TokenBalancesController({ + interval: 1337, + onTokensStateChange: sinon.stub(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: sinon.stub(), + messenger: getMessenger(), + }); await tokenBalances.poll(1338); await advanceTime({ clock, duration: 1339 }); expect(mock.called).toBe(true); @@ -106,17 +101,14 @@ describe('TokenBalancesController', () => { it('should update all balances', async () => { const selectedAddress = '0x0000000000000000000000000000000000000001'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: jest.fn(), - getSelectedAddress: () => selectedAddress, - getERC20BalanceOf: sinon.stub().returns(new BN(1)), - }, - { - interval: 1337, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - }, - ); + const tokenBalances = new TokenBalancesController({ + interval: 1337, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + onTokensStateChange: jest.fn(), + getSelectedAddress: () => selectedAddress, + getERC20BalanceOf: sinon.stub().returns(new BN(1)), + messenger: getMessenger(), + }); expect(tokenBalances.state.contractBalances).toStrictEqual({}); await tokenBalances.updateBalances(); @@ -127,7 +119,7 @@ describe('TokenBalancesController', () => { ); expect( - tokenBalances.state.contractBalances[address].toNumber(), + tokenBalances.state.contractBalances[address].toString(), ).toBeGreaterThan(0); }); @@ -137,17 +129,14 @@ describe('TokenBalancesController', () => { const getERC20BalanceOfStub = sinon .stub() .returns(Promise.reject(new Error(errorMsg))); - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: jest.fn(), - getSelectedAddress: jest.fn(), - getERC20BalanceOf: getERC20BalanceOfStub, - }, - { - interval: 1337, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], - }, - ); + const tokenBalances = new TokenBalancesController({ + interval: 1337, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + onTokensStateChange: jest.fn(), + getSelectedAddress: jest.fn(), + getERC20BalanceOf: getERC20BalanceOfStub, + messenger: getMessenger(), + }); expect(tokenBalances.state.contractBalances).toStrictEqual({}); await tokenBalances.updateBalances(); @@ -170,16 +159,15 @@ describe('TokenBalancesController', () => { it('should update balances when tokens change', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - getSelectedAddress: jest.fn(), - getERC20BalanceOf: jest.fn(), + const tokenBalances = new TokenBalancesController({ + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); }, - { interval: 1337 }, - ); + getSelectedAddress: jest.fn(), + getERC20BalanceOf: jest.fn(), + interval: 1337, + messenger: getMessenger(), + }); const triggerTokensStateChange = (state: TokensState) => { for (const listener of tokensStateChangeListeners) { listener(state); @@ -203,18 +191,15 @@ describe('TokenBalancesController', () => { it('should update token balances when detected tokens are added', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; - const tokenBalances = new TokenBalancesController( - { - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub().returns(new BN(1)), - }, - { - interval: 1337, + const tokenBalances = new TokenBalancesController({ + interval: 1337, + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); }, - ); + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: sinon.stub().returns(new BN(1)), + messenger: getMessenger(), + }); const triggerTokensStateChange = (state: TokensState) => { for (const listener of tokensStateChangeListeners) { listener(state); diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index bca6cfa90d..62c30c374b 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -1,5 +1,7 @@ -import type { BaseConfig, BaseState } from '@metamask/base-controller'; -import { BaseControllerV1 } from '@metamask/base-controller'; +import { + type RestrictedControllerMessenger, + BaseController, +} from '@metamask/base-controller'; import { safelyExecute } from '@metamask/controller-utils'; import type { PreferencesState } from '@metamask/preferences-controller'; import { BN } from 'ethereumjs-util'; @@ -8,126 +10,196 @@ import type { AssetsContractController } from './AssetsContractController'; import type { Token } from './TokenRatesController'; import type { TokensState } from './TokensController'; -// TODO: Remove this export in the next major release -export { BN }; +const DEFAULT_INTERVAL = 180000; + +const controllerName = 'TokenBalancesController'; + +const metadata = { + contractBalances: { persist: true, anonymous: false }, +}; /** - * @type TokenBalancesConfig + * @type TokenBalancesControllerOptions * - * Token balances controller configuration - * @property interval - Polling interval used to fetch new token balances - * @property tokens - List of tokens to track balances for + * Token balances controller options + * @property interval - Polling interval used to fetch new token balances. + * @property tokens - List of tokens to track balances for. + * @property disabled - If set to true, all tracked tokens contract balances updates are blocked. + * @property onTokensStateChange - Allows subscribing to assets controller state changes. + * @property getSelectedAddress - Gets the current selected address. + * @property getERC20BalanceOf - Gets the balance of the given account at the given contract address. */ -// This interface was created before this ESLint rule was added. -// Convert to a `type` in a future major version. -// eslint-disable-next-line @typescript-eslint/consistent-type-definitions -export interface TokenBalancesConfig extends BaseConfig { - interval: number; - tokens: Token[]; -} +export type TokenBalancesControllerOptions = { + interval?: number; + tokens?: Token[]; + disabled?: boolean; + onTokensStateChange: (listener: (tokenState: TokensState) => void) => void; + getSelectedAddress: () => PreferencesState['selectedAddress']; + getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; + messenger: TokenBalancesControllerMessenger; + state?: Partial; +}; + +/** + * Represents a mapping of hash token contract addresses to their balances. + */ +type ContractBalances = Record; /** - * @type TokenBalancesState + * @type TokenBalancesControllerState * * Token balances controller state * @property contractBalances - Hash of token contract addresses to balances */ -// This interface was created before this ESLint rule was added. -// Convert to a `type` in a future major version. -// eslint-disable-next-line @typescript-eslint/consistent-type-definitions -export interface TokenBalancesState extends BaseState { - contractBalances: { [address: string]: BN }; -} +export type TokenBalancesControllerState = { + contractBalances: ContractBalances; +}; + +const getDefaultState = (): TokenBalancesControllerState => { + return { + contractBalances: {}, + }; +}; + +export type TokenBalancesControllerMessenger = RestrictedControllerMessenger< + typeof controllerName, + never, + never, + never, + never +>; /** * Controller that passively polls on a set interval token balances * for tokens stored in the TokensController */ -export class TokenBalancesController extends BaseControllerV1< - TokenBalancesConfig, - TokenBalancesState +export class TokenBalancesController extends BaseController< + typeof controllerName, + TokenBalancesControllerState, + TokenBalancesControllerMessenger > { - private handle?: ReturnType; + #handle?: ReturnType; - /** - * Name of this controller used during composition - */ - override name = 'TokenBalancesController'; + #getSelectedAddress: () => PreferencesState['selectedAddress']; + + #getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; + + #interval: number; - private readonly getSelectedAddress: () => PreferencesState['selectedAddress']; + #tokens: Token[]; - private readonly getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; + #disabled: boolean; /** - * Creates a TokenBalancesController instance. + * Construct a Token Balances Controller. * * @param options - The controller options. + * @param options.interval - Polling interval used to fetch new token balances. + * @param options.tokens - List of tokens to track balances for. + * @param options.disabled - If set to true, all tracked tokens contract balances updates are blocked. * @param options.onTokensStateChange - Allows subscribing to assets controller state changes. * @param options.getSelectedAddress - Gets the current selected address. * @param options.getERC20BalanceOf - Gets the balance of the given account at the given contract address. - * @param config - Initial options used to configure this controller. - * @param state - Initial state to set on this controller. + * @param options.state - Initial state to set on this controller. + * @param options.messenger - The controller restricted messenger. */ - constructor( - { - onTokensStateChange, - getSelectedAddress, - getERC20BalanceOf, - }: { - onTokensStateChange: ( - listener: (tokenState: TokensState) => void, - ) => void; - getSelectedAddress: () => PreferencesState['selectedAddress']; - getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; - }, - config?: Partial, - state?: Partial, - ) { - super(config, state); - this.defaultConfig = { - interval: 180000, - tokens: [], - }; - this.defaultState = { contractBalances: {} }; - this.initialize(); - onTokensStateChange(({ tokens, detectedTokens }) => { - this.configure({ tokens: [...tokens, ...detectedTokens] }); - this.updateBalances(); + constructor({ + interval = DEFAULT_INTERVAL, + tokens = [], + disabled = false, + onTokensStateChange, + getSelectedAddress, + getERC20BalanceOf, + messenger, + state = {}, + }: TokenBalancesControllerOptions) { + super({ + name: controllerName, + metadata, + messenger, + state: { + ...getDefaultState(), + ...state, + }, }); - this.getSelectedAddress = getSelectedAddress; - this.getERC20BalanceOf = getERC20BalanceOf; + + this.#disabled = disabled; + this.#interval = interval; + this.#tokens = tokens; + + onTokensStateChange(this.#tokensStateChangeListener.bind(this)); + + this.#getSelectedAddress = getSelectedAddress; + this.#getERC20BalanceOf = getERC20BalanceOf; + this.poll(); } + /* + * Tokens state changes listener. + */ + #tokensStateChangeListener({ tokens, detectedTokens }: TokensState) { + this.#tokens = [...tokens, ...detectedTokens]; + this.updateBalances(); + } + + /** + * Allows controller to update tracked tokens contract balances. + */ + enable() { + this.#disabled = false; + } + + /** + * Blocks controller from updating tracked tokens contract balances. + */ + disable() { + this.#disabled = true; + } + + /* + * Lists all tracked tokens. + */ + getTokens() { + return this.#tokens; + } + /** * Starts a new polling interval. * * @param interval - Polling interval used to fetch new token balances. */ async poll(interval?: number): Promise { - interval && this.configure({ interval }, false, false); - this.handle && clearTimeout(this.handle); + if (interval) { + this.#interval = interval; + } + + if (this.#handle) { + clearTimeout(this.#handle); + } + await safelyExecute(() => this.updateBalances()); - this.handle = setTimeout(() => { - this.poll(this.config.interval); - }, this.config.interval); + + this.#handle = setTimeout(() => { + this.poll(this.#interval); + }, this.#interval); } /** * Updates balances for all tokens. */ async updateBalances() { - if (this.disabled) { + if (this.#disabled) { return; } - const { tokens } = this.config; - const newContractBalances: { [address: string]: BN } = {}; - for (const token of tokens) { + + const newContractBalances: ContractBalances = {}; + for (const token of this.#tokens) { const { address } = token; try { - newContractBalances[address] = await this.getERC20BalanceOf( + newContractBalances[address] = await this.#getERC20BalanceOf( address, - this.getSelectedAddress(), + this.#getSelectedAddress(), ); token.balanceError = null; } catch (error) { @@ -135,7 +207,10 @@ export class TokenBalancesController extends BaseControllerV1< token.balanceError = error; } } - this.update({ contractBalances: newContractBalances }); + + this.update((state) => { + state.contractBalances = newContractBalances; + }); } } diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index c76812bc52..61c6638615 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -17,7 +17,7 @@ import type { TokenPrice, TokenPricesByTokenAddress, } from './token-prices-service/abstract-token-prices-service'; -import type { TokenBalancesState } from './TokenBalancesController'; +import type { TokenBalancesControllerState } from './TokenBalancesController'; import { TokenRatesController } from './TokenRatesController'; import type { TokenRatesConfig, Token } from './TokenRatesController'; import type { TokensState } from './TokensController'; @@ -2228,7 +2228,7 @@ type WithControllerCallback = ({ type PartialConstructorParameters = { options?: Partial[0]>; config?: Partial; - state?: Partial; + state?: Partial; }; type WithControllerArgs = From 1583e10c3cf4695b862d01700f4e74546bdc6790 Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Tue, 9 Jan 2024 17:28:57 +0100 Subject: [PATCH 2/9] refactor: save balances as string instead of BN --- .../src/TokenBalancesController.test.ts | 34 +++++++++++++------ .../src/TokenBalancesController.ts | 11 +++--- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index c6bbd7397f..54198c6684 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -47,6 +47,7 @@ describe('TokenBalancesController', () => { getERC20BalanceOf: sinon.stub(), messenger: getMessenger(), }); + expect(tokenBalances.state).toStrictEqual({ contractBalances: {} }); }); @@ -62,10 +63,12 @@ describe('TokenBalancesController', () => { getERC20BalanceOf: sinon.stub(), messenger: getMessenger(), }); + expect(mock.called).toBe(true); expect(mock.calledTwice).toBe(false); await advanceTime({ clock, duration: 15 }); + expect(mock.calledTwice).toBe(true); }); @@ -80,8 +83,10 @@ describe('TokenBalancesController', () => { getERC20BalanceOf: sinon.stub().returns(new BN(1)), messenger: getMessenger(), }); + await tokenBalances.updateBalances(); - expect(Object.keys(tokenBalances.state.contractBalances)).toStrictEqual({}); + + expect(tokenBalances.state.contractBalances).toStrictEqual({}); }); it('should clear previous interval', async () => { @@ -93,8 +98,10 @@ describe('TokenBalancesController', () => { getERC20BalanceOf: sinon.stub(), messenger: getMessenger(), }); + await tokenBalances.poll(1338); await advanceTime({ clock, duration: 1339 }); + expect(mock.called).toBe(true); }); @@ -109,18 +116,20 @@ describe('TokenBalancesController', () => { getERC20BalanceOf: sinon.stub().returns(new BN(1)), messenger: getMessenger(), }); + expect(tokenBalances.state.contractBalances).toStrictEqual({}); await tokenBalances.updateBalances(); + const mytoken = getToken(tokenBalances, address); + expect(mytoken?.balanceError).toBeNull(); expect(Object.keys(tokenBalances.state.contractBalances)).toContain( address, ); - - expect( - tokenBalances.state.contractBalances[address].toString(), - ).toBeGreaterThan(0); + expect(tokenBalances.state.contractBalances[address].toString()).not.toBe( + '0', + ); }); it('should handle `getERC20BalanceOf` error case', async () => { @@ -139,22 +148,25 @@ describe('TokenBalancesController', () => { }); expect(tokenBalances.state.contractBalances).toStrictEqual({}); + await tokenBalances.updateBalances(); + const mytoken = getToken(tokenBalances, address); expect(mytoken?.balanceError).toBeInstanceOf(Error); expect(mytoken?.balanceError).toHaveProperty('message', errorMsg); - expect(tokenBalances.state.contractBalances[address].toNumber()).toBe(0); + expect(tokenBalances.state.contractBalances[address].toString()).toBe('0'); getERC20BalanceOfStub.returns(new BN(1)); + await tokenBalances.updateBalances(); + expect(mytoken?.balanceError).toBeNull(); expect(Object.keys(tokenBalances.state.contractBalances)).toContain( address, ); - - expect( - tokenBalances.state.contractBalances[address].toNumber(), - ).toBeGreaterThan(0); + expect(tokenBalances.state.contractBalances[address].toString()).not.toBe( + 0, + ); }); it('should update balances when tokens change', async () => { @@ -224,7 +236,7 @@ describe('TokenBalancesController', () => { await tokenBalances.updateBalances(); expect(tokenBalances.state.contractBalances).toStrictEqual({ - '0x02': new BN(1), + '0x02': new BN(1).toString(16), }); }); }); diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 62c30c374b..369e39425d 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -43,7 +43,7 @@ export type TokenBalancesControllerOptions = { /** * Represents a mapping of hash token contract addresses to their balances. */ -type ContractBalances = Record; +type ContractBalances = Record; /** * @type TokenBalancesControllerState @@ -197,13 +197,12 @@ export class TokenBalancesController extends BaseController< for (const token of this.#tokens) { const { address } = token; try { - newContractBalances[address] = await this.#getERC20BalanceOf( - address, - this.#getSelectedAddress(), - ); + newContractBalances[address] = ( + await this.#getERC20BalanceOf(address, this.#getSelectedAddress()) + ).toString(16); token.balanceError = null; } catch (error) { - newContractBalances[address] = new BN(0); + newContractBalances[address] = new BN(0).toString(16); token.balanceError = error; } } From 35106e2075cf9f192e0d506ea299e3dac216c260 Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Wed, 10 Jan 2024 14:57:26 +0100 Subject: [PATCH 3/9] fix: token rates controller wrong import --- .../assets-controllers/src/TokenRatesController.test.ts | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 61c6638615..f865b31d80 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -17,9 +17,12 @@ import type { TokenPrice, TokenPricesByTokenAddress, } from './token-prices-service/abstract-token-prices-service'; -import type { TokenBalancesControllerState } from './TokenBalancesController'; import { TokenRatesController } from './TokenRatesController'; -import type { TokenRatesConfig, Token } from './TokenRatesController'; +import type { + TokenRatesConfig, + Token, + TokenRatesState, +} from './TokenRatesController'; import type { TokensState } from './TokensController'; const defaultSelectedAddress = '0x0000000000000000000000000000000000000001'; @@ -2228,7 +2231,7 @@ type WithControllerCallback = ({ type PartialConstructorParameters = { options?: Partial[0]>; config?: Partial; - state?: Partial; + state?: Partial; }; type WithControllerArgs = From 8b2f04264c2b327ae476cce48134f31e170469b1 Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Wed, 10 Jan 2024 15:50:35 +0100 Subject: [PATCH 4/9] fix: test coverage --- .../src/TokenBalancesController.test.ts | 69 ++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 54198c6684..d08b14a830 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -72,7 +72,7 @@ describe('TokenBalancesController', () => { expect(mock.calledTwice).toBe(true); }); - it('should not update rates if disabled', async () => { + it('should not update banlances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; const tokenBalances = new TokenBalancesController({ disabled: true, @@ -89,6 +89,73 @@ describe('TokenBalancesController', () => { expect(tokenBalances.state.contractBalances).toStrictEqual({}); }); + it('should update banlances if controller is manually enabled', async () => { + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const tokenBalances = new TokenBalancesController({ + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: sinon.stub(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: sinon.stub().returns(new BN(1)), + messenger: getMessenger(), + }); + + await tokenBalances.updateBalances(); + + expect(tokenBalances.state.contractBalances).toStrictEqual({}); + + tokenBalances.enable(); + await tokenBalances.updateBalances(); + + expect(tokenBalances.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + }); + }); + + it('should not update banlances if controller is manually disabled', async () => { + const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const tokenBalances = new TokenBalancesController({ + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: sinon.stub().returns(new BN(1)), + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); + }, + messenger: getMessenger(), + }); + const triggerTokensStateChange = (state: TokensState) => { + for (const listener of tokensStateChangeListeners) { + listener(state); + } + }; + + await tokenBalances.updateBalances(); + + expect(tokenBalances.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + }); + + tokenBalances.disable(); + triggerTokensStateChange({ + ...getDefaultTokensState(), + tokens: [ + { + address: '0x00', + symbol: 'FOO', + decimals: 18, + }, + ], + }); + + expect(tokenBalances.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + }); + }); + it('should clear previous interval', async () => { const mock = sinon.stub(global, 'clearTimeout'); const tokenBalances = new TokenBalancesController({ From bc092738d959c08c2ed218a59db367f4bdc893fc Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Wed, 10 Jan 2024 16:54:15 +0100 Subject: [PATCH 5/9] fix: add missing types --- .../src/TokenBalancesController.ts | 33 +++++++++++++++---- packages/assets-controllers/src/index.ts | 10 +++++- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 369e39425d..4d92940ae5 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -1,5 +1,7 @@ import { type RestrictedControllerMessenger, + type ControllerGetStateAction, + type ControllerStateChangeEvent, BaseController, } from '@metamask/base-controller'; import { safelyExecute } from '@metamask/controller-utils'; @@ -55,20 +57,37 @@ export type TokenBalancesControllerState = { contractBalances: ContractBalances; }; -const getDefaultState = (): TokenBalancesControllerState => { - return { - contractBalances: {}, - }; -}; +export type TokenBalancesControllerGetStateAction = ControllerGetStateAction< + typeof controllerName, + TokenBalancesControllerState +>; + +export type TokenBalancesControllerActions = + TokenBalancesControllerGetStateAction; + +export type TokenBalancesControllerStateChangeEvent = + ControllerStateChangeEvent< + typeof controllerName, + TokenBalancesControllerState + >; + +export type TokenBalancesControllerEvents = + TokenBalancesControllerStateChangeEvent; export type TokenBalancesControllerMessenger = RestrictedControllerMessenger< typeof controllerName, - never, - never, + TokenBalancesControllerActions, + TokenBalancesControllerEvents, never, never >; +const getDefaultState = (): TokenBalancesControllerState => { + return { + contractBalances: {}, + }; +}; + /** * Controller that passively polls on a set interval token balances * for tokens stored in the TokensController diff --git a/packages/assets-controllers/src/index.ts b/packages/assets-controllers/src/index.ts index fc1a068472..6dbfdfd39c 100644 --- a/packages/assets-controllers/src/index.ts +++ b/packages/assets-controllers/src/index.ts @@ -3,7 +3,15 @@ export * from './AssetsContractController'; export * from './CurrencyRateController'; export * from './NftController'; export * from './NftDetectionController'; -export * from './TokenBalancesController'; +export type { + TokenBalancesControllerOptions, + TokenBalancesControllerMessenger, + TokenBalancesControllerActions, + TokenBalancesControllerGetStateAction, + TokenBalancesControllerEvents, + TokenBalancesControllerStateChangeEvent, +} from './TokenBalancesController'; +export { TokenBalancesController } from './TokenBalancesController'; export type { TokenDetectionControllerMessenger, TokenDetectionControllerActions, From 3e791e34e3bba164122f72f5fd52158e9b64fdec Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Wed, 10 Jan 2024 17:59:11 +0100 Subject: [PATCH 6/9] fix: replace sinon fake timers with jest built in timers --- .../src/TokenBalancesController.test.ts | 138 +++++++++--------- .../src/TokenRatesController.test.ts | 14 +- .../src/TransactionController.test.ts | 12 +- tests/helpers.ts | 12 ++ 4 files changed, 80 insertions(+), 96 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index d08b14a830..485bfe5f2e 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -1,8 +1,7 @@ import { ControllerMessenger } from '@metamask/base-controller'; import { BN } from 'ethereumjs-util'; -import * as sinon from 'sinon'; -import { advanceTime } from '../../../tests/helpers'; +import { flushPromises } from '../../../tests/helpers'; import { TokenBalancesController } from './TokenBalancesController'; import { getDefaultTokensState, type TokensState } from './TokensController'; @@ -23,70 +22,69 @@ function getMessenger() { }); } +const getToken = (controler: TokenBalancesController, address: string) => { + return controler.getTokens().find((token) => token.address === address); +}; + describe('TokenBalancesController', () => { - let clock: sinon.SinonFakeTimers; - const getToken = ( - tokenBalances: TokenBalancesController, - address: string, - ) => { - return tokenBalances.getTokens().find((token) => token.address === address); - }; beforeEach(() => { - clock = sinon.useFakeTimers(); + jest.useFakeTimers(); }); afterEach(() => { - clock.restore(); - sinon.restore(); + jest.restoreAllMocks(); + jest.useRealTimers(); }); it('should set default state', () => { - const tokenBalances = new TokenBalancesController({ - onTokensStateChange: sinon.stub(), + const controller = new TokenBalancesController({ + onTokensStateChange: jest.fn(), getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), + getERC20BalanceOf: jest.fn(), messenger: getMessenger(), }); - expect(tokenBalances.state).toStrictEqual({ contractBalances: {} }); + expect(controller.state).toStrictEqual({ contractBalances: {} }); }); it('should poll and update balances in the right interval', async () => { - const mock = sinon.stub( + const updateBalancesSpy = jest.spyOn( TokenBalancesController.prototype, 'updateBalances', ); + new TokenBalancesController({ interval: 10, - onTokensStateChange: sinon.stub(), + onTokensStateChange: jest.fn(), getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), + getERC20BalanceOf: jest.fn(), messenger: getMessenger(), }); + await flushPromises(); - expect(mock.called).toBe(true); - expect(mock.calledTwice).toBe(false); + expect(updateBalancesSpy).toHaveBeenCalled(); + expect(updateBalancesSpy).not.toHaveBeenCalledTimes(2); - await advanceTime({ clock, duration: 15 }); + jest.advanceTimersByTime(15); - expect(mock.calledTwice).toBe(true); + expect(updateBalancesSpy).toHaveBeenCalledTimes(2); }); it('should not update banlances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const tokenBalances = new TokenBalancesController({ + const controller = new TokenBalancesController({ disabled: true, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - onTokensStateChange: sinon.stub(), + onTokensStateChange: jest.fn(), getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub().returns(new BN(1)), + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), messenger: getMessenger(), }); - await tokenBalances.updateBalances(); + await controller.updateBalances(); - expect(tokenBalances.state.contractBalances).toStrictEqual({}); + expect(controller.state.contractBalances).toStrictEqual({}); }); it('should update banlances if controller is manually enabled', async () => { @@ -95,9 +93,9 @@ describe('TokenBalancesController', () => { disabled: true, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - onTokensStateChange: sinon.stub(), + onTokensStateChange: jest.fn(), getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub().returns(new BN(1)), + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), messenger: getMessenger(), }); @@ -121,7 +119,7 @@ describe('TokenBalancesController', () => { tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub().returns(new BN(1)), + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), onTokensStateChange: (listener) => { tokensStateChangeListeners.push(listener); }, @@ -157,55 +155,53 @@ describe('TokenBalancesController', () => { }); it('should clear previous interval', async () => { - const mock = sinon.stub(global, 'clearTimeout'); - const tokenBalances = new TokenBalancesController({ + const controller = new TokenBalancesController({ interval: 1337, - onTokensStateChange: sinon.stub(), + onTokensStateChange: jest.fn(), getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub(), + getERC20BalanceOf: jest.fn(), messenger: getMessenger(), }); - await tokenBalances.poll(1338); - await advanceTime({ clock, duration: 1339 }); + const mockClearTimeout = jest.spyOn(global, 'clearTimeout'); + + await controller.poll(1338); + + jest.advanceTimersByTime(1339); - expect(mock.called).toBe(true); + expect(mockClearTimeout).toHaveBeenCalled(); }); it('should update all balances', async () => { const selectedAddress = '0x0000000000000000000000000000000000000001'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const tokenBalances = new TokenBalancesController({ + const controller = new TokenBalancesController({ interval: 1337, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], onTokensStateChange: jest.fn(), getSelectedAddress: () => selectedAddress, - getERC20BalanceOf: sinon.stub().returns(new BN(1)), + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), messenger: getMessenger(), }); - expect(tokenBalances.state.contractBalances).toStrictEqual({}); + expect(controller.state.contractBalances).toStrictEqual({}); - await tokenBalances.updateBalances(); + await controller.updateBalances(); - const mytoken = getToken(tokenBalances, address); + const mytoken = getToken(controller, address); expect(mytoken?.balanceError).toBeNull(); - expect(Object.keys(tokenBalances.state.contractBalances)).toContain( - address, - ); - expect(tokenBalances.state.contractBalances[address].toString()).not.toBe( - '0', - ); + expect(Object.keys(controller.state.contractBalances)).toContain(address); + expect(controller.state.contractBalances[address].toString()).not.toBe('0'); }); it('should handle `getERC20BalanceOf` error case', async () => { const errorMsg = 'Failed to get balance'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const getERC20BalanceOfStub = sinon - .stub() - .returns(Promise.reject(new Error(errorMsg))); - const tokenBalances = new TokenBalancesController({ + const getERC20BalanceOfStub = jest + .fn() + .mockReturnValue(Promise.reject(new Error(errorMsg))); + const controller = new TokenBalancesController({ interval: 1337, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], onTokensStateChange: jest.fn(), @@ -214,31 +210,27 @@ describe('TokenBalancesController', () => { messenger: getMessenger(), }); - expect(tokenBalances.state.contractBalances).toStrictEqual({}); + expect(controller.state.contractBalances).toStrictEqual({}); - await tokenBalances.updateBalances(); + await controller.updateBalances(); - const mytoken = getToken(tokenBalances, address); + const mytoken = getToken(controller, address); expect(mytoken?.balanceError).toBeInstanceOf(Error); expect(mytoken?.balanceError).toHaveProperty('message', errorMsg); - expect(tokenBalances.state.contractBalances[address].toString()).toBe('0'); + expect(controller.state.contractBalances[address].toString()).toBe('0'); - getERC20BalanceOfStub.returns(new BN(1)); + getERC20BalanceOfStub.mockReturnValue(new BN(1)); - await tokenBalances.updateBalances(); + await controller.updateBalances(); expect(mytoken?.balanceError).toBeNull(); - expect(Object.keys(tokenBalances.state.contractBalances)).toContain( - address, - ); - expect(tokenBalances.state.contractBalances[address].toString()).not.toBe( - 0, - ); + expect(Object.keys(controller.state.contractBalances)).toContain(address); + expect(controller.state.contractBalances[address].toString()).not.toBe(0); }); it('should update balances when tokens change', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; - const tokenBalances = new TokenBalancesController({ + const controller = new TokenBalancesController({ onTokensStateChange: (listener) => { tokensStateChangeListeners.push(listener); }, @@ -252,7 +244,7 @@ describe('TokenBalancesController', () => { listener(state); } }; - const updateBalances = sinon.stub(tokenBalances, 'updateBalances'); + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); triggerTokensStateChange({ ...getDefaultTokensState(), @@ -265,18 +257,18 @@ describe('TokenBalancesController', () => { ], }); - expect(updateBalances.called).toBe(true); + expect(updateBalancesSpy).toHaveBeenCalled(); }); it('should update token balances when detected tokens are added', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; - const tokenBalances = new TokenBalancesController({ + const controller = new TokenBalancesController({ interval: 1337, onTokensStateChange: (listener) => { tokensStateChangeListeners.push(listener); }, getSelectedAddress: () => '0x1234', - getERC20BalanceOf: sinon.stub().returns(new BN(1)), + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), messenger: getMessenger(), }); const triggerTokensStateChange = (state: TokensState) => { @@ -284,7 +276,7 @@ describe('TokenBalancesController', () => { listener(state); } }; - expect(tokenBalances.state.contractBalances).toStrictEqual({}); + expect(controller.state.contractBalances).toStrictEqual({}); triggerTokensStateChange({ ...getDefaultTokensState(), @@ -300,9 +292,9 @@ describe('TokenBalancesController', () => { tokens: [], }); - await tokenBalances.updateBalances(); + await controller.updateBalances(); - expect(tokenBalances.state.contractBalances).toStrictEqual({ + expect(controller.state.contractBalances).toStrictEqual({ '0x02': new BN(1).toString(16), }); }); diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index f865b31d80..d77143a607 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -10,7 +10,7 @@ import { add0x } from '@metamask/utils'; import nock from 'nock'; import { useFakeTimers } from 'sinon'; -import { advanceTime } from '../../../tests/helpers'; +import { advanceTime, flushPromises } from '../../../tests/helpers'; import { TOKEN_PRICES_BATCH_SIZE } from './assetsUtil'; import type { AbstractTokenPricesService, @@ -2291,18 +2291,6 @@ async function withController( } } -/** - * Resolve all pending promises. - * - * This method is used for async tests that use fake timers. - * See https://stackoverflow.com/a/58716087 and https://jestjs.io/docs/timer-mocks. - * - * TODO: migrate this to @metamask/utils - */ -async function flushPromises(): Promise { - await new Promise(jest.requireActual('timers').setImmediate); -} - /** * Call an "update exchange rates" method with the given parameters. * diff --git a/packages/transaction-controller/src/TransactionController.test.ts b/packages/transaction-controller/src/TransactionController.test.ts index 65ceaea740..00817adfd8 100644 --- a/packages/transaction-controller/src/TransactionController.test.ts +++ b/packages/transaction-controller/src/TransactionController.test.ts @@ -23,6 +23,7 @@ import { errorCodes, providerErrors, rpcErrors } from '@metamask/rpc-errors'; import * as NonceTrackerPackage from 'nonce-tracker'; import { FakeBlockTracker } from '../../../tests/fake-block-tracker'; +import { flushPromises } from '../../../tests/helpers'; import { mockNetwork } from '../../../tests/mock-network'; import { IncomingTransactionHelper } from './helpers/IncomingTransactionHelper'; import { PendingTransactionTracker } from './helpers/PendingTransactionTracker'; @@ -306,15 +307,6 @@ function waitForTransactionFinished( }); } -/** - * Resolve all pending promises. - * This method is used for async tests that use fake timers. - * See https://stackoverflow.com/a/58716087 and https://jestjs.io/docs/timer-mocks. - */ -function flushPromises(): Promise { - return new Promise(jest.requireActual('timers').setImmediate); -} - const MOCK_PREFERENCES = { state: { selectedAddress: 'foo' } }; const INFURA_PROJECT_ID = '341eacb578dd44a1a049cbc5f6fd4035'; const GOERLI_PROVIDER = new HttpProvider( @@ -2640,7 +2632,7 @@ describe('TransactionController', () => { externalBaseFeePerGas, ); - await new Promise(jest.requireActual('timers').setImmediate); + await flushPromises(); expect(mockPostTransactionBalanceUpdatedListener).toHaveBeenCalledTimes( 1, diff --git a/tests/helpers.ts b/tests/helpers.ts index 114d47736c..ed0b2660f8 100644 --- a/tests/helpers.ts +++ b/tests/helpers.ts @@ -27,3 +27,15 @@ export async function advanceTime({ duration -= stepSize; } while (duration > 0); } + +/** + * Resolve all pending promises. + * + * This method is used for async tests that use fake timers. + * See https://stackoverflow.com/a/58716087 and https://jestjs.io/docs/timer-mocks. + * + * TODO: migrate this to @metamask/utils + */ +export async function flushPromises(): Promise { + await new Promise(jest.requireActual('timers').setImmediate); +} From 12d60f262bb8ca8c01877c2aadf512f4676682f0 Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Thu, 11 Jan 2024 14:25:24 +0100 Subject: [PATCH 7/9] fix: remove getTokens unusable function --- .../src/TokenBalancesController.test.ts | 65 +++++++++++-------- .../src/TokenBalancesController.ts | 25 +++---- packages/assets-controllers/src/index.ts | 1 - 3 files changed, 46 insertions(+), 45 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 485bfe5f2e..b2dcc8d666 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -3,6 +3,7 @@ import { BN } from 'ethereumjs-util'; import { flushPromises } from '../../../tests/helpers'; import { TokenBalancesController } from './TokenBalancesController'; +import type { Token } from './TokenRatesController'; import { getDefaultTokensState, type TokensState } from './TokensController'; const controllerName = 'TokenBalancesController'; @@ -22,17 +23,12 @@ function getMessenger() { }); } -const getToken = (controler: TokenBalancesController, address: string) => { - return controler.getTokens().find((token) => token.address === address); -}; - describe('TokenBalancesController', () => { beforeEach(() => { jest.useFakeTimers(); }); afterEach(() => { - jest.restoreAllMocks(); jest.useRealTimers(); }); @@ -89,7 +85,7 @@ describe('TokenBalancesController', () => { it('should update banlances if controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const tokenBalances = new TokenBalancesController({ + const controller = new TokenBalancesController({ disabled: true, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, @@ -99,14 +95,14 @@ describe('TokenBalancesController', () => { messenger: getMessenger(), }); - await tokenBalances.updateBalances(); + await controller.updateBalances(); - expect(tokenBalances.state.contractBalances).toStrictEqual({}); + expect(controller.state.contractBalances).toStrictEqual({}); - tokenBalances.enable(); - await tokenBalances.updateBalances(); + controller.enable(); + await controller.updateBalances(); - expect(tokenBalances.state.contractBalances).toStrictEqual({ + expect(controller.state.contractBalances).toStrictEqual({ '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', }); }); @@ -114,7 +110,7 @@ describe('TokenBalancesController', () => { it('should not update banlances if controller is manually disabled', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - const tokenBalances = new TokenBalancesController({ + const controller = new TokenBalancesController({ disabled: false, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, @@ -131,13 +127,13 @@ describe('TokenBalancesController', () => { } }; - await tokenBalances.updateBalances(); + await controller.updateBalances(); - expect(tokenBalances.state.contractBalances).toStrictEqual({ + expect(controller.state.contractBalances).toStrictEqual({ '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', }); - tokenBalances.disable(); + controller.disable(); triggerTokensStateChange({ ...getDefaultTokensState(), tokens: [ @@ -149,7 +145,7 @@ describe('TokenBalancesController', () => { ], }); - expect(tokenBalances.state.contractBalances).toStrictEqual({ + expect(controller.state.contractBalances).toStrictEqual({ '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', }); }); @@ -175,9 +171,17 @@ describe('TokenBalancesController', () => { it('should update all balances', async () => { const selectedAddress = '0x0000000000000000000000000000000000000001'; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const tokens: Token[] = [ + { + address, + decimals: 18, + symbol: 'EOS', + aggregators: [], + }, + ]; const controller = new TokenBalancesController({ interval: 1337, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + tokens, onTokensStateChange: jest.fn(), getSelectedAddress: () => selectedAddress, getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), @@ -188,11 +192,9 @@ describe('TokenBalancesController', () => { await controller.updateBalances(); - const mytoken = getToken(controller, address); - - expect(mytoken?.balanceError).toBeNull(); + expect(tokens[0].balanceError).toBeNull(); expect(Object.keys(controller.state.contractBalances)).toContain(address); - expect(controller.state.contractBalances[address].toString()).not.toBe('0'); + expect(controller.state.contractBalances[address]).not.toBe('0'); }); it('should handle `getERC20BalanceOf` error case', async () => { @@ -201,9 +203,17 @@ describe('TokenBalancesController', () => { const getERC20BalanceOfStub = jest .fn() .mockReturnValue(Promise.reject(new Error(errorMsg))); + const tokens: Token[] = [ + { + address, + decimals: 18, + symbol: 'EOS', + aggregators: [], + }, + ]; const controller = new TokenBalancesController({ interval: 1337, - tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + tokens, onTokensStateChange: jest.fn(), getSelectedAddress: jest.fn(), getERC20BalanceOf: getERC20BalanceOfStub, @@ -214,18 +224,17 @@ describe('TokenBalancesController', () => { await controller.updateBalances(); - const mytoken = getToken(controller, address); - expect(mytoken?.balanceError).toBeInstanceOf(Error); - expect(mytoken?.balanceError).toHaveProperty('message', errorMsg); - expect(controller.state.contractBalances[address].toString()).toBe('0'); + expect(tokens[0].balanceError).toBeInstanceOf(Error); + expect(tokens[0].balanceError).toHaveProperty('message', errorMsg); + expect(controller.state.contractBalances[address]).toBe('0'); getERC20BalanceOfStub.mockReturnValue(new BN(1)); await controller.updateBalances(); - expect(mytoken?.balanceError).toBeNull(); + expect(tokens[0].balanceError).toBeNull(); expect(Object.keys(controller.state.contractBalances)).toContain(address); - expect(controller.state.contractBalances[address].toString()).not.toBe(0); + expect(controller.state.contractBalances[address]).not.toBe(0); }); it('should update balances when tokens change', async () => { diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 4d92940ae5..0917b05a29 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -6,7 +6,6 @@ import { } from '@metamask/base-controller'; import { safelyExecute } from '@metamask/controller-utils'; import type { PreferencesState } from '@metamask/preferences-controller'; -import { BN } from 'ethereumjs-util'; import type { AssetsContractController } from './AssetsContractController'; import type { Token } from './TokenRatesController'; @@ -21,8 +20,6 @@ const metadata = { }; /** - * @type TokenBalancesControllerOptions - * * Token balances controller options * @property interval - Polling interval used to fetch new token balances. * @property tokens - List of tokens to track balances for. @@ -31,7 +28,7 @@ const metadata = { * @property getSelectedAddress - Gets the current selected address. * @property getERC20BalanceOf - Gets the balance of the given account at the given contract address. */ -export type TokenBalancesControllerOptions = { +type TokenBalancesControllerOptions = { interval?: number; tokens?: Token[]; disabled?: boolean; @@ -48,8 +45,6 @@ export type TokenBalancesControllerOptions = { type ContractBalances = Record; /** - * @type TokenBalancesControllerState - * * Token balances controller state * @property contractBalances - Hash of token contract addresses to balances */ @@ -82,11 +77,16 @@ export type TokenBalancesControllerMessenger = RestrictedControllerMessenger< never >; -const getDefaultState = (): TokenBalancesControllerState => { +/** + * Get the default TokenBalancesController state. + * + * @returns The default TokenBalancesController state. + */ +function getDefaultState(): TokenBalancesControllerState { return { contractBalances: {}, }; -}; +} /** * Controller that passively polls on a set interval token balances @@ -176,13 +176,6 @@ export class TokenBalancesController extends BaseController< this.#disabled = true; } - /* - * Lists all tracked tokens. - */ - getTokens() { - return this.#tokens; - } - /** * Starts a new polling interval. * @@ -221,7 +214,7 @@ export class TokenBalancesController extends BaseController< ).toString(16); token.balanceError = null; } catch (error) { - newContractBalances[address] = new BN(0).toString(16); + newContractBalances[address] = '0'; token.balanceError = error; } } diff --git a/packages/assets-controllers/src/index.ts b/packages/assets-controllers/src/index.ts index 6dbfdfd39c..1f2786d784 100644 --- a/packages/assets-controllers/src/index.ts +++ b/packages/assets-controllers/src/index.ts @@ -4,7 +4,6 @@ export * from './CurrencyRateController'; export * from './NftController'; export * from './NftDetectionController'; export type { - TokenBalancesControllerOptions, TokenBalancesControllerMessenger, TokenBalancesControllerActions, TokenBalancesControllerGetStateAction, From fbf5c0fa0e097e9701ef692b4fb20decc58026e0 Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Thu, 11 Jan 2024 15:43:13 +0100 Subject: [PATCH 8/9] fix: add more test cases --- .../src/TokenBalancesController.test.ts | 106 ++++++++++++++++-- 1 file changed, 95 insertions(+), 11 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index b2dcc8d666..c720da5e94 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -66,7 +66,26 @@ describe('TokenBalancesController', () => { expect(updateBalancesSpy).toHaveBeenCalledTimes(2); }); - it('should not update banlances if disabled', async () => { + it('should update balances if enabled', async () => { + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: jest.fn(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), + }); + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + }); + }); + + it('should not update balances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; const controller = new TokenBalancesController({ disabled: true, @@ -83,7 +102,7 @@ describe('TokenBalancesController', () => { expect(controller.state.contractBalances).toStrictEqual({}); }); - it('should update banlances if controller is manually enabled', async () => { + it('should update balances if controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; const controller = new TokenBalancesController({ disabled: true, @@ -107,7 +126,74 @@ describe('TokenBalancesController', () => { }); }); - it('should not update banlances if controller is manually disabled', async () => { + it('should not update balances if controller is manually disabled', async () => { + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: false, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + onTokensStateChange: jest.fn(), + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + messenger: getMessenger(), + }); + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + }); + + controller.disable(); + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + }); + }); + + it('should update balances if tokens change and controller is manually enabled', async () => { + const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const controller = new TokenBalancesController({ + disabled: true, + tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], + interval: 10, + getSelectedAddress: () => '0x1234', + getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), + onTokensStateChange: (listener) => { + tokensStateChangeListeners.push(listener); + }, + messenger: getMessenger(), + }); + const triggerTokensStateChange = async (state: TokensState) => { + for (const listener of tokensStateChangeListeners) { + listener(state); + } + }; + + await controller.updateBalances(); + + expect(controller.state.contractBalances).toStrictEqual({}); + + controller.enable(); + await triggerTokensStateChange({ + ...getDefaultTokensState(), + tokens: [ + { + address: '0x00', + symbol: 'FOO', + decimals: 18, + }, + ], + }); + + expect(controller.state.contractBalances).toStrictEqual({ + '0x00': '1', + }); + }); + + it('should not update balances if tokens change and controller is manually disabled', async () => { const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; const controller = new TokenBalancesController({ @@ -121,7 +207,7 @@ describe('TokenBalancesController', () => { }, messenger: getMessenger(), }); - const triggerTokensStateChange = (state: TokensState) => { + const triggerTokensStateChange = async (state: TokensState) => { for (const listener of tokensStateChangeListeners) { listener(state); } @@ -134,7 +220,7 @@ describe('TokenBalancesController', () => { }); controller.disable(); - triggerTokensStateChange({ + await triggerTokensStateChange({ ...getDefaultTokensState(), tokens: [ { @@ -248,14 +334,14 @@ describe('TokenBalancesController', () => { interval: 1337, messenger: getMessenger(), }); - const triggerTokensStateChange = (state: TokensState) => { + const triggerTokensStateChange = async (state: TokensState) => { for (const listener of tokensStateChangeListeners) { listener(state); } }; const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); - triggerTokensStateChange({ + await triggerTokensStateChange({ ...getDefaultTokensState(), tokens: [ { @@ -280,14 +366,14 @@ describe('TokenBalancesController', () => { getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), messenger: getMessenger(), }); - const triggerTokensStateChange = (state: TokensState) => { + const triggerTokensStateChange = async (state: TokensState) => { for (const listener of tokensStateChangeListeners) { listener(state); } }; expect(controller.state.contractBalances).toStrictEqual({}); - triggerTokensStateChange({ + await triggerTokensStateChange({ ...getDefaultTokensState(), detectedTokens: [ { @@ -301,8 +387,6 @@ describe('TokenBalancesController', () => { tokens: [], }); - await controller.updateBalances(); - expect(controller.state.contractBalances).toStrictEqual({ '0x02': new BN(1).toString(16), }); From 9c789c1444dbcbda4d5e66800222f5a32dfe95ad Mon Sep 17 00:00:00 2001 From: Salah-Eddine Saakoun Date: Thu, 11 Jan 2024 18:24:42 +0100 Subject: [PATCH 9/9] fix: use toHex instead of toString(16) --- .../src/TokenBalancesController.test.ts | 21 ++++++++++--------- .../src/TokenBalancesController.ts | 10 ++++----- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index c720da5e94..93c08b7508 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -1,4 +1,5 @@ import { ControllerMessenger } from '@metamask/base-controller'; +import { toHex } from '@metamask/controller-utils'; import { BN } from 'ethereumjs-util'; import { flushPromises } from '../../../tests/helpers'; @@ -81,7 +82,7 @@ describe('TokenBalancesController', () => { await controller.updateBalances(); expect(controller.state.contractBalances).toStrictEqual({ - '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), }); }); @@ -122,7 +123,7 @@ describe('TokenBalancesController', () => { await controller.updateBalances(); expect(controller.state.contractBalances).toStrictEqual({ - '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), }); }); @@ -141,14 +142,14 @@ describe('TokenBalancesController', () => { await controller.updateBalances(); expect(controller.state.contractBalances).toStrictEqual({ - '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), }); controller.disable(); await controller.updateBalances(); expect(controller.state.contractBalances).toStrictEqual({ - '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), }); }); @@ -189,7 +190,7 @@ describe('TokenBalancesController', () => { }); expect(controller.state.contractBalances).toStrictEqual({ - '0x00': '1', + '0x00': toHex(new BN(1)), }); }); @@ -216,7 +217,7 @@ describe('TokenBalancesController', () => { await controller.updateBalances(); expect(controller.state.contractBalances).toStrictEqual({ - '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), }); controller.disable(); @@ -232,7 +233,7 @@ describe('TokenBalancesController', () => { }); expect(controller.state.contractBalances).toStrictEqual({ - '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': '1', + '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0': toHex(new BN(1)), }); }); @@ -280,7 +281,7 @@ describe('TokenBalancesController', () => { expect(tokens[0].balanceError).toBeNull(); expect(Object.keys(controller.state.contractBalances)).toContain(address); - expect(controller.state.contractBalances[address]).not.toBe('0'); + expect(controller.state.contractBalances[address]).not.toBe(toHex(0)); }); it('should handle `getERC20BalanceOf` error case', async () => { @@ -312,7 +313,7 @@ describe('TokenBalancesController', () => { expect(tokens[0].balanceError).toBeInstanceOf(Error); expect(tokens[0].balanceError).toHaveProperty('message', errorMsg); - expect(controller.state.contractBalances[address]).toBe('0'); + expect(controller.state.contractBalances[address]).toBe(toHex(0)); getERC20BalanceOfStub.mockReturnValue(new BN(1)); @@ -388,7 +389,7 @@ describe('TokenBalancesController', () => { }); expect(controller.state.contractBalances).toStrictEqual({ - '0x02': new BN(1).toString(16), + '0x02': toHex(new BN(1)), }); }); }); diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 0917b05a29..82d94b57a8 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -4,7 +4,7 @@ import { type ControllerStateChangeEvent, BaseController, } from '@metamask/base-controller'; -import { safelyExecute } from '@metamask/controller-utils'; +import { safelyExecute, toHex } from '@metamask/controller-utils'; import type { PreferencesState } from '@metamask/preferences-controller'; import type { AssetsContractController } from './AssetsContractController'; @@ -209,12 +209,12 @@ export class TokenBalancesController extends BaseController< for (const token of this.#tokens) { const { address } = token; try { - newContractBalances[address] = ( - await this.#getERC20BalanceOf(address, this.#getSelectedAddress()) - ).toString(16); + newContractBalances[address] = toHex( + await this.#getERC20BalanceOf(address, this.#getSelectedAddress()), + ); token.balanceError = null; } catch (error) { - newContractBalances[address] = '0'; + newContractBalances[address] = toHex(0); token.balanceError = error; } }