Skip to content

Commit

Permalink
[token-balances-controller] test: Use `PreferencesController:getState…
Browse files Browse the repository at this point in the history
…` action, `TokensController:stateChange` event
  • Loading branch information
MajorLift committed Jan 24, 2024
1 parent 12971a5 commit 669bb4f
Showing 1 changed file with 88 additions and 82 deletions.
170 changes: 88 additions & 82 deletions packages/assets-controllers/src/TokenBalancesController.test.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import { ControllerMessenger } from '@metamask/base-controller';
import { toHex } from '@metamask/controller-utils';
import type {
NetworkControllerActions,
NetworkControllerEvents,
} from '@metamask/network-controller';
import {} from '@metamask/network-controller';
import { BN } from 'ethereumjs-util';

import { flushPromises } from '../../../tests/helpers';
import type {
AllowedActions,
AllowedEvents,
TokenBalancesControllerMessenger,
} from './TokenBalancesController';
import { TokenBalancesController } from './TokenBalancesController';
import type { TokenListStateChange } from './TokenListController';
import type { Token } from './TokenRatesController';
import { getDefaultTokensState, type TokensState } from './TokensController';

Expand All @@ -18,59 +17,63 @@ const controllerName = 'TokenBalancesController';
/**
* Constructs a restricted controller messenger.
*
* @param controllerMessenger - The controller messenger to restrict.
* @returns A restricted controller messenger.
*/
function getMessenger() {
return new ControllerMessenger().getRestricted<
typeof controllerName,
never,
never
>({
function getMessenger(
controllerMessenger = new ControllerMessenger<
AllowedActions,
AllowedEvents
>(),
): TokenBalancesControllerMessenger {
return controllerMessenger.getRestricted({
name: controllerName,
allowedActions: ['PreferencesController:getState'],
allowedEvents: ['TokensController:stateChange'],
});
}

describe('TokenBalancesController', () => {
let controllerMessenger: ControllerMessenger<
NetworkControllerActions,
NetworkControllerEvents | TokenListStateChange
>;
let controllerMessenger: ControllerMessenger<AllowedActions, AllowedEvents>;
let messenger: TokenBalancesControllerMessenger;

beforeEach(() => {
jest.useFakeTimers();
controllerMessenger = new ControllerMessenger();
messenger = getMessenger(controllerMessenger);
});

afterEach(() => {
jest.useRealTimers();
controllerMessenger.clearEventSubscriptions(
'NetworkController:networkDidChange',
);
});

it('should set default state', () => {
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
onTokensStateChange: jest.fn(),
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn(),
messenger: getMessenger(),
messenger,
});

expect(controller.state).toStrictEqual({ contractBalances: {} });
});

it('should poll and update balances in the right interval', async () => {
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const updateBalancesSpy = jest.spyOn(
TokenBalancesController.prototype,
'updateBalances',
);

new TokenBalancesController({
interval: 10,
onTokensStateChange: jest.fn(),
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn(),
messenger: getMessenger(),
messenger,
});
await flushPromises();

Expand All @@ -84,14 +87,16 @@ describe('TokenBalancesController', () => {

it('should update balances if enabled', async () => {
const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0';
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
disabled: false,
tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }],
interval: 10,
onTokensStateChange: jest.fn(),
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
messenger: getMessenger(),
messenger,
});

await controller.updateBalances();
Expand All @@ -103,14 +108,16 @@ describe('TokenBalancesController', () => {

it('should not update balances if disabled', async () => {
const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0';
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
disabled: true,
tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }],
interval: 10,
onTokensStateChange: jest.fn(),
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
messenger: getMessenger(),
messenger,
});

await controller.updateBalances();
Expand All @@ -120,14 +127,16 @@ describe('TokenBalancesController', () => {

it('should update balances if controller is manually enabled', async () => {
const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0';
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
disabled: true,
tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }],
interval: 10,
onTokensStateChange: jest.fn(),
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
messenger: getMessenger(),
messenger,
});

await controller.updateBalances();
Expand All @@ -144,14 +153,16 @@ describe('TokenBalancesController', () => {

it('should not update balances if controller is manually disabled', async () => {
const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0';
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
disabled: false,
tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }],
interval: 10,
onTokensStateChange: jest.fn(),
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
messenger: getMessenger(),
messenger,
});

await controller.updateBalances();
Expand All @@ -169,23 +180,20 @@ describe('TokenBalancesController', () => {
});

it('should update balances if tokens change and controller is manually enabled', async () => {
const tokensStateChangeListeners: ((state: TokensState) => void)[] = [];
const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0';
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
disabled: true,
tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }],
interval: 10,
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
onTokensStateChange: (listener) => {
tokensStateChangeListeners.push(listener);
},
messenger: getMessenger(),
messenger,
});
const triggerTokensStateChange = async (state: TokensState) => {
for (const listener of tokensStateChangeListeners) {
listener(state);
}
controllerMessenger.publish('TokensController:stateChange', state, []);
};

await controller.updateBalances();
Expand All @@ -210,23 +218,20 @@ describe('TokenBalancesController', () => {
});

it('should not update balances if tokens change and controller is manually disabled', async () => {
const tokensStateChangeListeners: ((state: TokensState) => void)[] = [];
const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0';
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
disabled: false,
tokens: [{ address, decimals: 18, symbol: 'EOS', aggregators: [] }],
interval: 10,
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
onTokensStateChange: (listener) => {
tokensStateChangeListeners.push(listener);
},
messenger: getMessenger(),
messenger,
});
const triggerTokensStateChange = async (state: TokensState) => {
for (const listener of tokensStateChangeListeners) {
listener(state);
}
controllerMessenger.publish('TokensController:stateChange', state, []);
};

await controller.updateBalances();
Expand All @@ -253,12 +258,14 @@ describe('TokenBalancesController', () => {
});

it('should clear previous interval', async () => {
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
interval: 1337,
onTokensStateChange: jest.fn(),
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn(),
messenger: getMessenger(),
messenger,
});

const mockClearTimeout = jest.spyOn(global, 'clearTimeout');
Expand All @@ -281,13 +288,15 @@ describe('TokenBalancesController', () => {
aggregators: [],
},
];
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress }),
);
const controller = new TokenBalancesController({
interval: 1337,
tokens,
onTokensStateChange: jest.fn(),
getSelectedAddress: () => selectedAddress,
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
messenger: getMessenger(),
messenger,
});

expect(controller.state.contractBalances).toStrictEqual({});
Expand All @@ -313,13 +322,16 @@ describe('TokenBalancesController', () => {
aggregators: [],
},
];

controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({}),
);
const controller = new TokenBalancesController({
interval: 1337,
tokens,
onTokensStateChange: jest.fn(),
getSelectedAddress: jest.fn(),
getERC20BalanceOf: getERC20BalanceOfStub,
messenger: getMessenger(),
messenger,
});

expect(controller.state.contractBalances).toStrictEqual({});
Expand All @@ -340,20 +352,17 @@ describe('TokenBalancesController', () => {
});

it('should update balances when tokens change', async () => {
const tokensStateChangeListeners: ((state: TokensState) => void)[] = [];
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
onTokensStateChange: (listener) => {
tokensStateChangeListeners.push(listener);
},
getSelectedAddress: jest.fn(),
getERC20BalanceOf: jest.fn(),
interval: 1337,
messenger: getMessenger(),
messenger,
});
const triggerTokensStateChange = async (state: TokensState) => {
for (const listener of tokensStateChangeListeners) {
listener(state);
}
controllerMessenger.publish('TokensController:stateChange', state, []);
};
const updateBalancesSpy = jest.spyOn(controller, 'updateBalances');

Expand All @@ -372,20 +381,17 @@ describe('TokenBalancesController', () => {
});

it('should update token balances when detected tokens are added', async () => {
const tokensStateChangeListeners: ((state: TokensState) => void)[] = [];
controllerMessenger.registerActionHandler(
'PreferencesController:getState',
jest.fn().mockReturnValue({ selectedAddress: '0x1234' }),
);
const controller = new TokenBalancesController({
interval: 1337,
onTokensStateChange: (listener) => {
tokensStateChangeListeners.push(listener);
},
getSelectedAddress: () => '0x1234',
getERC20BalanceOf: jest.fn().mockReturnValue(new BN(1)),
messenger: getMessenger(),
messenger,
});
const triggerTokensStateChange = async (state: TokensState) => {
for (const listener of tokensStateChangeListeners) {
listener(state);
}
controllerMessenger.publish('TokensController:stateChange', state, []);
};
expect(controller.state.contractBalances).toStrictEqual({});

Expand Down

0 comments on commit 669bb4f

Please sign in to comment.