diff --git a/packages/assets-controllers/CHANGELOG.md b/packages/assets-controllers/CHANGELOG.md index 2f879da7c7..9c4e2b57c0 100644 --- a/packages/assets-controllers/CHANGELOG.md +++ b/packages/assets-controllers/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING:** Adds `@metamask/accounts-controller` ^8.0.0 and `@metamask/keyring-controller` ^12.0.0 as dependencies and peer dependencies. ([#3775](https://github.com/MetaMask/core/pull/3775/)). - **BREAKING:** `TokenDetectionController` newly subscribes to the `PreferencesController:stateChange`, `AccountsController:selectedAccountChange`, `KeyringController:lock`, `KeyringController:unlock` events, and allows the `PreferencesController:getState` messenger action. ([#3775](https://github.com/MetaMask/core/pull/3775/)) +- `TokensController` now exports `TokensControllerActions`, `TokensControllerGetStateAction`, `TokensControllerAddDetectedTokensAction`, `TokensControllerEvents`, `TokensControllerStateChangeEvent`. ([#3690](https://github.com/MetaMask/core/pull/3690/)) ### Changed @@ -21,11 +22,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING:** The `detectTokens` method now excludes tokens that are already included in the `TokensController`'s `detectedTokens` list from the batch of incoming tokens it sends to the `TokensController` `addDetectedTokens` method. - **BREAKING:** The constructor for `TokenDetectionController` expects a new required proprerty `trackMetaMetricsEvent`, which defines the callback that is called in the `detectTokens` method. - **BREAKING:** In Mainnet, even if the `PreferenceController`'s `useTokenDetection` option is set to false, automatic token detection is performed on the legacy token list (token data from the contract-metadata repo). + - **BREAKING:** The `TokensState` type is now defined as a type alias rather than an interface. ([#3690](https://github.com/MetaMask/core/pull/3690/)) + - This is breaking because it could affect how this type is used with other types, such as `Json`, which does not support TypeScript interfaces. ### Removed -- **BREAKING:** `TokenDetectionController` constructor no longer accepts options `onPreferencesStateChange`, `getPreferencesState`. ([#3775](https://github.com/MetaMask/core/pull/3775/)) +- **BREAKING:** `TokenDetectionController` constructor no longer accepts options `onPreferencesStateChange`, `getPreferencesState`, `getTokensState`, `addDetectedTokens`. ([#3690](https://github.com/MetaMask/core/pull/3690/), [#3775](https://github.com/MetaMask/core/pull/3775/)) - **BREAKING:** `TokenDetectionController` no longer allows the `NetworkController:stateChange` event. The `NetworkController:networkDidChange` event can be used instead. ([#3775](https://github.com/MetaMask/core/pull/3775/)) +- **BREAKING:** `TokensController` constructor no longer accepts options `onPreferencesStateChange`, `onNetworkDidChange`, `onTokenListStateChange`, `getNetworkClientById`. ([#3690](https://github.com/MetaMask/core/pull/3690/)) +- **BREAKING:** `TokenBalancesController` constructor no longer accepts options `onTokensStateChange`, `getSelectedAddress`. ([#3690](https://github.com/MetaMask/core/pull/3690/)) ## [25.0.0] @@ -136,8 +141,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - These are needed for the new "polling by `networkClientId`" feature - **BREAKING:** `AccountTrackerController` has a new required state property, `accountByChainId`([#3586](https://github.com/MetaMask/core/pull/3586)) - This is needed to track balances accross chains. It was introduced for the "polling by `networkClientId`" feature, but is useful on its own as well. -- **BREAKING**: `AccountTrackerController` adds a mutex to `refresh` making it only possible for one call to be executed at time ([#3586](https://github.com/MetaMask/core/pull/3586)) -- **BREAKING**: `TokensController.watchAsset` now performs on-chain validation of the asset's symbol and decimals, if they're defined in the contract ([#1745](https://github.com/MetaMask/core/pull/1745)) +- **BREAKING:** `AccountTrackerController` adds a mutex to `refresh` making it only possible for one call to be executed at time ([#3586](https://github.com/MetaMask/core/pull/3586)) +- **BREAKING:** `TokensController.watchAsset` now performs on-chain validation of the asset's symbol and decimals, if they're defined in the contract ([#1745](https://github.com/MetaMask/core/pull/1745)) - The `TokensController` constructor no longer accepts a `getERC20TokenName` option. It was no longer needed due to this change. - Add new method `_getProvider`, though this is intended for internal use and should not be called externally. - Additionally, if the symbol and decimals are defined in the contract, they are no longer required to be passed to `watchAsset` @@ -168,10 +173,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - This method was previously used in TokenRatesController to access the CoinGecko API. There is no equivalent. - **BREAKING:** Remove `CoinGeckoResponse` and `CoinGeckoPlatform` types ([#3600](https://github.com/MetaMask/core/pull/3600)) - These types were previously used in TokenRatesController to represent data returned from the CoinGecko API. There is no equivalent. -- **BREAKING**: The TokenRatesController now only supports updating and polling rates for tokens tracked by the TokensController ([#3639](https://github.com/MetaMask/core/pull/3639)) +- **BREAKING:** The TokenRatesController now only supports updating and polling rates for tokens tracked by the TokensController ([#3639](https://github.com/MetaMask/core/pull/3639)) - The `tokenAddresses` option has been removed from `startPollingByNetworkClientId` - The `tokenContractAddresses` option has been removed from `updateExchangeRatesByChainId` -- **BREAKING**: `TokenRatesController.fetchAndMapExchangeRates` is no longer exposed publicly ([#3621](https://github.com/MetaMask/core/pull/3621)) +- **BREAKING:** `TokenRatesController.fetchAndMapExchangeRates` is no longer exposed publicly ([#3621](https://github.com/MetaMask/core/pull/3621)) ### Fixed @@ -255,7 +260,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ``` - **BREAKING**: `CurrencyRateController` now extends `PollingController` ([#1805](https://github.com/MetaMask/core/pull/1805)) - `start()` and `stop()` methods replaced with `startPollingByNetworkClientId()`, `stopPollingByPollingToken()`, and `stopAllPolling()` -- **BREAKING**: `CurrencyRateController` now sends the `NetworkController:getNetworkClientById` action via messaging controller ([#1805](https://github.com/MetaMask/core/pull/1805)) +- **BREAKING:** `CurrencyRateController` now sends the `NetworkController:getNetworkClientById` action via messaging controller ([#1805](https://github.com/MetaMask/core/pull/1805)) ### Fixed @@ -360,7 +365,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 networkClientId?: NetworkClientId; } ``` -- **BREAKING**: Bump peer dependency on `@metamask/network-controller` to ^13.0.0 ([#1633](https://github.com/MetaMask/core/pull/1633)) +- **BREAKING:** Bump peer dependency on `@metamask/network-controller` to ^13.0.0 ([#1633](https://github.com/MetaMask/core/pull/1633)) - **CHANGED**: `TokensController.addToken` will use the chain ID value derived from state for `networkClientId` if provided ([#1676](https://github.com/MetaMask/core/pull/1676)) - **CHANGED**: `TokensController.addTokens` now accepts an optional `networkClientId` as the last parameter ([#1676](https://github.com/MetaMask/core/pull/1676)) - **CHANGED**: `TokensController.addTokens` will use the chain ID value derived from state for `networkClientId` if provided ([#1676](https://github.com/MetaMask/core/pull/1676)) @@ -425,13 +430,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: New required constructor parameters for the `TokenRatesController` ([#1497](https://github.com/MetaMask/core/pull/1497), [#1511](https://github.com/MetaMask/core/pull/1511)) - The new required parameters are `ticker`, `onSelectedAddress`, and `onPreferencesStateChange` -- **BREAKING**: Remove `onCurrencyRateStateChange` constructor parameter from `TokenRatesController` ([#1496](https://github.com/MetaMask/core/pull/1496)) -- **BREAKING**: Disable `TokenRatesController` automatic polling ([#1501](https://github.com/MetaMask/core/pull/1501)) +- **BREAKING:** Remove `onCurrencyRateStateChange` constructor parameter from `TokenRatesController` ([#1496](https://github.com/MetaMask/core/pull/1496)) +- **BREAKING:** Disable `TokenRatesController` automatic polling ([#1501](https://github.com/MetaMask/core/pull/1501)) - Polling must be started explicitly by calling the `start` method - The token rates are not updated upon state changes when polling is disabled. -- **BREAKING**: Replace the `poll` method with `start` ([#1501](https://github.com/MetaMask/core/pull/1501)) +- **BREAKING:** Replace the `poll` method with `start` ([#1501](https://github.com/MetaMask/core/pull/1501)) - The `start` method does not offer a way to change the interval. That must be done by calling `.configure` instead -- **BREAKING**: Remove `TokenRatecontroller` setter for `chainId` and `tokens` properties ([#1505](https://github.com/MetaMask/core/pull/1505)) +- **BREAKING:** Remove `TokenRatecontroller` setter for `chainId` and `tokens` properties ([#1505](https://github.com/MetaMask/core/pull/1505)) - Bump @metamask/abi-utils from 1.2.0 to 2.0.1 ([#1525](https://github.com/MetaMask/core/pull/1525)) - Update `@metamask/utils` to `^6.2.0` ([#1514](https://github.com/MetaMask/core/pull/1514)) - Remove unnecessary `babel-runtime` dependency ([#1504](https://github.com/MetaMask/core/pull/1504)) @@ -520,7 +525,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The tokens controller `addDetectedTokens` method now accepts the `chainId` property of the `detectionDetails` parameter to be of type `Hex` rather than decimal `string`. - The tokens controller state properties `allTokens`, `allIgnoredTokens`, and `allDetectedTokens` are now keyed by chain ID in `Hex` format rather than decimal `string`. - This requires a state migration -- **BREAKING**: Use approval controller for suggested assets ([#1261](https://github.com/MetaMask/core/pull/1261), [#1268](https://github.com/MetaMask/core/pull/1268)) +- **BREAKING:** Use approval controller for suggested assets ([#1261](https://github.com/MetaMask/core/pull/1261), [#1268](https://github.com/MetaMask/core/pull/1268)) - The actions `ApprovalController:acceptRequest` and `ApprovalController:rejectRequest` are no longer required by the token controller messenger. - The `suggestedAssets` state has been removed, which means that suggested assets are no longer persisted in state - The return type for `watchAsset` has changed. It now returns a Promise that settles after the request has been confirmed or rejected. diff --git a/packages/assets-controllers/jest.config.js b/packages/assets-controllers/jest.config.js index 34e91df691..71baa6a1b6 100644 --- a/packages/assets-controllers/jest.config.js +++ b/packages/assets-controllers/jest.config.js @@ -17,10 +17,10 @@ module.exports = merge(baseConfig, { // An object that configures minimum threshold enforcement for coverage results coverageThreshold: { global: { - branches: 88.3, + branches: 88.22, functions: 95.32, - lines: 96.69, - statements: 96.7, + lines: 96.68, + statements: 96.68, }, }, diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 93c08b7508..0797292ccb 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -3,6 +3,11 @@ import { toHex } from '@metamask/controller-utils'; import { BN } from 'ethereumjs-util'; import { flushPromises } from '../../../tests/helpers'; +import type { + AllowedActions, + AllowedEvents, + TokenBalancesControllerMessenger, +} from './TokenBalancesController'; import { TokenBalancesController } from './TokenBalancesController'; import type { Token } from './TokenRatesController'; import { getDefaultTokensState, type TokensState } from './TokensController'; @@ -12,21 +17,30 @@ const controllerName = 'TokenBalancesController'; /** * Constructs a restricted controller messenger. * + * @param controllerMessenger - The controller messenger to restrict. * @returns A restricted controller messenger. */ -function getMessenger() { - return new ControllerMessenger().getRestricted< - typeof controllerName, - never, - never - >({ +function getMessenger( + controllerMessenger = new ControllerMessenger< + AllowedActions, + AllowedEvents + >(), +): TokenBalancesControllerMessenger { + return controllerMessenger.getRestricted({ name: controllerName, + allowedActions: ['PreferencesController:getState'], + allowedEvents: ['TokensController:stateChange'], }); } describe('TokenBalancesController', () => { + let controllerMessenger: ControllerMessenger; + let messenger: TokenBalancesControllerMessenger; + beforeEach(() => { jest.useFakeTimers(); + controllerMessenger = new ControllerMessenger(); + messenger = getMessenger(controllerMessenger); }); afterEach(() => { @@ -34,17 +48,23 @@ describe('TokenBalancesController', () => { }); it('should set default state', () => { + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ - onTokensStateChange: jest.fn(), - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn(), - messenger: getMessenger(), + messenger, }); expect(controller.state).toStrictEqual({ contractBalances: {} }); }); it('should poll and update balances in the right interval', async () => { + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const updateBalancesSpy = jest.spyOn( TokenBalancesController.prototype, 'updateBalances', @@ -52,10 +72,8 @@ describe('TokenBalancesController', () => { new TokenBalancesController({ interval: 10, - onTokensStateChange: jest.fn(), - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn(), - messenger: getMessenger(), + messenger, }); await flushPromises(); @@ -69,14 +87,16 @@ describe('TokenBalancesController', () => { it('should update balances if enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ disabled: false, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - onTokensStateChange: jest.fn(), - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger: getMessenger(), + messenger, }); await controller.updateBalances(); @@ -88,14 +108,16 @@ describe('TokenBalancesController', () => { it('should not update balances if disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ disabled: true, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - onTokensStateChange: jest.fn(), - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger: getMessenger(), + messenger, }); await controller.updateBalances(); @@ -105,14 +127,16 @@ describe('TokenBalancesController', () => { it('should update balances if controller is manually enabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ disabled: true, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - onTokensStateChange: jest.fn(), - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger: getMessenger(), + messenger, }); await controller.updateBalances(); @@ -129,14 +153,16 @@ describe('TokenBalancesController', () => { it('should not update balances if controller is manually disabled', async () => { const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ disabled: false, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - onTokensStateChange: jest.fn(), - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger: getMessenger(), + messenger, }); await controller.updateBalances(); @@ -154,23 +180,20 @@ describe('TokenBalancesController', () => { }); it('should update balances if tokens change and controller is manually enabled', async () => { - const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ disabled: true, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - messenger: getMessenger(), + messenger, }); const triggerTokensStateChange = async (state: TokensState) => { - for (const listener of tokensStateChangeListeners) { - listener(state); - } + controllerMessenger.publish('TokensController:stateChange', state, []); }; await controller.updateBalances(); @@ -195,23 +218,20 @@ describe('TokenBalancesController', () => { }); it('should not update balances if tokens change and controller is manually disabled', async () => { - const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ disabled: false, tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }], interval: 10, - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - messenger: getMessenger(), + messenger, }); const triggerTokensStateChange = async (state: TokensState) => { - for (const listener of tokensStateChangeListeners) { - listener(state); - } + controllerMessenger.publish('TokensController:stateChange', state, []); }; await controller.updateBalances(); @@ -238,12 +258,14 @@ describe('TokenBalancesController', () => { }); it('should clear previous interval', async () => { + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ interval: 1337, - onTokensStateChange: jest.fn(), - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn(), - messenger: getMessenger(), + messenger, }); const mockClearTimeout = jest.spyOn(global, 'clearTimeout'); @@ -266,13 +288,15 @@ describe('TokenBalancesController', () => { aggregators: [], }, ]; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress }), + ); const controller = new TokenBalancesController({ interval: 1337, tokens, - onTokensStateChange: jest.fn(), - getSelectedAddress: () => selectedAddress, getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger: getMessenger(), + messenger, }); expect(controller.state.contractBalances).toStrictEqual({}); @@ -298,13 +322,16 @@ describe('TokenBalancesController', () => { aggregators: [], }, ]; + + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({}), + ); const controller = new TokenBalancesController({ interval: 1337, tokens, - onTokensStateChange: jest.fn(), - getSelectedAddress: jest.fn(), getERC20BalanceOf: getERC20BalanceOfStub, - messenger: getMessenger(), + messenger, }); expect(controller.state.contractBalances).toStrictEqual({}); @@ -325,20 +352,17 @@ describe('TokenBalancesController', () => { }); it('should update balances when tokens change', async () => { - const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - getSelectedAddress: jest.fn(), getERC20BalanceOf: jest.fn(), interval: 1337, - messenger: getMessenger(), + messenger, }); const triggerTokensStateChange = async (state: TokensState) => { - for (const listener of tokensStateChangeListeners) { - listener(state); - } + controllerMessenger.publish('TokensController:stateChange', state, []); }; const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); @@ -357,20 +381,17 @@ describe('TokenBalancesController', () => { }); it('should update token balances when detected tokens are added', async () => { - const tokensStateChangeListeners: ((state: TokensState) => void)[] = []; + controllerMessenger.registerActionHandler( + 'PreferencesController:getState', + jest.fn().mockReturnValue({ selectedAddress: '0x1234' }), + ); const controller = new TokenBalancesController({ interval: 1337, - onTokensStateChange: (listener) => { - tokensStateChangeListeners.push(listener); - }, - getSelectedAddress: () => '0x1234', getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)), - messenger: getMessenger(), + messenger, }); const triggerTokensStateChange = async (state: TokensState) => { - for (const listener of tokensStateChangeListeners) { - listener(state); - } + controllerMessenger.publish('TokensController:stateChange', state, []); }; expect(controller.state.contractBalances).toStrictEqual({}); diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 82d94b57a8..1acc2f226c 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -5,11 +5,11 @@ import { BaseController, } from '@metamask/base-controller'; import { safelyExecute, toHex } from '@metamask/controller-utils'; -import type { PreferencesState } from '@metamask/preferences-controller'; +import type { PreferencesControllerGetStateAction } from '@metamask/preferences-controller'; import type { AssetsContractController } from './AssetsContractController'; import type { Token } from './TokenRatesController'; -import type { TokensState } from './TokensController'; +import type { TokensControllerStateChangeEvent } from './TokensController'; const DEFAULT_INTERVAL = 180000; @@ -24,16 +24,12 @@ const metadata = { * @property interval - Polling interval used to fetch new token balances. * @property tokens - List of tokens to track balances for. * @property disabled - If set to true, all tracked tokens contract balances updates are blocked. - * @property onTokensStateChange - Allows subscribing to assets controller state changes. - * @property getSelectedAddress - Gets the current selected address. * @property getERC20BalanceOf - Gets the balance of the given account at the given contract address. */ type TokenBalancesControllerOptions = { interval?: number; tokens?: Token[]; disabled?: boolean; - onTokensStateChange: (listener: (tokenState: TokensState) => void) => void; - getSelectedAddress: () => PreferencesState['selectedAddress']; getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; messenger: TokenBalancesControllerMessenger; state?: Partial; @@ -60,6 +56,8 @@ export type TokenBalancesControllerGetStateAction = ControllerGetStateAction< export type TokenBalancesControllerActions = TokenBalancesControllerGetStateAction; +export type AllowedActions = PreferencesControllerGetStateAction; + export type TokenBalancesControllerStateChangeEvent = ControllerStateChangeEvent< typeof controllerName, @@ -69,12 +67,14 @@ export type TokenBalancesControllerStateChangeEvent = export type TokenBalancesControllerEvents = TokenBalancesControllerStateChangeEvent; +export type AllowedEvents = TokensControllerStateChangeEvent; + export type TokenBalancesControllerMessenger = RestrictedControllerMessenger< typeof controllerName, - TokenBalancesControllerActions, - TokenBalancesControllerEvents, - never, - never + TokenBalancesControllerActions | AllowedActions, + TokenBalancesControllerEvents | AllowedEvents, + AllowedActions['type'], + AllowedEvents['type'] >; /** @@ -82,7 +82,7 @@ export type TokenBalancesControllerMessenger = RestrictedControllerMessenger< * * @returns The default TokenBalancesController state. */ -function getDefaultState(): TokenBalancesControllerState { +export function getDefaultTokenBalancesState(): TokenBalancesControllerState { return { contractBalances: {}, }; @@ -99,8 +99,6 @@ export class TokenBalancesController extends BaseController< > { #handle?: ReturnType; - #getSelectedAddress: () => PreferencesState['selectedAddress']; - #getERC20BalanceOf: AssetsContractController['getERC20BalanceOf']; #interval: number; @@ -116,8 +114,6 @@ export class TokenBalancesController extends BaseController< * @param options.interval - Polling interval used to fetch new token balances. * @param options.tokens - List of tokens to track balances for. * @param options.disabled - If set to true, all tracked tokens contract balances updates are blocked. - * @param options.onTokensStateChange - Allows subscribing to assets controller state changes. - * @param options.getSelectedAddress - Gets the current selected address. * @param options.getERC20BalanceOf - Gets the balance of the given account at the given contract address. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller restricted messenger. @@ -126,8 +122,6 @@ export class TokenBalancesController extends BaseController< interval = DEFAULT_INTERVAL, tokens = [], disabled = false, - onTokensStateChange, - getSelectedAddress, getERC20BalanceOf, messenger, state = {}, @@ -137,7 +131,7 @@ export class TokenBalancesController extends BaseController< metadata, messenger, state: { - ...getDefaultState(), + ...getDefaultTokenBalancesState(), ...state, }, }); @@ -146,22 +140,19 @@ export class TokenBalancesController extends BaseController< this.#interval = interval; this.#tokens = tokens; - onTokensStateChange(this.#tokensStateChangeListener.bind(this)); + this.messagingSystem.subscribe( + 'TokensController:stateChange', + ({ tokens: newTokens, detectedTokens }) => { + this.#tokens = [...newTokens, ...detectedTokens]; + this.updateBalances(); + }, + ); - this.#getSelectedAddress = getSelectedAddress; this.#getERC20BalanceOf = getERC20BalanceOf; this.poll(); } - /* - * Tokens state changes listener. - */ - #tokensStateChangeListener({ tokens, detectedTokens }: TokensState) { - this.#tokens = [...tokens, ...detectedTokens]; - this.updateBalances(); - } - /** * Allows controller to update tracked tokens contract balances. */ @@ -208,9 +199,12 @@ export class TokenBalancesController extends BaseController< const newContractBalances: ContractBalances = {}; for (const token of this.#tokens) { const { address } = token; + const { selectedAddress } = this.messagingSystem.call( + 'PreferencesController:getState', + ); try { newContractBalances[address] = toHex( - await this.#getERC20BalanceOf(address, this.#getSelectedAddress()), + await this.#getERC20BalanceOf(address, selectedAddress), ); token.balanceError = null; } catch (error) { diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index be64826344..ba6522cedb 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -40,6 +40,7 @@ import { type TokenListState, type TokenListToken, } from './TokenListController'; +import type { TokensState } from './TokensController'; import { getDefaultTokensState } from './TokensController'; const DEFAULT_INTERVAL = 180000; @@ -138,6 +139,8 @@ function buildTokenDetectionControllerMessenger( allowedActions: [ 'KeyringController:getState', 'NetworkController:getNetworkConfigurationByNetworkClientId', + 'TokensController:getState', + 'TokensController:addDetectedTokens', 'TokenListController:getState', 'PreferencesController:getState', ], @@ -222,18 +225,20 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ + controller, + mockTokenListGetState, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -251,10 +256,14 @@ describe('TokenDetectionController', () => { await controller.start(); - expect(mockAddDetectedTokens).toHaveBeenCalledWith([sampleTokenA], { - chainId: ChainId.mainnet, - selectedAddress, - }); + expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: ChainId.mainnet, + selectedAddress, + }, + ); }, ); }); @@ -263,18 +272,20 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: 'polygon', selectedAddress, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ + controller, + mockTokenListGetState, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -292,10 +303,14 @@ describe('TokenDetectionController', () => { await controller.start(); - expect(mockAddDetectedTokens).toHaveBeenCalledWith([sampleTokenA], { - chainId: '0x89', - selectedAddress, - }); + expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: '0x89', + selectedAddress, + }, + ); }, ); }); @@ -305,20 +320,22 @@ describe('TokenDetectionController', () => { [sampleTokenA.address]: new BN(1), [sampleTokenB.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; const interval = 100; await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, getBalancesInSingleCall: mockGetBalancesInSingleCall, interval, networkClientId: NetworkType.mainnet, selectedAddress, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ + controller, + mockTokenListGetState, + mockAddDetectedTokens, + }) => { const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -335,7 +352,6 @@ describe('TokenDetectionController', () => { }; mockTokenListGetState(tokenListState); await controller.start(); - mockAddDetectedTokens.mockReset(); tokenListState.tokenList[sampleTokenB.address] = { name: sampleTokenB.name as string, @@ -350,6 +366,7 @@ describe('TokenDetectionController', () => { await advanceTime({ clock, duration: interval }); expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', [sampleTokenA, sampleTokenB], { chainId: ChainId.mainnet, @@ -364,23 +381,25 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); - const mockGetTokensState = jest.fn().mockReturnValue({ - ...getDefaultTokensState(), - ignoredTokens: [sampleTokenA.address], - }); const selectedAddress = '0x0000000000000000000000000000000000000001'; await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, getBalancesInSingleCall: mockGetBalancesInSingleCall, - getTokensState: mockGetTokensState, networkClientId: NetworkType.mainnet, selectedAddress, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ + controller, + mockTokensGetState, + mockTokenListGetState, + mockAddDetectedTokens, + }) => { + mockTokensGetState({ + ...getDefaultTokensState(), + ignoredTokens: [sampleTokenA.address], + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -398,7 +417,9 @@ describe('TokenDetectionController', () => { await controller.start(); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -407,17 +428,19 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress: '', }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ + controller, + mockTokenListGetState, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -435,7 +458,9 @@ describe('TokenDetectionController', () => { await controller.start(); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -456,7 +481,6 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const firstSelectedAddress = '0x0000000000000000000000000000000000000001'; const secondSelectedAddress = @@ -464,14 +488,17 @@ describe('TokenDetectionController', () => { await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress: firstSelectedAddress, }, }, - async ({ mockTokenListGetState, triggerPreferencesStateChange }) => { + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -494,10 +521,14 @@ describe('TokenDetectionController', () => { }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).toHaveBeenCalledWith([sampleTokenA], { - chainId: ChainId.mainnet, - selectedAddress: secondSelectedAddress, - }); + expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: ChainId.mainnet, + selectedAddress: secondSelectedAddress, + }, + ); }, ); }); @@ -506,19 +537,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, }, - async ({ mockTokenListGetState, triggerPreferencesStateChange }) => { + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -548,10 +581,14 @@ describe('TokenDetectionController', () => { }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).toHaveBeenCalledWith([sampleTokenA], { - chainId: ChainId.mainnet, - selectedAddress, - }); + expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: ChainId.mainnet, + selectedAddress, + }, + ); }, ); }); @@ -560,7 +597,6 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const firstSelectedAddress = '0x0000000000000000000000000000000000000001'; const secondSelectedAddress = @@ -568,14 +604,17 @@ describe('TokenDetectionController', () => { await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress: firstSelectedAddress, }, }, - async ({ mockTokenListGetState, triggerPreferencesStateChange }) => { + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -598,7 +637,9 @@ describe('TokenDetectionController', () => { }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -607,19 +648,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, }, - async ({ mockTokenListGetState, triggerPreferencesStateChange }) => { + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -642,7 +685,9 @@ describe('TokenDetectionController', () => { }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -653,7 +698,6 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const firstSelectedAddress = '0x0000000000000000000000000000000000000001'; const secondSelectedAddress = @@ -661,14 +705,17 @@ describe('TokenDetectionController', () => { await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress: firstSelectedAddress, }, }, - async ({ mockTokenListGetState, triggerPreferencesStateChange }) => { + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -691,7 +738,9 @@ describe('TokenDetectionController', () => { }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -700,19 +749,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, }, - async ({ mockTokenListGetState, triggerPreferencesStateChange }) => { + async ({ + mockTokenListGetState, + triggerPreferencesStateChange, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -742,7 +793,9 @@ describe('TokenDetectionController', () => { }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -764,24 +817,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ mockTokenListGetState }) => { + async ({ + mockTokenListGetState, + mockAddDetectedTokens, + triggerNetworkDidChange, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -797,16 +847,20 @@ describe('TokenDetectionController', () => { }, }); - messenger.publish('NetworkController:networkDidChange', { + triggerNetworkDidChange({ ...defaultNetworkState, selectedNetworkClientId: 'polygon', }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).toHaveBeenCalledWith([sampleTokenA], { - chainId: '0x89', - selectedAddress, - }); + expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: '0x89', + selectedAddress, + }, + ); }, ); }); @@ -815,24 +869,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ mockTokenListGetState }) => { + async ({ + mockTokenListGetState, + mockAddDetectedTokens, + triggerNetworkDidChange, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -848,13 +899,15 @@ describe('TokenDetectionController', () => { }, }); - messenger.publish('NetworkController:networkDidChange', { + triggerNetworkDidChange({ ...defaultNetworkState, selectedNetworkClientId: 'goerli', }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -863,24 +916,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ mockTokenListGetState }) => { + async ({ + mockTokenListGetState, + mockAddDetectedTokens, + triggerNetworkDidChange, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -896,13 +946,15 @@ describe('TokenDetectionController', () => { }, }); - messenger.publish('NetworkController:networkDidChange', { + triggerNetworkDidChange({ ...defaultNetworkState, selectedNetworkClientId: 'mainnet', }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -913,24 +965,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ mockTokenListGetState }) => { + async ({ + mockTokenListGetState, + mockAddDetectedTokens, + triggerNetworkDidChange, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -946,13 +995,15 @@ describe('TokenDetectionController', () => { }, }); - messenger.publish('NetworkController:networkDidChange', { + triggerNetworkDidChange({ ...defaultNetworkState, selectedNetworkClientId: 'polygon', }); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -974,24 +1025,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ mockTokenListGetState }) => { + async ({ + mockTokenListGetState, + mockAddDetectedTokens, + triggerTokenListStateChange, + }) => { const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1008,17 +1056,17 @@ describe('TokenDetectionController', () => { }; mockTokenListGetState(tokenListState); - messenger.publish( - 'TokenListController:stateChange', - tokenListState, - [], - ); + triggerTokenListStateChange(tokenListState); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).toHaveBeenCalledWith([sampleTokenA], { - chainId: ChainId.mainnet, - selectedAddress, - }); + expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: ChainId.mainnet, + selectedAddress, + }, + ); }, ); }); @@ -1027,38 +1075,33 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ mockTokenListGetState }) => { + async ({ + mockTokenListGetState, + mockAddDetectedTokens, + triggerTokenListStateChange, + }) => { const tokenListState = { ...getDefaultTokenListState(), tokenList: {}, }; mockTokenListGetState(tokenListState); - messenger.publish( - 'TokenListController:stateChange', - tokenListState, - [], - ); + triggerTokenListStateChange(tokenListState); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -1069,24 +1112,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ mockTokenListGetState }) => { + async ({ + mockTokenListGetState, + mockAddDetectedTokens, + triggerTokenListStateChange, + }) => { const tokenListState = { ...getDefaultTokenListState(), tokenList: { @@ -1103,14 +1143,12 @@ describe('TokenDetectionController', () => { }; mockTokenListGetState(tokenListState); - messenger.publish( - 'TokenListController:stateChange', - tokenListState, - [], - ); + triggerTokenListStateChange(tokenListState); await advanceTime({ clock, duration: 1 }); - expect(mockAddDetectedTokens).not.toHaveBeenCalled(); + expect(mockAddDetectedTokens).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); }, ); }); @@ -1131,22 +1169,15 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, async ({ controller, mockTokenListGetState }) => { mockTokenListGetState({ @@ -1205,24 +1236,21 @@ describe('TokenDetectionController', () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - const mockAddDetectedTokens = jest.fn(); const selectedAddress = '0x0000000000000000000000000000000000000001'; - const messenger = new ControllerMessenger< - AllowedActions, - AllowedEvents - >(); await withController( { options: { - addDetectedTokens: mockAddDetectedTokens, disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, networkClientId: NetworkType.mainnet, selectedAddress, }, - messenger, }, - async ({ controller, mockTokenListGetState }) => { + async ({ + controller, + mockTokenListGetState, + mockAddDetectedTokens, + }) => { mockTokenListGetState({ ...getDefaultTokenListState(), tokenList: { @@ -1243,10 +1271,14 @@ describe('TokenDetectionController', () => { accountAddress: selectedAddress, }); - expect(mockAddDetectedTokens).toHaveBeenCalledWith([sampleTokenA], { - chainId: ChainId.mainnet, - selectedAddress, - }); + expect(mockAddDetectedTokens).toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: ChainId.mainnet, + selectedAddress, + }, + ); }, ); }); @@ -1268,8 +1300,10 @@ function getTokensPath(chainId: Hex) { type WithControllerCallback = ({ controller, mockKeyringGetState, + mockTokensGetState, mockTokenListGetState, mockPreferencesGetState, + mockAddDetectedTokens, triggerKeyringUnlock, triggerKeyringLock, triggerTokenListStateChange, @@ -1279,11 +1313,13 @@ type WithControllerCallback = ({ }: { controller: TokenDetectionController; mockKeyringGetState: (state: KeyringControllerState) => void; + mockTokensGetState: (state: TokensState) => void; mockTokenListGetState: (state: TokenListState) => void; mockPreferencesGetState: (state: PreferencesState) => void; mockGetNetworkConfigurationByNetworkClientId: ( handler: (networkClientId: string) => NetworkConfiguration, ) => void; + mockAddDetectedTokens: jest.SpyInstance; triggerKeyringUnlock: () => void; triggerKeyringLock: () => void; triggerTokenListStateChange: (state: TokenListState) => void; @@ -1323,7 +1359,7 @@ async function withController( 'KeyringController:getState', mockKeyringState.mockReturnValue({ isUnlocked: true, - } as unknown as KeyringControllerState), + } as KeyringControllerState), ); const mockGetNetworkConfigurationByNetworkClientId = jest.fn< ReturnType, @@ -1337,6 +1373,11 @@ async function withController( }, ), ); + const mockTokensState = jest.fn(); + controllerMessenger.registerActionHandler( + 'TokensController:getState', + mockTokensState.mockReturnValue({ ...getDefaultTokensState() }), + ); const mockTokenListState = jest.fn(); controllerMessenger.registerActionHandler( 'TokenListController:getState', @@ -1349,12 +1390,11 @@ async function withController( ...getDefaultPreferencesState(), }), ); + const mockAddDetectedTokens = jest.spyOn(controllerMessenger, 'call'); const controller = new TokenDetectionController({ networkClientId: NetworkType.mainnet, getBalancesInSingleCall: jest.fn(), - addDetectedTokens: jest.fn(), - getTokensState: jest.fn().mockReturnValue(getDefaultTokensState()), trackMetaMetricsEvent: jest.fn(), messenger: buildTokenDetectionControllerMessenger(controllerMessenger), ...options, @@ -1365,6 +1405,9 @@ async function withController( mockKeyringGetState: (state: KeyringControllerState) => { mockKeyringState.mockReturnValue(state); }, + mockTokensGetState: (state: TokensState) => { + mockTokensState.mockReturnValue(state); + }, mockPreferencesGetState: (state: PreferencesState) => { mockPreferencesState.mockReturnValue(state); }, @@ -1378,6 +1421,7 @@ async function withController( handler, ); }, + mockAddDetectedTokens, triggerKeyringUnlock: () => { controllerMessenger.publish('KeyringController:unlock'); }, diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 13175385f2..e570cce454 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -35,7 +35,10 @@ import type { TokenListToken, } from './TokenListController'; import type { Token } from './TokenRatesController'; -import type { TokensController, TokensState } from './TokensController'; +import type { + TokensControllerAddDetectedTokensAction, + TokensControllerGetStateAction, +} from './TokensController'; const DEFAULT_INTERVAL = 180000; @@ -61,7 +64,12 @@ type LegacyToken = Omit< export const STATIC_MAINNET_TOKEN_LIST = Object.entries( contractMap, -).reduce>>((acc, [base, contract]) => { +).reduce< + Record< + string, + Partial & Pick + > +>((acc, [base, contract]) => { const { logo, ...tokenMetadata } = contract; return { ...acc, @@ -90,7 +98,9 @@ export type AllowedActions = | NetworkControllerGetNetworkConfigurationByNetworkClientId | GetTokenListState | KeyringControllerGetStateAction - | PreferencesControllerGetStateAction; + | PreferencesControllerGetStateAction + | TokensControllerGetStateAction + | TokensControllerAddDetectedTokensAction; export type TokenDetectionControllerStateChangeEvent = ControllerStateChangeEvent; @@ -146,12 +156,8 @@ export class TokenDetectionController extends StaticIntervalPollingController< #isDetectionEnabledForNetwork: boolean; - readonly #addDetectedTokens: TokensController['addDetectedTokens']; - readonly #getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; - readonly #getTokensState: () => TokensState; - readonly #trackMetaMetricsEvent: (options: { event: string; category: string; @@ -171,9 +177,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< * @param options.interval - Polling interval used to fetch new token rates * @param options.networkClientId - The selected network client ID of the current network * @param options.selectedAddress - Vault selected address - * @param options.addDetectedTokens - Add a list of detected tokens. * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address. - * @param options.getTokensState - Gets the current state of the Tokens controller. * @param options.trackMetaMetricsEvent - Sets options for MetaMetrics event tracking. */ constructor({ @@ -182,8 +186,6 @@ export class TokenDetectionController extends StaticIntervalPollingController< interval = DEFAULT_INTERVAL, disabled = true, getBalancesInSingleCall, - addDetectedTokens, - getTokensState, trackMetaMetricsEvent, messenger, }: { @@ -191,9 +193,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< selectedAddress?: string; interval?: number; disabled?: boolean; - addDetectedTokens: TokensController['addDetectedTokens']; getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; - getTokensState: () => TokensState; trackMetaMetricsEvent: (options: { event: string; category: string; @@ -226,9 +226,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< this.#chainId, ); - this.#addDetectedTokens = addDetectedTokens; this.#getBalancesInSingleCall = getBalancesInSingleCall; - this.#getTokensState = getTokensState; this.#trackMetaMetricsEvent = trackMetaMetricsEvent; @@ -460,9 +458,10 @@ export class TokenDetectionController extends StaticIntervalPollingController< ? STATIC_MAINNET_TOKEN_LIST : tokenList; - const { tokens, detectedTokens } = this.#getTokensState(); + const { tokens, detectedTokens, ignoredTokens } = this.messagingSystem.call( + 'TokensController:getState', + ); const tokensToDetect: string[] = []; - for (const tokenAddress of Object.keys(tokenListUsed)) { if ( !findCaseInsensitiveMatch( @@ -500,11 +499,9 @@ export class TokenDetectionController extends StaticIntervalPollingController< tokensSlice, ); const tokensToAdd: Token[] = []; - const eventTokensDetails = []; + const eventTokensDetails: string[] = []; + let ignored; for (const tokenAddress of Object.keys(balances)) { - let ignored; - /* istanbul ignore else */ - const { ignoredTokens } = this.#getTokensState(); if (ignoredTokens.length) { ignored = ignoredTokens.find( (ignoredTokenAddress) => @@ -519,7 +516,7 @@ export class TokenDetectionController extends StaticIntervalPollingController< if (ignored === undefined) { const { decimals, symbol, aggregators, iconUrl, name } = - tokenList[caseInsensitiveTokenKey]; + tokenListUsed[caseInsensitiveTokenKey]; eventTokensDetails.push(`${symbol} - ${tokenAddress}`); tokensToAdd.push({ address: tokenAddress, @@ -543,10 +540,14 @@ export class TokenDetectionController extends StaticIntervalPollingController< asset_type: 'TOKEN', }, }); - await this.#addDetectedTokens(tokensToAdd, { - selectedAddress, - chainId, - }); + await this.messagingSystem.call( + 'TokensController:addDetectedTokens', + tokensToAdd, + { + selectedAddress, + chainId, + }, + ); } }); } diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index 3f5aaaa35d..78c050f92f 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -1,8 +1,7 @@ +import type { ApprovalControllerEvents } from '@metamask/approval-controller'; import { ApprovalController, - type AddApprovalRequest, type ApprovalControllerState, - type ApprovalControllerEvents, } from '@metamask/approval-controller'; import { ControllerMessenger } from '@metamask/base-controller'; import contractMaps from '@metamask/contract-metadata'; @@ -17,10 +16,15 @@ import { toHex, } from '@metamask/controller-utils'; import type { - NetworkState, + BlockTrackerProxy, + NetworkController, ProviderConfig, + ProviderProxy, +} from '@metamask/network-controller'; +import { + defaultState as defaultNetworkState, + NetworkClientType, } from '@metamask/network-controller'; -import { defaultState as defaultNetworkState } from '@metamask/network-controller'; import { getDefaultPreferencesState, type PreferencesState, @@ -28,12 +32,15 @@ import { import nock from 'nock'; import * as sinon from 'sinon'; +import { FakeBlockTracker } from '../../../tests/fake-block-tracker'; +import { FakeProvider } from '../../../tests/fake-provider'; import { ERC20Standard } from './Standards/ERC20Standard'; import { ERC1155Standard } from './Standards/NftStandards/ERC1155/ERC1155Standard'; import { TOKEN_END_POINT_API } from './token-service'; +import type { TokenListState } from './TokenListController'; import type { Token } from './TokenRatesController'; import { TokensController } from './TokensController'; -import type { TokensControllerMessenger } from './TokensController'; +import type { AllowedActions, AllowedEvents } from './TokensController'; jest.mock('uuid', () => { return { @@ -51,7 +58,21 @@ const stubCreateEthers = (ctrl: TokensController, res: () => boolean) => { } as any; }); }; - +const MAINNET = { + chainId: ChainId.mainnet, + type: NetworkType.mainnet, + ticker: NetworksTicker.mainnet, +}; +const mockMainnetClient = { + configuration: { + network: 'mainnet', + ...MAINNET, + type: NetworkClientType.Infura, + }, + provider: {} as ProviderProxy, + blockTracker: {} as BlockTrackerProxy, + destroy: jest.fn(), +}; const SEPOLIA = { chainId: toHex(11155111), type: NetworkType.sepolia, @@ -65,76 +86,79 @@ const GOERLI = { const controllerName = 'TokensController' as const; -type ApprovalActions = AddApprovalRequest; - describe('TokensController', () => { let tokensController: TokensController; - let triggerPreferencesStateChange: (state: PreferencesState) => void; - const messenger = new ControllerMessenger< - ApprovalActions, - ApprovalControllerEvents - >(); - - const approvalControllerMessenger = messenger.getRestricted({ - name: 'ApprovalController', - }); - - const approvalController = new ApprovalController({ - messenger: approvalControllerMessenger, - showApprovalRequest: jest.fn(), - typesExcludedFromRateLimiting: [ApprovalType.WatchAsset], - }); + let approvalController: ApprovalController; + let messenger: ControllerMessenger< + AllowedActions, + AllowedEvents | ApprovalControllerEvents + >; + let tokensControllerMessenger; + let approvalControllerMessenger; + let getNetworkClientByIdHandler: jest.Mock< + ReturnType, + Parameters + >; - const tokensControllerMessenger = messenger.getRestricted< - typeof controllerName, - ApprovalActions['type'], - never - >({ - name: controllerName, - allowedActions: ['ApprovalController:addRequest'], - }) as TokensControllerMessenger; - - let onNetworkDidChangeListener: (state: NetworkState) => void; const changeNetwork = (providerConfig: ProviderConfig) => { - onNetworkDidChangeListener({ + messenger.publish(`NetworkController:networkDidChange`, { ...defaultNetworkState, providerConfig, }); }; - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let tokenListStateChangeListener: (state: any) => void; - const onTokenListStateChange = sinon.stub().callsFake((listener) => { - tokenListStateChangeListener = listener; - }); + const triggerPreferencesStateChange = (state: PreferencesState) => { + messenger.publish('PreferencesController:stateChange', state, []); + }; beforeEach(async () => { const defaultSelectedAddress = '0x1'; - const preferencesStateChangeListeners: (( - state: PreferencesState, - ) => void)[] = []; + messenger = new ControllerMessenger(); + + approvalControllerMessenger = messenger.getRestricted< + 'ApprovalController', + never, + never + >({ + name: 'ApprovalController', + }); + + tokensControllerMessenger = messenger.getRestricted({ + name: controllerName, + allowedActions: [ + 'ApprovalController:addRequest', + 'NetworkController:getNetworkClientById', + ], + allowedEvents: [ + 'NetworkController:networkDidChange', + 'PreferencesController:stateChange', + 'TokenListController:stateChange', + ], + }); tokensController = new TokensController({ chainId: ChainId.mainnet, - onPreferencesStateChange: (listener) => { - preferencesStateChangeListeners.push(listener); - }, - onNetworkDidChange: (listener) => (onNetworkDidChangeListener = listener), - onTokenListStateChange, config: { selectedAddress: defaultSelectedAddress, provider: sinon.stub(), }, - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - getNetworkClientById: sinon.stub() as any, messenger: tokensControllerMessenger, }); - triggerPreferencesStateChange = (state: PreferencesState) => { - for (const listener of preferencesStateChangeListeners) { - listener(state); - } - }; + + approvalController = new ApprovalController({ + messenger: approvalControllerMessenger, + showApprovalRequest: jest.fn(), + typesExcludedFromRateLimiting: [ApprovalType.WatchAsset], + }); + + getNetworkClientByIdHandler = jest.fn(); + messenger.registerActionHandler( + `NetworkController:getNetworkClientById`, + getNetworkClientByIdHandler.mockReturnValue( + mockMainnetClient as unknown as ReturnType< + NetworkController['getNetworkClientById'] + >, + ), + ); }); afterEach(() => { @@ -398,11 +422,9 @@ describe('TokensController', () => { it('should add token to the correct chainId when passed a networkClientId', async () => { const stub = stubCreateEthers(tokensController, () => false); - const getNetworkClientByIdStub = jest - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - .spyOn(tokensController as any, 'getNetworkClientById') - .mockReturnValue({ configuration: { chainId: '0x5' } }); + getNetworkClientByIdHandler.mockReturnValue({ + configuration: { chainId: '0x5' }, + } as unknown as ReturnType); await tokensController.addToken({ address: '0x01', symbol: 'bar', @@ -432,7 +454,9 @@ describe('TokensController', () => { }, ]); - expect(getNetworkClientByIdStub).toHaveBeenCalledWith('networkClientId1'); + expect(getNetworkClientByIdHandler).toHaveBeenCalledWith( + 'networkClientId1', + ); stub.restore(); }); @@ -1083,12 +1107,9 @@ describe('TokensController', () => { }); it('should add tokens to the correct chainId when passed a networkClientId', async () => { - const getNetworkClientByIdStub = jest - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - .spyOn(tokensController as any, 'getNetworkClientById') - .mockReturnValue({ configuration: { chainId: '0x5' } }); - + getNetworkClientByIdHandler.mockReturnValue({ + configuration: { chainId: '0x5' }, + } as unknown as ReturnType); const dummyTokens: Token[] = [ { address: '0x01', @@ -1114,7 +1135,9 @@ describe('TokensController', () => { expect(tokensController.state.allTokens['0x5']['0x1']).toStrictEqual( dummyTokens, ); - expect(getNetworkClientByIdStub).toHaveBeenCalledWith('networkClientId1'); + expect(getNetworkClientByIdHandler).toHaveBeenCalledWith( + 'networkClientId1', + ); }); }); @@ -1203,7 +1226,6 @@ describe('TokensController', () => { .spyOn(ERC20Standard.prototype as any, 'getTokenDecimals') .mockImplementationOnce(() => a.decimals?.toString()); }); - let createEthersStub: sinon.SinonStub; beforeEach(function () { type = ERC20; @@ -1539,22 +1561,29 @@ describe('TokensController', () => { }); it('stores token correctly when passed a networkClientId', async function () { - const getNetworkClientByIdStub = jest - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - .spyOn(tokensController as any, 'getNetworkClientById') - .mockReturnValue({ + getNetworkClientByIdHandler.mockImplementation((networkClientId) => { + expect(networkClientId).toBe('networkClientId1'); + return { configuration: { chainId: '0x5' }, - provider: sinon.stub(), - }); + provider: new FakeProvider({ + stubs: [], + }), + blockTracker: new FakeBlockTracker(), + destroy: jest.fn(), + } as unknown as ReturnType; + }); + + const addRequestHandler = jest.fn(); + messenger.unregisterActionHandler(`ApprovalController:addRequest`); + messenger.registerActionHandler( + `ApprovalController:addRequest`, + addRequestHandler, + ); + const generateRandomIdStub = jest .spyOn(tokensController, '_generateRandomId') .mockReturnValue(requestId); - const callActionSpy = jest - .spyOn(messenger, 'call') - .mockResolvedValue(undefined); - await tokensController.watchAsset({ asset, type, @@ -1562,6 +1591,20 @@ describe('TokensController', () => { networkClientId: 'networkClientId1', }); + expect(addRequestHandler).toHaveBeenCalledWith( + { + id: requestId, + origin: ORIGIN_METAMASK, + type: ApprovalType.WatchAsset, + requestData: { + id: requestId, + interactingAddress, + asset, + }, + }, + true, + ); + expect(tokensController.state.tokens).toHaveLength(0); expect(tokensController.state.tokens).toStrictEqual([]); expect( @@ -1576,22 +1619,6 @@ describe('TokensController', () => { ...asset, }, ]); - expect(callActionSpy).toHaveBeenCalledTimes(1); - expect(callActionSpy).toHaveBeenCalledWith( - 'ApprovalController:addRequest', - { - id: requestId, - origin: ORIGIN_METAMASK, - type: ApprovalType.WatchAsset, - requestData: { - id: requestId, - interactingAddress, - asset, - }, - }, - true, - ); - expect(getNetworkClientByIdStub).toHaveBeenCalledWith('networkClientId1'); generateRandomIdStub.mockRestore(); }); @@ -1630,7 +1657,7 @@ describe('TokensController', () => { generateRandomIdStub.mockRestore(); }); - it('stores multiple tokens from a batched watchAsset confirmation screen correctly when user confirms', async function () { + it('stores multiple tokens from a batched watchAsset confirmation screen correctly when user confirms', async () => { const generateRandomIdStub = jest .spyOn(tokensController, '_generateRandomId') .mockImplementationOnce(() => requestId) @@ -1657,36 +1684,32 @@ describe('TokensController', () => { mockContract([asset, anotherAsset]); + const promiseForApprovals = new Promise((resolve) => { + const listener = (state: ApprovalControllerState) => { + if (state.pendingApprovalCount === 2) { + messenger.unsubscribe('ApprovalController:stateChange', listener); + resolve(); + } + }; + messenger.subscribe('ApprovalController:stateChange', listener); + }); + + // eslint-disable-next-line @typescript-eslint/no-floating-promises tokensController.watchAsset({ asset, type, interactingAddress }); + + // eslint-disable-next-line @typescript-eslint/no-floating-promises tokensController.watchAsset({ asset: anotherAsset, type, interactingAddress, }); - await new Promise((resolve) => { - const listener = (state: ApprovalControllerState) => { - if (state.pendingApprovalCount === 2) { - approvalControllerMessenger.unsubscribe( - 'ApprovalController:stateChange', - listener, - ); - resolve(); - } - }; - approvalControllerMessenger.subscribe( - 'ApprovalController:stateChange', - listener, - ); - }); + await promiseForApprovals; await approvalController.accept(requestId); await approvalController.accept('67890'); await acceptedRequest; - expect( - tokensController.state.allTokens[ChainId.mainnet][interactingAddress], - ).toHaveLength(2); expect( tokensController.state.allTokens[ChainId.mainnet][interactingAddress], ).toStrictEqual([ @@ -1957,8 +1980,13 @@ describe('TokensController', () => { aggregators: ['Aave'], }, }; - - await tokenListStateChangeListener({ tokenList: sampleMainnetTokenList }); + messenger.publish( + 'TokenListController:stateChange', + { + tokenList: sampleMainnetTokenList, + } as unknown as TokenListState, + [], + ); expect(tokensController.state.tokens[0]).toStrictEqual({ address: '0x01', diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index f263f331ca..02153e952d 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -5,6 +5,8 @@ import type { BaseConfig, BaseState, RestrictedControllerMessenger, + ControllerGetStateAction, + ControllerStateChangeEvent, } from '@metamask/base-controller'; import { BaseControllerV1 } from '@metamask/base-controller'; import contractsMap from '@metamask/contract-metadata'; @@ -22,10 +24,10 @@ import { import { abiERC721 } from '@metamask/metamask-eth-abis'; import type { NetworkClientId, - NetworkController, - NetworkState, + NetworkControllerGetNetworkClientByIdAction, + NetworkControllerNetworkDidChangeEvent, } from '@metamask/network-controller'; -import type { PreferencesState } from '@metamask/preferences-controller'; +import type { PreferencesControllerStateChangeEvent } from '@metamask/preferences-controller'; import { rpcErrors } from '@metamask/rpc-errors'; import type { Hex } from '@metamask/utils'; import { Mutex } from 'async-mutex'; @@ -41,7 +43,7 @@ import { } from './token-service'; import type { TokenListMap, - TokenListState, + TokenListStateChange, TokenListToken, } from './TokenListController'; import type { Token } from './TokenRatesController'; @@ -92,37 +94,62 @@ type SuggestedAssetMeta = { * @property allIgnoredTokens - Object containing hidden/ignored tokens by network and account * @property allDetectedTokens - Object containing tokens detected with non-zero balances */ -// This interface was created before this ESLint rule was added. -// Convert to a `type` in a future major version. -// eslint-disable-next-line @typescript-eslint/consistent-type-definitions -export interface TokensState extends BaseState { +export type TokensState = { tokens: Token[]; ignoredTokens: string[]; detectedTokens: Token[]; allTokens: { [chainId: Hex]: { [key: string]: Token[] } }; allIgnoredTokens: { [chainId: Hex]: { [key: string]: string[] } }; allDetectedTokens: { [chainId: Hex]: { [key: string]: Token[] } }; -} +}; /** * The name of the {@link TokensController}. */ const controllerName = 'TokensController'; +export type TokensControllerActions = + | TokensControllerGetStateAction + | TokensControllerAddDetectedTokensAction; + +export type TokensControllerGetStateAction = ControllerGetStateAction< + typeof controllerName, + TokensState +>; + +export type TokensControllerAddDetectedTokensAction = { + type: `${typeof controllerName}:addDetectedTokens`; + handler: TokensController['addDetectedTokens']; +}; + /** * The external actions available to the {@link TokensController}. */ -type AllowedActions = AddApprovalRequest; +export type AllowedActions = + | AddApprovalRequest + | NetworkControllerGetNetworkClientByIdAction; + +export type TokensControllerStateChangeEvent = ControllerStateChangeEvent< + typeof controllerName, + TokensState +>; + +export type TokensControllerEvents = TokensControllerStateChangeEvent; + +export type AllowedEvents = + | NetworkControllerNetworkDidChangeEvent + | PreferencesControllerStateChangeEvent + | TokenListStateChange; /** * The messenger of the {@link TokensController}. */ export type TokensControllerMessenger = RestrictedControllerMessenger< typeof controllerName, - AllowedActions, - never, + TokensControllerActions | AllowedActions, + TokensControllerEvents | AllowedEvents, AllowedActions['type'], - never + AllowedEvents['type'] >; export const getDefaultTokensState = (): TokensState => { @@ -141,7 +168,7 @@ export const getDefaultTokensState = (): TokensState => { */ export class TokensController extends BaseControllerV1< TokensConfig, - TokensState + TokensState & BaseState > { private readonly mutex = new Mutex(); @@ -186,42 +213,22 @@ export class TokensController extends BaseControllerV1< */ override name = 'TokensController'; - private readonly getNetworkClientById: NetworkController['getNetworkClientById']; - /** * Creates a TokensController instance. * * @param options - The controller options. * @param options.chainId - The chain ID of the current network. - * @param options.onPreferencesStateChange - Allows subscribing to preference controller state changes. - * @param options.onNetworkDidChange - Allows subscribing to network controller networkDidChange events. - * @param options.onTokenListStateChange - Allows subscribing to token list controller state changes. - * @param options.getNetworkClientById - Gets the network client with the given id from the NetworkController. * @param options.config - Initial options used to configure this controller. * @param options.state - Initial state to set on this controller. * @param options.messenger - The controller messenger. */ constructor({ chainId: initialChainId, - onPreferencesStateChange, - onNetworkDidChange, - onTokenListStateChange, - getNetworkClientById, config, state, messenger, }: { chainId: Hex; - onPreferencesStateChange: ( - listener: (preferencesState: PreferencesState) => void, - ) => void; - onNetworkDidChange: ( - listener: (networkState: NetworkState) => void, - ) => void; - onTokenListStateChange: ( - listener: (tokenListState: TokenListState) => void, - ) => void; - getNetworkClientById: NetworkController['getNetworkClientById']; config?: Partial; state?: Partial; messenger: TokensControllerMessenger; @@ -242,41 +249,54 @@ export class TokensController extends BaseControllerV1< this.initialize(); this.abortController = new AbortController(); - this.getNetworkClientById = getNetworkClientById; this.messagingSystem = messenger; - onPreferencesStateChange(({ selectedAddress }) => { - const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - const { chainId } = this.config; - this.configure({ selectedAddress }); - this.update({ - tokens: allTokens[chainId]?.[selectedAddress] || [], - ignoredTokens: allIgnoredTokens[chainId]?.[selectedAddress] || [], - detectedTokens: allDetectedTokens[chainId]?.[selectedAddress] || [], - }); - }); + this.messagingSystem.registerActionHandler( + `${controllerName}:addDetectedTokens` as const, + this.addDetectedTokens.bind(this), + ); - onNetworkDidChange(({ providerConfig }) => { - const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; - const { selectedAddress } = this.config; - const { chainId } = providerConfig; - this.abortController.abort(); - this.abortController = new AbortController(); - this.configure({ chainId }); - this.update({ - tokens: allTokens[chainId]?.[selectedAddress] || [], - ignoredTokens: allIgnoredTokens[chainId]?.[selectedAddress] || [], - detectedTokens: allDetectedTokens[chainId]?.[selectedAddress] || [], - }); - }); + this.messagingSystem.subscribe( + 'PreferencesController:stateChange', + ({ selectedAddress }) => { + const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; + const { chainId } = this.config; + this.configure({ selectedAddress }); + this.update({ + tokens: allTokens[chainId]?.[selectedAddress] ?? [], + ignoredTokens: allIgnoredTokens[chainId]?.[selectedAddress] ?? [], + detectedTokens: allDetectedTokens[chainId]?.[selectedAddress] ?? [], + }); + }, + ); - onTokenListStateChange(({ tokenList }) => { - const { tokens } = this.state; - if (tokens.length && !tokens[0].name) { - this.updateTokensAttribute(tokenList, 'name'); - } - }); + this.messagingSystem.subscribe( + 'NetworkController:networkDidChange', + ({ providerConfig }) => { + const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; + const { selectedAddress } = this.config; + const { chainId } = providerConfig; + this.abortController.abort(); + this.abortController = new AbortController(); + this.configure({ chainId }); + this.update({ + tokens: allTokens[chainId]?.[selectedAddress] || [], + ignoredTokens: allIgnoredTokens[chainId]?.[selectedAddress] || [], + detectedTokens: allDetectedTokens[chainId]?.[selectedAddress] || [], + }); + }, + ); + + this.messagingSystem.subscribe( + 'TokenListController:stateChange', + ({ tokenList }) => { + const { tokens } = this.state; + if (tokens.length && !tokens[0].name) { + this.updateTokensAttribute(tokenList, 'name'); + } + }, + ); } /** @@ -314,8 +334,10 @@ export class TokensController extends BaseControllerV1< const { allTokens, allIgnoredTokens, allDetectedTokens } = this.state; let currentChainId = chainId; if (networkClientId) { - currentChainId = - this.getNetworkClientById(networkClientId).configuration.chainId; + currentChainId = this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ).configuration.chainId; } const accountAddress = interactingAddress || selectedAddress; @@ -445,8 +467,10 @@ export class TokensController extends BaseControllerV1< let interactingChainId; if (networkClientId) { - interactingChainId = - this.getNetworkClientById(networkClientId).configuration.chainId; + interactingChainId = this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ).configuration.chainId; } const { newAllTokens, newAllDetectedTokens, newAllIgnoredTokens } = @@ -695,8 +719,11 @@ export class TokensController extends BaseControllerV1< _getProvider(networkClientId?: NetworkClientId): Web3Provider { return new Web3Provider( networkClientId - ? this.getNetworkClientById(networkClientId).provider - : this.config?.provider, + ? this.messagingSystem.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ).provider + : this.config.provider, ); } diff --git a/packages/assets-controllers/src/index.ts b/packages/assets-controllers/src/index.ts index 1f2786d784..242498de3e 100644 --- a/packages/assets-controllers/src/index.ts +++ b/packages/assets-controllers/src/index.ts @@ -19,9 +19,35 @@ export type { TokenDetectionControllerStateChangeEvent, } from './TokenDetectionController'; export { TokenDetectionController } from './TokenDetectionController'; -export * from './TokenListController'; -export * from './TokenRatesController'; -export * from './TokensController'; +export type { + TokenListState, + TokenListToken, + TokenListMap, + TokenListStateChange, + TokenListControllerEvents, + GetTokenListState, + TokenListControllerActions, + TokenListControllerMessenger, +} from './TokenListController'; +export { TokenListController } from './TokenListController'; +export type { + Token, + TokenRatesConfig, + ContractExchangeRates, + TokenRatesState, +} from './TokenRatesController'; +export { TokenRatesController } from './TokenRatesController'; +export type { + TokensConfig, + TokensState, + TokensControllerActions, + TokensControllerGetStateAction, + TokensControllerAddDetectedTokensAction, + TokensControllerEvents, + TokensControllerStateChangeEvent, + TokensControllerMessenger, +} from './TokensController'; +export { TokensController } from './TokensController'; export { isTokenDetectionSupportedForNetwork, formatIconUrlWithProxy,