From c1cdc756eaa3b1e6c16a4c85ffde412a27e3ee4b Mon Sep 17 00:00:00 2001 From: Mark Stacey Date: Thu, 11 Mar 2021 09:52:10 -0330 Subject: [PATCH 1/3] Use DI for controller communication rather than context We now setup inter-controller communication using dependency injection rather than the `context` property. Any dependency a controller has is passed in via a constructor parameter. This was done in preparation for migrating to BaseControllerV2 and the new controller messaging system - it's just a temporary solution that will let us migrate controllers one at a time. The style of dependency injection here matches the extension. Just as we do there, we inject state snapshots, means of subscribing to state, and individual methods rather than entire controllers. This helps to simplify tests and makes it easier to understand how controllers interact. --- README.md | 2 - src/BaseController.test.ts | 10 - src/BaseController.ts | 26 -- src/ComposableController.test.ts | 129 +++--- src/ComposableController.ts | 63 +-- src/assets/AccountTrackerController.test.ts | 31 +- src/assets/AccountTrackerController.ts | 44 ++- src/assets/AssetsController.test.ts | 15 +- src/assets/AssetsController.ts | 106 ++--- src/assets/AssetsDetectionController.test.ts | 76 ++-- src/assets/AssetsDetectionController.ts | 116 +++--- src/assets/TokenBalancesController.test.ts | 137 +++++-- src/assets/TokenBalancesController.ts | 54 +-- src/assets/TokenRatesController.test.ts | 79 ++-- src/assets/TokenRatesController.ts | 47 ++- src/keyring/KeyringController.test.ts | 12 +- src/keyring/KeyringController.ts | 60 ++- src/transaction/TransactionController.test.ts | 368 +++++++++++------- src/transaction/TransactionController.ts | 96 ++--- 19 files changed, 816 insertions(+), 655 deletions(-) diff --git a/README.md b/README.md index 0f6a421bfb..dc334645e7 100644 --- a/README.md +++ b/README.md @@ -245,8 +245,6 @@ console.log(datamodel.state); // {NetworkController: {...}, TokenRatesController console.log(datamodel.flatState); // {infura: {...}, contractExchangeRates: [...]} ``` -**Advanced Note:** The ComposableController builds a map of all child controllers keyed by controller name. This object is cached as a `context` instance variable on both the ComposableController itself as well as all child controllers. This means that child controllers can call methods on other sibling controllers through the `context` variable, e.g. `this.context.SomeController.someMethod()`. - ## Linking during development Linking `@metamask/controllers` into other projects involves a special NPM command to ensure that dependencies are not duplicated. This is because `@metamask/controllers` ships modules that are transpiled but not bundled, and [NPM does not deduplicate](https://github.com/npm/npm/issues/7742) linked dependency trees. diff --git a/src/BaseController.test.ts b/src/BaseController.test.ts index bd3078f2d6..7df1ac1f0c 100644 --- a/src/BaseController.test.ts +++ b/src/BaseController.test.ts @@ -1,13 +1,10 @@ import { stub } from 'sinon'; import BaseController, { BaseConfig, BaseState } from './BaseController'; -import ComposableController from './ComposableController'; const STATE = { name: 'foo' }; const CONFIG = { disabled: true }; class TestController extends BaseController { - requiredControllers = ['Foo']; - constructor(config?: BaseConfig, state?: BaseState) { super(config, state); this.initialize(); @@ -68,11 +65,4 @@ describe('BaseController', () => { controller.notify(); expect(listener.called).toBe(false); }); - - it('should throw if siblings are missing dependencies', () => { - const controller = new TestController(); - expect(() => { - new ComposableController([controller]); - }).toThrow('BaseController must be composed with Foo.'); - }); }); diff --git a/src/BaseController.ts b/src/BaseController.ts index be3f97aa64..645faa563e 100644 --- a/src/BaseController.ts +++ b/src/BaseController.ts @@ -1,5 +1,3 @@ -import { ChildControllerContext } from './ComposableController'; - /** * State change callbacks */ @@ -31,13 +29,6 @@ export interface BaseState { * Controller class that provides configuration, state management, and subscriptions */ export class BaseController { - /** - * Map of all sibling child controllers keyed by name if this - * controller is composed using a ComposableController, allowing - * any API on any sibling controller to be accessed - */ - context: ChildControllerContext = {}; - /** * Default options used to configure this controller */ @@ -58,11 +49,6 @@ export class BaseController { */ name = 'BaseController'; - /** - * List of required sibling controllers this controller needs to function - */ - requiredControllers: string[] = []; - private readonly initialConfig: C; private readonly initialState: S; @@ -158,18 +144,6 @@ export class BaseController { }); } - /** - * Extension point called if and when this controller is composed - * with other controllers using a ComposableController - */ - onComposed() { - this.requiredControllers.forEach((name) => { - if (!this.context[name]) { - throw new Error(`${this.name} must be composed with ${name}.`); - } - }); - } - /** * Adds new listener to be notified of state changes * diff --git a/src/ComposableController.test.ts b/src/ComposableController.test.ts index 2b02529bed..ea68b17fa7 100644 --- a/src/ComposableController.test.ts +++ b/src/ComposableController.test.ts @@ -11,15 +11,29 @@ import CurrencyRateController from './assets/CurrencyRateController'; describe('ComposableController', () => { it('should compose controller state', () => { + const preferencesController = new PreferencesController(); + const networkController = new NetworkController(); + const assetContractController = new AssetsContractController(); + const assetController = new AssetsController({ + onPreferencesStateChange: (listener) => preferencesController.subscribe(listener), + onNetworkStateChange: (listener) => networkController.subscribe(listener), + getAssetName: assetContractController.getAssetName.bind(assetContractController), + getAssetSymbol: assetContractController.getAssetSymbol.bind(assetContractController), + getCollectibleTokenURI: assetContractController.getCollectibleTokenURI.bind(assetContractController), + }); + const currencyRateController = new CurrencyRateController(); const controller = new ComposableController([ new AddressBookController(), - new AssetsController(), - new AssetsContractController(), + assetController, + assetContractController, new EnsController(), - new CurrencyRateController(), - new NetworkController(), - new PreferencesController(), - new TokenRatesController(), + currencyRateController, + networkController, + preferencesController, + new TokenRatesController({ + onAssetsStateChange: (listener) => assetController.subscribe(listener), + onCurrencyRateStateChange: (listener) => currencyRateController.subscribe(listener), + }), ]); expect(controller.state).toEqual({ AddressBookController: { addressBook: {} }, @@ -62,15 +76,29 @@ describe('ComposableController', () => { }); it('should compose flat controller state', () => { + const preferencesController = new PreferencesController(); + const networkController = new NetworkController(); + const assetContractController = new AssetsContractController(); + const assetController = new AssetsController({ + onPreferencesStateChange: (listener) => preferencesController.subscribe(listener), + onNetworkStateChange: (listener) => networkController.subscribe(listener), + getAssetName: assetContractController.getAssetName.bind(assetContractController), + getAssetSymbol: assetContractController.getAssetSymbol.bind(assetContractController), + getCollectibleTokenURI: assetContractController.getCollectibleTokenURI.bind(assetContractController), + }); + const currencyRateController = new CurrencyRateController(); const controller = new ComposableController([ new AddressBookController(), - new AssetsController(), - new AssetsContractController(), + assetController, + assetContractController, new EnsController(), - new CurrencyRateController(), - new NetworkController(), - new PreferencesController(), - new TokenRatesController(), + currencyRateController, + networkController, + preferencesController, + new TokenRatesController({ + onAssetsStateChange: (listener) => assetController.subscribe(listener), + onCurrencyRateStateChange: (listener) => currencyRateController.subscribe(listener), + }), ]); expect(controller.flatState).toEqual({ addressBook: {}, @@ -101,26 +129,12 @@ describe('ComposableController', () => { }); }); - it('should expose sibling context', () => { - const controller = new ComposableController([ - new AddressBookController(), - new AssetsController(), - new AssetsContractController(), - new CurrencyRateController(), - new EnsController(), - new NetworkController(), - new PreferencesController(), - new TokenRatesController(), - ]); - const addressContext = controller.context.TokenRatesController.context - .AddressBookController as AddressBookController; - expect(addressContext).toBeDefined(); - addressContext.set('0x32Be343B94f860124dC4fEe278FDCBD38C102D88', 'foo'); - expect(controller.flatState).toEqual({ + it('should set initial state', () => { + const state = { addressBook: { - 1: { - '0x32Be343B94f860124dC4fEe278FDCBD38C102D88': { - address: '0x32Be343B94f860124dC4fEe278FDCBD38C102D88', + '0x1': { + '0x1234': { + address: 'bar', chainId: '1', isEns: false, memo: '', @@ -128,58 +142,9 @@ describe('ComposableController', () => { }, }, }, - allCollectibleContracts: {}, - allCollectibles: {}, - allTokens: {}, - collectibleContracts: [], - collectibles: [], - contractExchangeRates: {}, - conversionDate: 0, - conversionRate: 0, - currentCurrency: 'usd', - ensEntries: {}, - featureFlags: {}, - frequentRpcList: [], - identities: {}, - ignoredCollectibles: [], - ignoredTokens: [], - ipfsGateway: 'https://ipfs.io/ipfs/', - lostIdentities: {}, - nativeCurrency: 'ETH', - network: 'loading', - provider: { type: 'mainnet', chainId: NetworksChainId.mainnet }, - selectedAddress: '', - suggestedAssets: [], - tokens: [], - usdConversionRate: 0, - }); - }); - - it('should get and set new stores', () => { - const controller = new ComposableController(); - const addressBook = new AddressBookController(); - controller.controllers = [addressBook]; - expect(controller.controllers).toEqual([addressBook]); - }); - - it('should set initial state', () => { - const state = { - AddressBookController: { - addressBook: [ - { - 1: { - address: 'bar', - chainId: '1', - isEns: false, - memo: '', - name: 'foo', - }, - }, - ], - }, }; - const controller = new ComposableController([new AddressBookController()], state); - expect(controller.state).toEqual(state); + const controller = new ComposableController([new AddressBookController(undefined, state)]); + expect(controller.state).toEqual({ AddressBookController: state }); }); it('should notify listeners of nested state change', () => { diff --git a/src/ComposableController.ts b/src/ComposableController.ts index 2e6374259b..044e194cce 100644 --- a/src/ComposableController.ts +++ b/src/ComposableController.ts @@ -1,12 +1,5 @@ import BaseController from './BaseController'; -/** - * Child controller instances keyed by controller name - */ -export interface ChildControllerContext { - [key: string]: BaseController; -} - /** * List of child controller instances */ @@ -15,15 +8,8 @@ export type ControllerList = BaseController[]; /** * Controller that can be used to compose multiple controllers together */ -export class ComposableController extends BaseController { - private cachedState: any; - - private internalControllers: ControllerList = []; - - /** - * Map of stores to compose together - */ - context: ChildControllerContext = {}; +export class ComposableController extends BaseController { + private controllers: ControllerList = []; /** * Name of this controller used during composition @@ -36,45 +22,22 @@ export class ComposableController extends BaseController { * @param controllers - Map of names to controller instances * @param initialState - Initial state keyed by child controller name */ - constructor(controllers: ControllerList = [], initialState?: any) { - super(); + constructor(controllers: ControllerList) { + super( + undefined, + controllers.reduce((state, controller) => { + state[controller.name] = controller.state; + return state; + }, {} as any), + ); this.initialize(); - this.cachedState = initialState; this.controllers = controllers; - this.cachedState = undefined; - } - - /** - * Get current list of child composed store instances - * - * @returns - List of names to controller instances - */ - get controllers() { - return this.internalControllers; - } - - /** - * Set new list of controller instances - * - * @param controllers - List of names to controller instsances - */ - set controllers(controllers: ControllerList) { - this.internalControllers = controllers; - const initialState: any = {}; - controllers.forEach((controller) => { + this.controllers.forEach((controller) => { const { name } = controller; - this.context[name] = controller; - controller.context = this.context; - this.cachedState?.[name] && controller.update(this.cachedState[name]); - initialState[name] = controller.state; controller.subscribe((state) => { this.update({ [name]: state }); }); }); - controllers.forEach((controller) => { - controller.onComposed(); - }); - this.update(initialState, true); } /** @@ -86,8 +49,8 @@ export class ComposableController extends BaseController { */ get flatState() { let flatState = {}; - for (const name in this.context) { - flatState = { ...flatState, ...this.context[name].state }; + for (const controller of this.controllers) { + flatState = { ...flatState, ...controller.state }; } return flatState; } diff --git a/src/assets/AccountTrackerController.test.ts b/src/assets/AccountTrackerController.test.ts index 925464d5f0..f82204053e 100644 --- a/src/assets/AccountTrackerController.test.ts +++ b/src/assets/AccountTrackerController.test.ts @@ -1,34 +1,37 @@ import { stub, spy } from 'sinon'; import HttpProvider from 'ethjs-provider-http'; +import type { ContactEntry } from '../user/AddressBookController'; import PreferencesController from '../user/PreferencesController'; -import ComposableController from '../ComposableController'; import AccountTrackerController from './AccountTrackerController'; const provider = new HttpProvider('https://ropsten.infura.io/v3/341eacb578dd44a1a049cbc5f6fd4035'); describe('AccountTrackerController', () => { it('should set default state', () => { - const controller = new AccountTrackerController(); + const controller = new AccountTrackerController({ onPreferencesStateChange: stub(), initialIdentities: {} }); expect(controller.state).toEqual({ accounts: {}, }); }); it('should throw when provider property is accessed', () => { - const controller = new AccountTrackerController(); + const controller = new AccountTrackerController({ onPreferencesStateChange: stub(), initialIdentities: {} }); expect(() => console.log(controller.provider)).toThrow('Property only used for setting'); }); it('should get real balance', async () => { const address = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; - const controller = new AccountTrackerController({ provider }); - controller.context = { PreferencesController: { state: { identities: { [address]: {} } } } } as any; + const controller = new AccountTrackerController( + { onPreferencesStateChange: stub(), initialIdentities: { [address]: {} as ContactEntry } }, + { provider }, + ); await controller.refresh(); expect(controller.state.accounts[address].balance).toBeDefined(); }); it('should sync addresses', () => { const controller = new AccountTrackerController( + { onPreferencesStateChange: stub(), initialIdentities: { baz: {} as ContactEntry } }, { provider }, { accounts: { @@ -37,17 +40,18 @@ describe('AccountTrackerController', () => { }, }, ); - controller.context = { PreferencesController: { state: { identities: { baz: {} } } } } as any; controller.refresh(); expect(controller.state.accounts).toEqual({ baz: { balance: '0x0' } }); }); it('should subscribe to new sibling preference controllers', async () => { const preferences = new PreferencesController(); - const controller = new AccountTrackerController({ provider }); + const controller = new AccountTrackerController( + { onPreferencesStateChange: (listener) => preferences.subscribe(listener), initialIdentities: {} }, + { provider }, + ); controller.refresh = stub(); - new ComposableController([controller, preferences]); preferences.setFeatureFlag('foo', true); expect((controller.refresh as any).called).toBe(true); }); @@ -55,11 +59,16 @@ describe('AccountTrackerController', () => { it('should call refresh every ten seconds', async () => { await new Promise((resolve) => { const preferences = new PreferencesController(); - const controller = new AccountTrackerController({ provider, interval: 100 }); + const poll = spy(AccountTrackerController.prototype, 'poll'); + const controller = new AccountTrackerController( + { + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + initialIdentities: {}, + }, + { provider, interval: 100 }, + ); stub(controller, 'refresh'); - const poll = spy(controller, 'poll'); - new ComposableController([controller, preferences]); expect(poll.called).toBe(true); expect(poll.calledTwice).toBe(false); setTimeout(() => { diff --git a/src/assets/AccountTrackerController.ts b/src/assets/AccountTrackerController.ts index 79126524fd..707b3cf218 100644 --- a/src/assets/AccountTrackerController.ts +++ b/src/assets/AccountTrackerController.ts @@ -1,7 +1,7 @@ import EthQuery from 'eth-query'; import { Mutex } from 'async-mutex'; import BaseController, { BaseConfig, BaseState } from '../BaseController'; -import PreferencesController from '../user/PreferencesController'; +import { PreferencesState } from '../user/PreferencesController'; import { BNToHex, query, safelyExecuteWithTimeout } from '../util'; /** @@ -49,11 +49,8 @@ export class AccountTrackerController extends BaseController existing.indexOf(address) === -1); const oldAddresses = existing.filter((address) => addresses.indexOf(address) === -1); @@ -71,24 +68,40 @@ export class AccountTrackerController extends BaseController, state?: Partial) { + constructor( + { + onPreferencesStateChange, + initialIdentities, + }: { + onPreferencesStateChange: (listener: (preferencesState: PreferencesState) => void) => void; + initialIdentities: PreferencesState['identities']; + }, + config?: Partial, + state?: Partial, + ) { super(config, state); this.defaultConfig = { interval: 10000, }; this.defaultState = { accounts: {} }; this.initialize(); + this.identities = initialIdentities; + onPreferencesStateChange(({ identities }) => { + this.identities = identities; + this.refresh(); + }); + this.poll(); } /** @@ -106,17 +119,6 @@ export class AccountTrackerController extends BaseController { const sandbox = createSandbox(); beforeEach(() => { - assetsController = new AssetsController(); preferences = new PreferencesController(); network = new NetworkController(); assetsContract = new AssetsContractController(); - - new ComposableController([assetsController, assetsContract, network, preferences]); + assetsController = new AssetsController({ + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getAssetName: assetsContract.getAssetName.bind(assetsContract), + getAssetSymbol: assetsContract.getAssetSymbol.bind(assetsContract), + getCollectibleTokenURI: assetsContract.getCollectibleTokenURI.bind(assetsContract), + }); nock(OPEN_SEA_HOST) .get(`${OPEN_SEA_PATH}/asset_contract/0xfoO`) @@ -407,9 +410,9 @@ describe('AssetsController', () => { const networkType = 'rinkeby'; const address = '0x123'; preferences.update({ selectedAddress: address }); - expect(assetsController.context.PreferencesController.state.selectedAddress).toEqual(address); + expect(preferences.state.selectedAddress).toEqual(address); network.update({ provider: { type: networkType, chainId: NetworksChainId[networkType] } }); - expect(assetsController.context.NetworkController.state.provider.type).toEqual(networkType); + expect(network.state.provider.type).toEqual(networkType); }); it('should add a valid suggested asset via watchAsset', async () => { diff --git a/src/assets/AssetsController.ts b/src/assets/AssetsController.ts index acdc77f510..834a122cb2 100644 --- a/src/assets/AssetsController.ts +++ b/src/assets/AssetsController.ts @@ -3,12 +3,12 @@ import { toChecksumAddress } from 'ethereumjs-util'; import { v1 as random } from 'uuid'; import { Mutex } from 'async-mutex'; import BaseController, { BaseConfig, BaseState } from '../BaseController'; -import PreferencesController from '../user/PreferencesController'; -import NetworkController, { NetworkType } from '../network/NetworkController'; +import type { PreferencesState } from '../user/PreferencesController'; +import type { NetworkState, NetworkType } from '../network/NetworkController'; import { safelyExecute, handleFetch, validateTokenToWatch } from '../util'; -import { Token } from './TokenRatesController'; -import { AssetsContractController } from './AssetsContractController'; -import { ApiCollectibleResponse } from './AssetsDetectionController'; +import type { Token } from './TokenRatesController'; +import type { ApiCollectibleResponse } from './AssetsDetectionController'; +import type { AssetsContractController } from './AssetsContractController'; /** * @type Collectible @@ -225,8 +225,7 @@ export class AssetsController extends BaseController contractAddress: string, tokenId: number, ): Promise { - const assetsContract = this.context.AssetsContractController as AssetsContractController; - const tokenURI = await assetsContract.getCollectibleTokenURI(contractAddress, tokenId); + const tokenURI = await this.getCollectibleTokenURI(contractAddress, tokenId); const object = await handleFetch(tokenURI); const image = Object.prototype.hasOwnProperty.call(object, 'image') ? 'image' @@ -292,9 +291,8 @@ export class AssetsController extends BaseController private async getCollectibleContractInformationFromContract( contractAddress: string, ): Promise { - const assetsContractController = this.context.AssetsContractController as AssetsContractController; - const name = await assetsContractController.getAssetName(contractAddress); - const symbol = await assetsContractController.getAssetSymbol(contractAddress); + const name = await this.getAssetName(contractAddress); + const symbol = await this.getAssetSymbol(contractAddress); return { name, symbol }; } @@ -505,18 +503,41 @@ export class AssetsController extends BaseController */ name = 'AssetsController'; - /** - * List of required sibling controllers this controller needs to function - */ - requiredControllers = ['AssetsContractController', 'NetworkController', 'PreferencesController']; + private getAssetName: AssetsContractController['getAssetName']; + + private getAssetSymbol: AssetsContractController['getAssetSymbol']; + + private getCollectibleTokenURI: AssetsContractController['getCollectibleTokenURI']; /** * Creates a AssetsController instance * + * @param options + * @param options.onPreferencesStateChange - Allows subscribing to preference controller state changes + * @param options.onNetworkStateChange - Allows subscribing to network controller state changes + * @param options.getAssetName - Gets the name of the asset at the given address + * @param options.getAssetSymbol - Gets the symbol of the asset at the given address + * @param options.getCollectibleTokenURI - Gets the URI of the NFT at the given address, with the given ID * @param config - Initial options used to configure this controller * @param state - Initial state to set on this controller */ - constructor(config?: Partial, state?: Partial) { + constructor( + { + onPreferencesStateChange, + onNetworkStateChange, + getAssetName, + getAssetSymbol, + getCollectibleTokenURI, + }: { + onPreferencesStateChange: (listener: (preferencesState: PreferencesState) => void) => void; + onNetworkStateChange: (listener: (networkState: NetworkState) => void) => void; + getAssetName: AssetsContractController['getAssetName']; + getAssetSymbol: AssetsContractController['getAssetSymbol']; + getCollectibleTokenURI: AssetsContractController['getCollectibleTokenURI']; + }, + config?: Partial, + state?: Partial, + ) { super(config, state); this.defaultConfig = { networkType: 'mainnet', @@ -535,6 +556,30 @@ export class AssetsController extends BaseController tokens: [], }; this.initialize(); + this.getAssetName = getAssetName; + this.getAssetSymbol = getAssetSymbol; + this.getCollectibleTokenURI = getCollectibleTokenURI; + onPreferencesStateChange(({ selectedAddress }) => { + const { allCollectibleContracts, allCollectibles, allTokens } = this.state; + const { chainId } = this.config; + this.configure({ selectedAddress }); + this.update({ + collectibleContracts: allCollectibleContracts[selectedAddress]?.[chainId] || [], + collectibles: allCollectibles[selectedAddress]?.[chainId] || [], + tokens: allTokens[selectedAddress]?.[chainId] || [], + }); + }); + onNetworkStateChange(({ provider }) => { + const { allCollectibleContracts, allCollectibles, allTokens } = this.state; + const { selectedAddress } = this.config; + const { chainId } = provider; + this.configure({ chainId }); + this.update({ + collectibleContracts: allCollectibleContracts[selectedAddress]?.[chainId] || [], + collectibles: allCollectibles[selectedAddress]?.[chainId] || [], + tokens: allTokens[selectedAddress]?.[chainId] || [], + }); + }); } /** @@ -822,37 +867,6 @@ export class AssetsController extends BaseController clearIgnoredCollectibles() { this.update({ ignoredCollectibles: [] }); } - - /** - * Extension point called if and when this controller is composed - * with other controllers using a ComposableController - */ - onComposed() { - super.onComposed(); - const preferences = this.context.PreferencesController as PreferencesController; - const network = this.context.NetworkController as NetworkController; - preferences.subscribe(({ selectedAddress }) => { - const { allCollectibleContracts, allCollectibles, allTokens } = this.state; - const { chainId } = this.config; - this.configure({ selectedAddress }); - this.update({ - collectibleContracts: allCollectibleContracts[selectedAddress]?.[chainId] || [], - collectibles: allCollectibles[selectedAddress]?.[chainId] || [], - tokens: allTokens[selectedAddress]?.[chainId] || [], - }); - }); - network.subscribe(({ provider }) => { - const { allCollectibleContracts, allCollectibles, allTokens } = this.state; - const { selectedAddress } = this.config; - const { chainId } = provider; - this.configure({ chainId }); - this.update({ - collectibleContracts: allCollectibleContracts[selectedAddress]?.[chainId] || [], - collectibles: allCollectibles[selectedAddress]?.[chainId] || [], - tokens: allTokens[selectedAddress]?.[chainId] || [], - }); - }); - } } export default AssetsController; diff --git a/src/assets/AssetsDetectionController.test.ts b/src/assets/AssetsDetectionController.test.ts index 1b5c786cdc..8c0f95be4e 100644 --- a/src/assets/AssetsDetectionController.test.ts +++ b/src/assets/AssetsDetectionController.test.ts @@ -1,9 +1,8 @@ -import { createSandbox, stub } from 'sinon'; -import { BN } from 'ethereumjs-util'; +import { createSandbox, SinonStub, stub } from 'sinon'; import nock from 'nock'; +import { BN } from 'ethereumjs-util'; import { NetworkController, NetworksChainId } from '../network/NetworkController'; import { PreferencesController } from '../user/PreferencesController'; -import { ComposableController } from '../ComposableController'; import { AssetsController } from './AssetsController'; import { AssetsContractController } from './AssetsContractController'; import { AssetsDetectionController } from './AssetsDetectionController'; @@ -21,16 +20,31 @@ describe('AssetsDetectionController', () => { let network: NetworkController; let assets: AssetsController; let assetsContract: AssetsContractController; + let getBalancesInSingleCall: SinonStub<[AssetsContractController['getBalancesInSingleCall']]>; const sandbox = createSandbox(); beforeEach(() => { - assetsDetection = new AssetsDetectionController(); preferences = new PreferencesController(); network = new NetworkController(); - assets = new AssetsController(); assetsContract = new AssetsContractController(); - - new ComposableController([assets, assetsContract, assetsDetection, network, preferences]); + assets = new AssetsController({ + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getAssetName: assetsContract.getAssetName.bind(assetsContract), + getAssetSymbol: assetsContract.getAssetSymbol.bind(assetsContract), + getCollectibleTokenURI: assetsContract.getCollectibleTokenURI.bind(assetsContract), + }); + getBalancesInSingleCall = sandbox.stub(); + assetsDetection = new AssetsDetectionController({ + onAssetsStateChange: (listener) => assets.subscribe(listener), + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getOpenSeaApiKey: () => assets.openSeaApiKey, + getBalancesInSingleCall: (getBalancesInSingleCall as unknown) as AssetsContractController['getBalancesInSingleCall'], + addTokens: assets.addTokens.bind(assets), + addCollectible: assets.addCollectible.bind(assets), + getAssetsState: () => assets.state, + }); nock(OPEN_SEA_HOST) .get(`${OPEN_SEA_PATH}/assets?owner=0x2&limit=300`) @@ -122,7 +136,19 @@ describe('AssetsDetectionController', () => { await new Promise((resolve) => { const mockTokens = stub(AssetsDetectionController.prototype, 'detectTokens'); const mockCollectibles = stub(AssetsDetectionController.prototype, 'detectCollectibles'); - new AssetsDetectionController({ interval: 10 }); + new AssetsDetectionController( + { + onAssetsStateChange: (listener) => assets.subscribe(listener), + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getOpenSeaApiKey: () => assets.openSeaApiKey, + getBalancesInSingleCall: assetsContract.getBalancesInSingleCall.bind(assetsContract), + addTokens: assets.addTokens.bind(assets), + addCollectible: assets.addCollectible.bind(assets), + getAssetsState: () => assets.state, + }, + { interval: 10 }, + ); expect(mockTokens.calledOnce).toBe(true); expect(mockCollectibles.calledOnce).toBe(true); setTimeout(() => { @@ -146,7 +172,19 @@ describe('AssetsDetectionController', () => { await new Promise((resolve) => { const mockTokens = stub(AssetsDetectionController.prototype, 'detectTokens'); const mockCollectibles = stub(AssetsDetectionController.prototype, 'detectCollectibles'); - new AssetsDetectionController({ interval: 10, networkType: ROPSTEN }); + new AssetsDetectionController( + { + onAssetsStateChange: (listener) => assets.subscribe(listener), + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getOpenSeaApiKey: () => assets.openSeaApiKey, + getBalancesInSingleCall: assetsContract.getBalancesInSingleCall.bind(assetsContract), + addTokens: assets.addTokens.bind(assets), + addCollectible: assets.addCollectible.bind(assets), + getAssetsState: () => assets.state, + }, + { interval: 10, networkType: ROPSTEN }, + ); expect(mockTokens.called).toBe(false); expect(mockCollectibles.called).toBe(false); mockTokens.restore(); @@ -327,9 +365,7 @@ describe('AssetsDetectionController', () => { it('should detect tokens correctly', async () => { assetsDetection.configure({ networkType: MAINNET, selectedAddress: '0x1' }); - sandbox - .stub(assetsContract, 'getBalancesInSingleCall') - .resolves({ '0x6810e776880C02933D47DB1b9fc05908e5386b96': new BN(1) }); + getBalancesInSingleCall.resolves({ '0x6810e776880C02933D47DB1b9fc05908e5386b96': new BN(1) }); await assetsDetection.detectTokens(); expect(assets.state.tokens).toEqual([ { @@ -342,9 +378,7 @@ describe('AssetsDetectionController', () => { it('should not autodetect tokens that exist in the ignoreList', async () => { assetsDetection.configure({ networkType: MAINNET, selectedAddress: '0x1' }); - sandbox - .stub(assetsContract, 'getBalancesInSingleCall') - .resolves({ '0x6810e776880C02933D47DB1b9fc05908e5386b96': new BN(1) }); + getBalancesInSingleCall.resolves({ '0x6810e776880C02933D47DB1b9fc05908e5386b96': new BN(1) }); await assetsDetection.detectTokens(); assets.removeAndIgnoreToken('0x6810e776880C02933D47DB1b9fc05908e5386b96'); @@ -354,9 +388,7 @@ describe('AssetsDetectionController', () => { it('should not detect tokens if there is no selectedAddress set', async () => { assetsDetection.configure({ networkType: MAINNET }); - sandbox - .stub(assetsContract, 'getBalancesInSingleCall') - .resolves({ '0x6810e776880C02933D47DB1b9fc05908e5386b96': new BN(1) }); + getBalancesInSingleCall.resolves({ '0x6810e776880C02933D47DB1b9fc05908e5386b96': new BN(1) }); await assetsDetection.detectTokens(); expect(assets.state.tokens).toEqual([]); }); @@ -369,14 +401,14 @@ describe('AssetsDetectionController', () => { const detectAssets = sandbox.stub(assetsDetection, 'detectAssets'); preferences.update({ selectedAddress: secondAddress }); preferences.update({ selectedAddress: secondAddress }); - expect(assetsDetection.context.PreferencesController.state.selectedAddress).toEqual(secondAddress); + expect(preferences.state.selectedAddress).toEqual(secondAddress); expect(detectAssets.calledTwice).toBe(false); preferences.update({ selectedAddress: firstAddress }); - expect(assetsDetection.context.PreferencesController.state.selectedAddress).toEqual(firstAddress); + expect(preferences.state.selectedAddress).toEqual(firstAddress); network.update({ provider: { type: secondNetworkType, chainId: NetworksChainId[secondNetworkType] } }); - expect(assetsDetection.context.NetworkController.state.provider.type).toEqual(secondNetworkType); + expect(network.state.provider.type).toEqual(secondNetworkType); network.update({ provider: { type: firstNetworkType, chainId: NetworksChainId[firstNetworkType] } }); - expect(assetsDetection.context.NetworkController.state.provider.type).toEqual(firstNetworkType); + expect(network.state.provider.type).toEqual(firstNetworkType); assets.update({ tokens: TOKENS }); expect(assetsDetection.config.tokens).toEqual(TOKENS); }); diff --git a/src/assets/AssetsDetectionController.ts b/src/assets/AssetsDetectionController.ts index c213e6a302..dd3ff9fa40 100644 --- a/src/assets/AssetsDetectionController.ts +++ b/src/assets/AssetsDetectionController.ts @@ -1,14 +1,13 @@ import { toChecksumAddress } from 'ethereumjs-util'; import contractMap from '@metamask/contract-metadata'; import BaseController, { BaseConfig, BaseState } from '../BaseController'; -import NetworkController, { NetworkType } from '../network/NetworkController'; -import PreferencesController from '../user/PreferencesController'; +import type { NetworkState, NetworkType } from '../network/NetworkController'; +import type { PreferencesState } from '../user/PreferencesController'; import { safelyExecute, timeoutFetch } from '../util'; -import AssetsContractController from './AssetsContractController'; +import type { AssetsController, AssetsState } from './AssetsController'; +import type { AssetsContractController } from './AssetsContractController'; import { Token } from './TokenRatesController'; -import AssetsController from './AssetsController'; - const DEFAULT_INTERVAL = 180000; const MAINNET = 'mainnet'; @@ -61,12 +60,12 @@ export class AssetsDetectionController extends BaseController string | undefined; + + private getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; + + private addTokens: AssetsController['addTokens']; + + private addCollectible: AssetsController['addCollectible']; + + private getAssetsState: () => AssetsState; /** * Creates a AssetsDetectionController instance * + * @param options + * @param options.onAssetsStateChange - Allows subscribing to assets controller state changes + * @param options.onPreferencesStateChange - Allows subscribing to preferences controller state changes + * @param options.onNetworkStateChange - Allows subscribing to network controller state changes + * @param options.getOpenSeaApiKey - Gets the OpenSea API key, if one is set + * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address + * @param options.addTokens - Add a list of tokens + * @param options.addCollectible - Add a collectible + * @param options.initialAssetsState - The initial state of the Assets controller * @param config - Initial options used to configure this controller * @param state - Initial state to set on this controller */ - constructor(config?: Partial, state?: Partial) { + constructor( + { + onAssetsStateChange, + onPreferencesStateChange, + onNetworkStateChange, + getOpenSeaApiKey, + getBalancesInSingleCall, + addTokens, + addCollectible, + getAssetsState, + }: { + onAssetsStateChange: (listener: (assetsState: AssetsState) => void) => void; + onPreferencesStateChange: (listener: (preferencesState: PreferencesState) => void) => void; + onNetworkStateChange: (listener: (networkState: NetworkState) => void) => void; + getOpenSeaApiKey: () => string | undefined; + getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; + addTokens: AssetsController['addTokens']; + addCollectible: AssetsController['addCollectible']; + getAssetsState: () => AssetsState; + }, + config?: Partial, + state?: Partial, + ) { super(config, state); this.defaultConfig = { interval: DEFAULT_INTERVAL, @@ -104,6 +139,24 @@ export class AssetsDetectionController extends BaseController { + this.configure({ tokens }); + }); + onPreferencesStateChange(({ selectedAddress }) => { + const actualSelectedAddress = this.config.selectedAddress; + if (selectedAddress !== actualSelectedAddress) { + this.configure({ selectedAddress }); + this.detectAssets(); + } + }); + onNetworkStateChange(({ provider }) => { + this.configure({ networkType: provider.type }); + }); + this.getOpenSeaApiKey = getOpenSeaApiKey; + this.getBalancesInSingleCall = getBalancesInSingleCall; + this.addCollectible = addCollectible; this.poll(); } @@ -162,20 +215,18 @@ export class AssetsDetectionController extends BaseController { - const balances = await assetsContractController.getBalancesInSingleCall(selectedAddress, tokensToDetect); - const assetsController = this.context.AssetsController as AssetsController; - const { ignoredTokens } = assetsController.state; + const balances = await this.getBalancesInSingleCall(selectedAddress, tokensToDetect); const tokensToAdd = []; for (const tokenAddress in balances) { let ignored; /* istanbul ignore else */ + const { ignoredTokens } = this.getAssetsState(); if (ignoredTokens.length) { ignored = ignoredTokens.find((token) => token.address === toChecksumAddress(tokenAddress)); } @@ -188,7 +239,7 @@ export class AssetsDetectionController extends BaseController { - const assetsController = this.context.AssetsController as AssetsController; - const { ignoredCollectibles } = assetsController.state; const apiCollectibles = await this.getOwnerCollectibles(); const addCollectiblesPromises = apiCollectibles.map(async (collectible: ApiCollectibleResponse) => { const { @@ -222,6 +271,7 @@ export class AssetsDetectionController extends BaseController { /* istanbul ignore next */ @@ -230,7 +280,7 @@ export class AssetsDetectionController extends BaseController { - this.configure({ tokens }); - }); - preferences.subscribe(({ selectedAddress }) => { - const actualSelectedAddress = this.config.selectedAddress; - if (selectedAddress !== actualSelectedAddress) { - this.configure({ selectedAddress }); - this.detectAssets(); - } - }); - network.subscribe(({ provider }) => { - this.configure({ networkType: provider.type }); - }); - } } export default AssetsDetectionController; diff --git a/src/assets/TokenBalancesController.test.ts b/src/assets/TokenBalancesController.test.ts index c919379e8c..eb3b194114 100644 --- a/src/assets/TokenBalancesController.test.ts +++ b/src/assets/TokenBalancesController.test.ts @@ -1,7 +1,6 @@ import { createSandbox, stub } from 'sinon'; import { BN } from 'ethereumjs-util'; import HttpProvider from 'ethjs-provider-http'; -import ComposableController from '../ComposableController'; import { NetworkController } from '../network/NetworkController'; import { PreferencesController } from '../user/PreferencesController'; import { AssetsController } from './AssetsController'; @@ -12,18 +11,13 @@ import { BN as exportedBn, TokenBalancesController } from './TokenBalancesContro const MAINNET_PROVIDER = new HttpProvider('https://mainnet.infura.io'); describe('TokenBalancesController', () => { - let tokenBalances: TokenBalancesController; const sandbox = createSandbox(); - const getToken = (address: string) => { + const getToken = (tokenBalances: TokenBalancesController, address: string) => { const { tokens } = tokenBalances.config; return tokens.find((token) => token.address === address); }; - beforeEach(() => { - tokenBalances = new TokenBalancesController(); - }); - afterEach(() => { sandbox.restore(); }); @@ -33,10 +27,20 @@ describe('TokenBalancesController', () => { }); it('should set default state', () => { + const tokenBalances = new TokenBalancesController({ + onAssetsStateChange: stub(), + getSelectedAddress: () => '0x1234', + getBalanceOf: stub(), + }); expect(tokenBalances.state).toEqual({ contractBalances: {} }); }); it('should set default config', () => { + const tokenBalances = new TokenBalancesController({ + onAssetsStateChange: stub(), + getSelectedAddress: () => '0x1234', + getBalanceOf: stub(), + }); expect(tokenBalances.config).toEqual({ interval: 180000, tokens: [], @@ -46,7 +50,14 @@ describe('TokenBalancesController', () => { it('should poll and update balances in the right interval', async () => { await new Promise((resolve) => { const mock = stub(TokenBalancesController.prototype, 'updateBalances'); - new TokenBalancesController({ interval: 10 }); + new TokenBalancesController( + { + onAssetsStateChange: stub(), + getSelectedAddress: () => '0x1234', + getBalanceOf: stub(), + }, + { interval: 10 }, + ); expect(mock.called).toBe(true); expect(mock.calledTwice).toBe(false); setTimeout(() => { @@ -58,21 +69,35 @@ describe('TokenBalancesController', () => { }); it('should not update rates if disabled', async () => { - const controller = new TokenBalancesController({ - disabled: true, - interval: 10, - }); - const mock = stub(controller, 'update'); - await controller.updateBalances(); + const tokenBalances = new TokenBalancesController( + { + onAssetsStateChange: stub(), + getSelectedAddress: () => '0x1234', + getBalanceOf: stub(), + }, + { + disabled: true, + interval: 10, + }, + ); + const mock = stub(tokenBalances, 'update'); + await tokenBalances.updateBalances(); expect(mock.called).toBe(false); }); it('should clear previous interval', async () => { const mock = stub(global, 'clearTimeout'); - const controller = new TokenBalancesController({ interval: 1337 }); + const tokenBalances = new TokenBalancesController( + { + onAssetsStateChange: stub(), + getSelectedAddress: () => '0x1234', + getBalanceOf: stub(), + }, + { interval: 1337 }, + ); await new Promise((resolve) => { setTimeout(() => { - controller.poll(1338); + tokenBalances.poll(1338); expect(mock.called).toBe(true); mock.restore(); resolve(); @@ -81,46 +106,66 @@ describe('TokenBalancesController', () => { }); it('should update all balances', async () => { - const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - expect(tokenBalances.state.contractBalances).toEqual({}); - tokenBalances.configure({ tokens: [{ address, decimals: 18, symbol: 'EOS' }] }); - const assets = new AssetsController(); const assetsContract = new AssetsContractController(); const network = new NetworkController(); const preferences = new PreferencesController(); + const assets = new AssetsController({ + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getAssetName: assetsContract.getAssetName.bind(assetsContract), + getAssetSymbol: assetsContract.getAssetSymbol.bind(assetsContract), + getCollectibleTokenURI: assetsContract.getCollectibleTokenURI.bind(assetsContract), + }); + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const tokenBalances = new TokenBalancesController( + { + onAssetsStateChange: (listener) => assets.subscribe(listener), + getSelectedAddress: () => preferences.state.selectedAddress, + getBalanceOf: stub().returns(new BN(1)), + }, + { interval: 1337, tokens: [{ address, decimals: 18, symbol: 'EOS' }] }, + ); + expect(tokenBalances.state.contractBalances).toEqual({}); - new ComposableController([assets, assetsContract, network, preferences, tokenBalances]); assetsContract.configure({ provider: MAINNET_PROVIDER }); - stub(assetsContract, 'getBalanceOf').resolves(new BN(1)); await tokenBalances.updateBalances(); - const mytoken = getToken(address); + const mytoken = getToken(tokenBalances, address); expect(mytoken?.balanceError).toBeNull(); expect(Object.keys(tokenBalances.state.contractBalances)).toContain(address); expect(tokenBalances.state.contractBalances[address].toNumber()).toBeGreaterThan(0); }); it('should handle `getBalanceOf` error case', async () => { - const errorMsg = 'Failed to get balance'; - const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; - expect(tokenBalances.state.contractBalances).toEqual({}); - tokenBalances.configure({ tokens: [{ address, decimals: 18, symbol: 'EOS' }] }); - const assets = new AssetsController(); - const assetsContract = new AssetsContractController(); + const assetsContract = new AssetsContractController({ provider: MAINNET_PROVIDER }); const network = new NetworkController(); const preferences = new PreferencesController(); + const assets = new AssetsController({ + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getAssetName: assetsContract.getAssetName.bind(assetsContract), + getAssetSymbol: assetsContract.getAssetSymbol.bind(assetsContract), + getCollectibleTokenURI: assetsContract.getCollectibleTokenURI.bind(assetsContract), + }); + const errorMsg = 'Failed to get balance'; + const address = '0x86fa049857e0209aa7d9e616f7eb3b3b78ecfdb0'; + const getBalanceOfStub = stub().returns(Promise.reject(new Error(errorMsg))); + const tokenBalances = new TokenBalancesController( + { + onAssetsStateChange: (listener) => assets.subscribe(listener), + getSelectedAddress: () => preferences.state.selectedAddress, + getBalanceOf: getBalanceOfStub, + }, + { interval: 1337, tokens: [{ address, decimals: 18, symbol: 'EOS' }] }, + ); - new ComposableController([assets, assetsContract, network, preferences, tokenBalances]); - assetsContract.configure({ provider: MAINNET_PROVIDER }); - const mock = stub(assetsContract, 'getBalanceOf').returns(Promise.reject(new Error(errorMsg))); + expect(tokenBalances.state.contractBalances).toEqual({}); await tokenBalances.updateBalances(); - const mytoken = getToken(address); + const mytoken = getToken(tokenBalances, address); expect(mytoken?.balanceError).toBeInstanceOf(Error); expect(mytoken?.balanceError?.message).toBe(errorMsg); expect(tokenBalances.state.contractBalances[address].toNumber()).toEqual(0); - // test reset case - mock.restore(); - stub(assetsContract, 'getBalanceOf').resolves(new BN(1)); + getBalanceOfStub.returns(new BN(1)); await tokenBalances.updateBalances(); expect(mytoken?.balanceError).toBeNull(); expect(Object.keys(tokenBalances.state.contractBalances)).toContain(address); @@ -128,15 +173,27 @@ describe('TokenBalancesController', () => { }); it('should subscribe to new sibling assets controllers', async () => { - const assets = new AssetsController(); const assetsContract = new AssetsContractController(); const network = new NetworkController(); const preferences = new PreferencesController(); - - new ComposableController([assets, assetsContract, network, preferences, tokenBalances]); + const assets = new AssetsController({ + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getAssetName: assetsContract.getAssetName.bind(assetsContract), + getAssetSymbol: assetsContract.getAssetSymbol.bind(assetsContract), + getCollectibleTokenURI: assetsContract.getCollectibleTokenURI.bind(assetsContract), + }); + const tokenBalances = new TokenBalancesController( + { + onAssetsStateChange: (listener) => assets.subscribe(listener), + getSelectedAddress: () => preferences.state.selectedAddress, + getBalanceOf: assetsContract.getBalanceOf.bind(assetsContract), + }, + { interval: 1337 }, + ); const updateBalances = sandbox.stub(tokenBalances, 'updateBalances'); await assets.addToken('0xfoO', 'FOO', 18); - const { tokens } = tokenBalances.context.AssetsController.state; + const { tokens } = assets.state; const found = tokens.filter((token: Token) => token.address === '0xfoO'); expect(found.length > 0).toBe(true); expect(updateBalances.called).toBe(true); diff --git a/src/assets/TokenBalancesController.ts b/src/assets/TokenBalancesController.ts index 6f53751e17..38813a7aad 100644 --- a/src/assets/TokenBalancesController.ts +++ b/src/assets/TokenBalancesController.ts @@ -1,9 +1,10 @@ import { BN } from 'ethereumjs-util'; import BaseController, { BaseConfig, BaseState } from '../BaseController'; import { safelyExecute } from '../util'; -import AssetsController from './AssetsController'; +import type { PreferencesState } from '../user/PreferencesController'; import { Token } from './TokenRatesController'; -import { AssetsContractController } from './AssetsContractController'; +import type { AssetsState } from './AssetsController'; +import type { AssetsContractController } from './AssetsContractController'; // TODO: Remove this export in the next major release export { BN }; @@ -44,18 +45,33 @@ export class TokenBalancesController extends BaseController PreferencesState['selectedAddress']; + + private getBalanceOf: AssetsContractController['getBalanceOf']; /** * Creates a TokenBalancesController instance * + * @param options + * @param options.onAssetsStateChange - Allows subscribing to assets controller state changes + * @param options.getSelectedAddress - Gets the current selected address + * @param options.getBalanceOf - Gets the balance of the given account at the given contract address * @param config - Initial options used to configure this controller * @param state - Initial state to set on this controller */ - constructor(config?: Partial, state?: Partial) { + constructor( + { + onAssetsStateChange, + getSelectedAddress, + getBalanceOf, + }: { + onAssetsStateChange: (listener: (tokenState: AssetsState) => void) => void; + getSelectedAddress: () => PreferencesState['selectedAddress']; + getBalanceOf: AssetsContractController['getBalanceOf']; + }, + config?: Partial, + state?: Partial, + ) { super(config, state); this.defaultConfig = { interval: 180000, @@ -63,6 +79,12 @@ export class TokenBalancesController extends BaseController { + this.configure({ tokens }); + this.updateBalances(); + }); + this.getSelectedAddress = getSelectedAddress; + this.getBalanceOf = getBalanceOf; this.poll(); } @@ -89,15 +111,12 @@ export class TokenBalancesController extends BaseController { - this.configure({ tokens }); - this.updateBalances(); - }); - } } export default TokenBalancesController; diff --git a/src/assets/TokenRatesController.test.ts b/src/assets/TokenRatesController.test.ts index 93e081141b..f601eccf53 100644 --- a/src/assets/TokenRatesController.test.ts +++ b/src/assets/TokenRatesController.test.ts @@ -1,6 +1,5 @@ import { stub } from 'sinon'; import nock from 'nock'; -import ComposableController from '../ComposableController'; import { PreferencesController } from '../user/PreferencesController'; import { NetworkController } from '../network/NetworkController'; import TokenRatesController, { Token } from './TokenRatesController'; @@ -35,12 +34,12 @@ describe('TokenRatesController', () => { }); it('should set default state', () => { - const controller = new TokenRatesController(); + const controller = new TokenRatesController({ onAssetsStateChange: stub(), onCurrencyRateStateChange: stub() }); expect(controller.state).toEqual({ contractExchangeRates: {} }); }); it('should initialize with the default config', () => { - const controller = new TokenRatesController(); + const controller = new TokenRatesController({ onAssetsStateChange: stub(), onCurrencyRateStateChange: stub() }); expect(controller.config).toEqual({ disabled: false, interval: 180000, @@ -50,17 +49,20 @@ describe('TokenRatesController', () => { }); it('should throw when tokens property is accessed', () => { - const controller = new TokenRatesController(); + const controller = new TokenRatesController({ onAssetsStateChange: stub(), onCurrencyRateStateChange: stub() }); expect(() => console.log(controller.tokens)).toThrow('Property only used for setting'); }); it('should poll and update rate in the right interval', async () => { await new Promise((resolve) => { const mock = stub(TokenRatesController.prototype, 'fetchExchangeRate'); - new TokenRatesController({ - interval: 10, - tokens: [{ address: 'bar', decimals: 0, symbol: '' }], - }); + new TokenRatesController( + { onAssetsStateChange: stub(), onCurrencyRateStateChange: stub() }, + { + interval: 10, + tokens: [{ address: 'bar', decimals: 0, symbol: '' }], + }, + ); expect(mock.called).toBe(true); expect(mock.calledTwice).toBe(false); setTimeout(() => { @@ -72,9 +74,12 @@ describe('TokenRatesController', () => { }); it('should not update rates if disabled', async () => { - const controller = new TokenRatesController({ - interval: 10, - }); + const controller = new TokenRatesController( + { onAssetsStateChange: stub(), onCurrencyRateStateChange: stub() }, + { + interval: 10, + }, + ); controller.fetchExchangeRate = stub(); controller.disabled = true; await controller.updateExchangeRates(); @@ -83,7 +88,10 @@ describe('TokenRatesController', () => { it('should clear previous interval', async () => { const mock = stub(global, 'clearTimeout'); - const controller = new TokenRatesController({ interval: 1337 }); + const controller = new TokenRatesController( + { onAssetsStateChange: stub(), onCurrencyRateStateChange: stub() }, + { interval: 1337 }, + ); await new Promise((resolve) => { setTimeout(() => { controller.poll(1338); @@ -95,14 +103,24 @@ describe('TokenRatesController', () => { }); it('should update all rates', async () => { - const assets = new AssetsController(); const assetsContract = new AssetsContractController(); - const currencyRate = new CurrencyRateController(); - const controller = new TokenRatesController({ interval: 10 }); const network = new NetworkController(); const preferences = new PreferencesController(); - - new ComposableController([controller, assets, assetsContract, currencyRate, network, preferences]); + const assets = new AssetsController({ + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getAssetName: assetsContract.getAssetName.bind(assetsContract), + getAssetSymbol: assetsContract.getAssetSymbol.bind(assetsContract), + getCollectibleTokenURI: assetsContract.getCollectibleTokenURI.bind(assetsContract), + }); + const currencyRate = new CurrencyRateController(); + const controller = new TokenRatesController( + { + onAssetsStateChange: (listener) => assets.subscribe(listener), + onCurrencyRateStateChange: (listener) => currencyRate.subscribe(listener), + }, + { interval: 10 }, + ); const address = '0x89d24A6b4CcB1B6fAA2625fE562bDD9a23260359'; const address2 = '0xfoO'; expect(controller.state.contractExchangeRates).toEqual({}); @@ -118,7 +136,10 @@ describe('TokenRatesController', () => { }); it('should handle balance not found in API', async () => { - const controller = new TokenRatesController({ interval: 10 }); + const controller = new TokenRatesController( + { onAssetsStateChange: stub(), onCurrencyRateStateChange: stub() }, + { interval: 10 }, + ); stub(controller, 'fetchExchangeRate').throws({ error: 'Not Found', message: 'Not Found' }); expect(controller.state.contractExchangeRates).toEqual({}); controller.tokens = [{ address: 'bar', decimals: 0, symbol: '' }]; @@ -128,17 +149,27 @@ describe('TokenRatesController', () => { }); it('should subscribe to new sibling assets controllers', async () => { - const assets = new AssetsController(); const assetsContract = new AssetsContractController(); - const currencyRate = new CurrencyRateController(); - const controller = new TokenRatesController(); const network = new NetworkController(); const preferences = new PreferencesController(); - - new ComposableController([controller, assets, assetsContract, currencyRate, network, preferences]); + const assets = new AssetsController({ + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + onNetworkStateChange: (listener) => network.subscribe(listener), + getAssetName: assetsContract.getAssetName.bind(assetsContract), + getAssetSymbol: assetsContract.getAssetSymbol.bind(assetsContract), + getCollectibleTokenURI: assetsContract.getCollectibleTokenURI.bind(assetsContract), + }); + const currencyRate = new CurrencyRateController(); + const controller = new TokenRatesController( + { + onAssetsStateChange: (listener) => assets.subscribe(listener), + onCurrencyRateStateChange: (listener) => currencyRate.subscribe(listener), + }, + { interval: 10 }, + ); await assets.addToken('0xfoO', 'FOO', 18); currencyRate.update({ nativeCurrency: 'gno' }); - const { tokens } = controller.context.AssetsController.state; + const { tokens } = assets.state; const found = tokens.filter((token: Token) => token.address === '0xfoO'); expect(found.length > 0).toBe(true); expect(controller.config.nativeCurrency).toEqual('gno'); diff --git a/src/assets/TokenRatesController.ts b/src/assets/TokenRatesController.ts index 89ae923bcd..b9e0ef47db 100644 --- a/src/assets/TokenRatesController.ts +++ b/src/assets/TokenRatesController.ts @@ -1,8 +1,9 @@ import { toChecksumAddress } from 'ethereumjs-util'; import BaseController, { BaseConfig, BaseState } from '../BaseController'; import { safelyExecute, handleFetch } from '../util'; -import AssetsController from './AssetsController'; -import CurrencyRateController from './CurrencyRateController'; + +import type { AssetsState } from './AssetsController'; +import type { CurrencyRateState } from './CurrencyRateController'; /** * @type CoinGeckoResponse @@ -77,18 +78,26 @@ export class TokenRatesController extends BaseController, state?: Partial) { + constructor( + { + onAssetsStateChange, + onCurrencyRateStateChange, + }: { + onAssetsStateChange: (listener: (assetState: AssetsState) => void) => void; + onCurrencyRateStateChange: (listener: (currencyRateState: CurrencyRateState) => void) => void; + }, + config?: Partial, + state?: Partial, + ) { super(config, state); this.defaultConfig = { disabled: true, @@ -99,6 +108,12 @@ export class TokenRatesController extends BaseController { + this.configure({ tokens: assetsState.tokens }); + }); + onCurrencyRateStateChange((currencyRateState) => { + this.configure({ nativeCurrency: currencyRateState.nativeCurrency }); + }); this.poll(); } @@ -142,22 +157,6 @@ export class TokenRatesController extends BaseController { - this.configure({ tokens: assets.state.tokens }); - }); - currencyRate.subscribe(() => { - this.configure({ nativeCurrency: currencyRate.state.nativeCurrency }); - }); - } - /** * Updates exchange rates for all tokens * diff --git a/src/keyring/KeyringController.test.ts b/src/keyring/KeyringController.test.ts index 6306efe65b..5f592b2897 100644 --- a/src/keyring/KeyringController.test.ts +++ b/src/keyring/KeyringController.test.ts @@ -9,7 +9,6 @@ import { stub } from 'sinon'; import Transaction from 'ethereumjs-tx'; import MockEncryptor from '../../tests/mocks/mockEncryptor'; import PreferencesController from '../user/PreferencesController'; -import ComposableController from '../ComposableController'; import KeyringController, { AccountImportStrategy, Keyring, @@ -33,10 +32,17 @@ describe('KeyringController', () => { let initialState: { isUnlocked: boolean; keyringTypes: string[]; keyrings: Keyring[] }; const baseConfig: Partial = { encryptor: new MockEncryptor() }; beforeEach(async () => { - keyringController = new KeyringController(baseConfig); preferences = new PreferencesController(); + keyringController = new KeyringController( + { + removeIdentity: preferences.removeIdentity.bind(preferences), + syncIdentities: preferences.syncIdentities.bind(preferences), + updateIdentities: preferences.updateIdentities.bind(preferences), + setSelectedAddress: preferences.setSelectedAddress.bind(preferences), + }, + baseConfig, + ); - new ComposableController([keyringController, preferences]); initialState = await keyringController.createNewVaultAndKeychain(password); }); diff --git a/src/keyring/KeyringController.ts b/src/keyring/KeyringController.ts index 27a7e6ad65..b4378ef942 100644 --- a/src/keyring/KeyringController.ts +++ b/src/keyring/KeyringController.ts @@ -117,24 +117,50 @@ export class KeyringController extends BaseController, state?: Partial) { + constructor( + { + removeIdentity, + syncIdentities, + updateIdentities, + setSelectedAddress, + }: { + removeIdentity: PreferencesController['removeIdentity']; + syncIdentities: PreferencesController['syncIdentities']; + updateIdentities: PreferencesController['updateIdentities']; + setSelectedAddress: PreferencesController['setSelectedAddress']; + }, + config?: Partial, + state?: Partial, + ) { super(config, state); privates.set(this, { keyring: new Keyring(Object.assign({ initState: state }, config)) }); this.defaultState = { ...privates.get(this).keyring.store.getState(), keyrings: [], }; + this.removeIdentity = removeIdentity; + this.syncIdentities = syncIdentities; + this.updateIdentities = updateIdentities; + this.setSelectedAddress = setSelectedAddress; this.initialize(); this.fullUpdate(); } @@ -145,7 +171,6 @@ export class KeyringController extends BaseController { - const preferences = this.context.PreferencesController as PreferencesController; const primaryKeyring = privates.get(this).keyring.getKeyringsByType('HD Key Tree')[0]; /* istanbul ignore if */ if (!primaryKeyring) { @@ -157,10 +182,10 @@ export class KeyringController extends BaseController { if (!oldAccounts.includes(selectedAddress)) { - preferences.update({ selectedAddress }); + this.setSelectedAddress(selectedAddress); } }); return this.fullUpdate(); @@ -191,12 +216,11 @@ export class KeyringController extends BaseController { let privateKey; - const preferences = this.context.PreferencesController as PreferencesController; switch (strategy) { case 'privateKey': const [importedKey] = args; @@ -307,8 +329,8 @@ export class KeyringController extends BaseController { - const preferences = this.context.PreferencesController as PreferencesController; - preferences.removeIdentity(address); + this.removeIdentity(address); await privates.get(this).keyring.removeAccount(address); return this.fullUpdate(); } @@ -404,10 +425,9 @@ export class KeyringController extends BaseController { - const preferences = this.context.PreferencesController as PreferencesController; await privates.get(this).keyring.submitPassword(password); const accounts = await privates.get(this).keyring.getAccounts(); - await preferences.syncIdentities(accounts); + await this.syncIdentities(accounts); return this.fullUpdate(); } diff --git a/src/transaction/TransactionController.test.ts b/src/transaction/TransactionController.test.ts index f29daa68c5..746e6d7052 100644 --- a/src/transaction/TransactionController.test.ts +++ b/src/transaction/TransactionController.test.ts @@ -1,6 +1,6 @@ import { stub } from 'sinon'; import HttpProvider from 'ethjs-provider-http'; -import { NetworksChainId } from '../network/NetworkController'; +import { NetworksChainId, NetworkType, NetworkState } from '../network/NetworkController'; import { TransactionController, TransactionStatus, TransactionMeta } from './TransactionController'; const globalAny: any = global; @@ -63,18 +63,18 @@ const MOCK_PRFERENCES = { state: { selectedAddress: 'foo' } }; const PROVIDER = new HttpProvider('https://ropsten.infura.io/v3/341eacb578dd44a1a049cbc5f6fd4035'); const MAINNET_PROVIDER = new HttpProvider('https://mainnet.infura.io/v3/341eacb578dd44a1a049cbc5f6fd4035'); const MOCK_NETWORK = { - provider: PROVIDER, - state: { network: '3', provider: { type: 'ropsten', chainId: NetworksChainId.ropsten } }, + getProvider: () => PROVIDER, + state: { network: '3', provider: { type: 'ropsten' as NetworkType, chainId: NetworksChainId.ropsten } }, subscribe: () => undefined, }; const MOCK_NETWORK_WITHOUT_CHAIN_ID = { - provider: PROVIDER, - state: { network: '3', provider: { type: 'ropsten' } }, + getProvider: () => PROVIDER, + state: { network: '3', provider: { type: 'ropsten' as NetworkType } }, subscribe: () => undefined, }; const MOCK_MAINNET_NETWORK = { - provider: MAINNET_PROVIDER, - state: { network: '1', provider: { type: 'mainnet', chainId: NetworksChainId.mainnet } }, + getProvider: () => MAINNET_PROVIDER, + state: { network: '1', provider: { type: 'mainnet' as NetworkType, chainId: NetworksChainId.mainnet } }, subscribe: () => undefined, }; @@ -525,12 +525,20 @@ describe('TransactionController', () => { }); it('should set default state', () => { - const controller = new TransactionController(); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); expect(controller.state).toEqual({ methodData: {}, transactions: [] }); }); it('should set default config', () => { - const controller = new TransactionController(); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); expect(controller.config).toEqual({ interval: 5000, provider: undefined, @@ -540,7 +548,14 @@ describe('TransactionController', () => { it('should poll and update transaction statuses in the right interval', async () => { await new Promise((resolve) => { const mock = stub(TransactionController.prototype, 'queryTransactionStatuses'); - new TransactionController({ interval: 10 }); + new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { interval: 10 }, + ); expect(mock.called).toBe(true); expect(mock.calledTwice).toBe(false); setTimeout(() => { @@ -553,7 +568,14 @@ describe('TransactionController', () => { it('should clear previous interval', async () => { const mock = stub(global, 'clearTimeout'); - const controller = new TransactionController({ interval: 1337 }); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { interval: 1337 }, + ); await new Promise((resolve) => { setTimeout(() => { controller.poll(1338); @@ -566,7 +588,14 @@ describe('TransactionController', () => { it('should not update the state if there are no updates on transaction statuses', async () => { await new Promise((resolve) => { - const controller = new TransactionController({ interval: 10 }); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { interval: 10 }, + ); const func = stub(controller, 'update'); setTimeout(() => { expect(func.called).toBe(false); @@ -577,17 +606,21 @@ describe('TransactionController', () => { }); it('should throw when adding invalid transaction', async () => { - const controller = new TransactionController(); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); await expect(controller.addTransaction({ from: 'foo' } as any)).rejects.toThrow('Invalid "from" address'); }); it('should add a valid transaction', async () => { - const controller = new TransactionController({ provider: PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); await controller.addTransaction({ from, to: from, @@ -598,12 +631,42 @@ describe('TransactionController', () => { expect(controller.state.transactions[0].status).toBe(TransactionStatus.unapproved); }); + it('should add a valid transaction after a network switch', async () => { + const getNetworkState = stub().returns(MOCK_NETWORK.state); + let networkStateChangeListener: ((state: NetworkState) => void) | null = null; + const onNetworkStateChange = (listener: (state: NetworkState) => void) => { + networkStateChangeListener = listener; + }; + const getProvider = stub().returns(PROVIDER); + const controller = new TransactionController({ + getNetworkState, + onNetworkStateChange, + getProvider, + }); + + // switch from Ropsten to Mainnet + getNetworkState.returns(MOCK_MAINNET_NETWORK.state); + getProvider.returns(MAINNET_PROVIDER); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + networkStateChangeListener!(MOCK_MAINNET_NETWORK.state); + + const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; + await controller.addTransaction({ + from, + to: from, + }); + expect(controller.state.transactions[0].transaction.from).toBe(from); + expect(controller.state.transactions[0].networkID).toBe(MOCK_MAINNET_NETWORK.state.network); + expect(controller.state.transactions[0].chainId).toBe(MOCK_MAINNET_NETWORK.state.provider.chainId); + expect(controller.state.transactions[0].status).toBe(TransactionStatus.unapproved); + }); + it('should cancel a transaction', async () => { - const controller = new TransactionController({ provider: PROVIDER }); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; const { result } = await controller.addTransaction({ from, @@ -623,12 +686,12 @@ describe('TransactionController', () => { }); it('should wipe transactions', async () => { - const controller = new TransactionController({ provider: PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); controller.wipeTransactions(); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; await controller.addTransaction({ from, @@ -640,12 +703,12 @@ describe('TransactionController', () => { // This tests the fallback to networkID only when there is no chainId present. Should be removed when networkID is completely removed. it('should wipe transactions using networkID when there is no chainId', async () => { - const controller = new TransactionController({ provider: PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); controller.wipeTransactions(); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); controller.state.transactions.push({ from: MOCK_PRFERENCES.state.selectedAddress, id: 'foo', @@ -658,16 +721,18 @@ describe('TransactionController', () => { }); it('should fail to approve an invalid transaction', async () => { - const controller = new TransactionController({ - provider: PROVIDER, - sign: () => { - throw new Error('foo'); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, }, - }); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); + { + sign: () => { + throw new Error('foo'); + }, + }, + ); const from = '0xe6509775f3f3614576c0d83f8647752f87cd6659'; const to = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; const { result } = await controller.addTransaction({ from, to }); @@ -680,12 +745,12 @@ describe('TransactionController', () => { }); it('should fail transaction if gas calculation fails', async () => { - const controller = new TransactionController({ provider: PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); mockFlags.estimateGas = 'Uh oh'; await expect( controller.addTransaction({ @@ -696,13 +761,14 @@ describe('TransactionController', () => { }); it('should fail if no sign method defined', async () => { - const controller = new TransactionController({ - provider: PROVIDER, - }); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + {}, + ); const from = '0xe6509775f3f3614576c0d83f8647752f87cd6659'; const to = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; const { result } = await controller.addTransaction({ from, to }); @@ -714,15 +780,17 @@ describe('TransactionController', () => { await expect(result).rejects.toThrow('No sign method defined'); }); - it('should fail if no chainId defined', async () => { - const controller = new TransactionController({ - provider: PROVIDER, - sign: async (transaction: any) => transaction, - }); - controller.context = { - NetworkController: MOCK_NETWORK_WITHOUT_CHAIN_ID, - } as any; - controller.onComposed(); + it('should fail if no chainId is defined', async () => { + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK_WITHOUT_CHAIN_ID.state as NetworkState, + onNetworkStateChange: MOCK_NETWORK_WITHOUT_CHAIN_ID.subscribe, + getProvider: MOCK_NETWORK_WITHOUT_CHAIN_ID.getProvider, + }, + { + sign: async (transaction: any) => transaction, + }, + ); const from = '0xe6509775f3f3614576c0d83f8647752f87cd6659'; const to = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; const { result } = await controller.addTransaction({ from, to }); @@ -736,15 +804,17 @@ describe('TransactionController', () => { it('should approve a transaction', async () => { await new Promise(async (resolve) => { - const controller = new TransactionController({ - provider: PROVIDER, - sign: async (transaction: any) => transaction, - }); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { + sign: async (transaction: any) => transaction, + }, + ); const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); await controller.addTransaction({ from, gas: '0x0', @@ -764,14 +834,16 @@ describe('TransactionController', () => { it('should query transaction statuses', async () => { await new Promise((resolve) => { - const controller = new TransactionController({ - provider: PROVIDER, - sign: async (transaction: any) => transaction, - }); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { + sign: async (transaction: any) => transaction, + }, + ); controller.state.transactions.push({ from: MOCK_PRFERENCES.state.selectedAddress, id: 'foo', @@ -793,14 +865,16 @@ describe('TransactionController', () => { // This tests the fallback to networkID only when there is no chainId present. Should be removed when networkID is completely removed. it('should query transaction statuses with networkID only when there is no chainId', async () => { await new Promise((resolve) => { - const controller = new TransactionController({ - provider: PROVIDER, - sign: async (transaction: any) => transaction, - }); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { + sign: async (transaction: any) => transaction, + }, + ); controller.state.transactions.push({ from: MOCK_PRFERENCES.state.selectedAddress, id: 'foo', @@ -820,12 +894,12 @@ describe('TransactionController', () => { it('should fetch all the transactions from an address, including incoming transactions, in ropsten', async () => { globalAny.fetch = mockFetchs(MOCK_FETCH_TX_HISTORY_DATA_OK); - const controller = new TransactionController({ provider: PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); controller.wipeTransactions(); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); expect(controller.state.transactions).toHaveLength(0); const from = '0x6bf137f335ea1b8f193b8f6ea92561a60d23a207'; @@ -837,12 +911,12 @@ describe('TransactionController', () => { it('should fetch all the transactions from an address, including incoming token transactions, in mainnet', async () => { globalAny.fetch = mockFetchs(MOCK_FETCH_TX_HISTORY_DATA_OK); - const controller = new TransactionController({ provider: MAINNET_PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_MAINNET_NETWORK.state, + onNetworkStateChange: MOCK_MAINNET_NETWORK.subscribe, + getProvider: MOCK_MAINNET_NETWORK.getProvider, + }); controller.wipeTransactions(); - controller.context = { - NetworkController: MOCK_MAINNET_NETWORK, - } as any; - controller.onComposed(); expect(controller.state.transactions).toHaveLength(0); const from = '0x6bf137f335ea1b8f193b8f6ea92561a60d23a207'; @@ -854,13 +928,13 @@ describe('TransactionController', () => { it('should fetch all the transactions from an address, including incoming token transactions, but not adding the ones already in state', async () => { globalAny.fetch = mockFetchs(MOCK_FETCH_TX_HISTORY_DATA_OK); - const controller = new TransactionController({ provider: MAINNET_PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_MAINNET_NETWORK.state, + onNetworkStateChange: MOCK_MAINNET_NETWORK.subscribe, + getProvider: MOCK_MAINNET_NETWORK.getProvider, + }); const from = '0x6bf137f335ea1b8f193b8f6ea92561a60d23a207'; controller.wipeTransactions(); - controller.context = { - NetworkController: MOCK_MAINNET_NETWORK, - } as any; - controller.onComposed(); controller.state.transactions = TRANSACTIONS_IN_STATE; await controller.fetchAll(from); expect(controller.state.transactions).toHaveLength(17); @@ -876,12 +950,12 @@ describe('TransactionController', () => { it('should fetch all the transactions from an address, including incoming transactions, in mainnet from block', async () => { globalAny.fetch = mockFetchs(MOCK_FETCH_TX_HISTORY_DATA_OK); - const controller = new TransactionController({ provider: MAINNET_PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_MAINNET_NETWORK.state, + onNetworkStateChange: MOCK_MAINNET_NETWORK.subscribe, + getProvider: MOCK_MAINNET_NETWORK.getProvider, + }); controller.wipeTransactions(); - controller.context = { - NetworkController: MOCK_MAINNET_NETWORK, - } as any; - controller.onComposed(); expect(controller.state.transactions).toHaveLength(0); const from = '0x6bf137f335ea1b8f193b8f6ea92561a60d23a207'; @@ -893,12 +967,12 @@ describe('TransactionController', () => { it('should return', async () => { globalAny.fetch = mockFetch(MOCK_FETCH_TX_HISTORY_DATA_ERROR); - const controller = new TransactionController({ provider: PROVIDER }); + const controller = new TransactionController({ + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }); controller.wipeTransactions(); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); expect(controller.state.transactions).toHaveLength(0); const from = '0x6bf137f335ea1b8f193b8f6ea92561a60d23a207'; const result = await controller.fetchAll(from); @@ -907,11 +981,14 @@ describe('TransactionController', () => { }); it('should handle new method data', async () => { - const controller = new TransactionController({ provider: MOCK_MAINNET_NETWORK }); - controller.context = { - NetworkController: MOCK_MAINNET_NETWORK, - } as any; - controller.onComposed(); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_MAINNET_NETWORK.state, + onNetworkStateChange: MOCK_MAINNET_NETWORK.subscribe, + getProvider: MOCK_MAINNET_NETWORK.getProvider, + }, + {}, + ); const registry = await controller.handleMethodData('0xf39b5b9b'); expect(registry.parsedRegistryMethod).toEqual({ args: [{ type: 'uint256' }, { type: 'uint256' }], @@ -921,11 +998,14 @@ describe('TransactionController', () => { }); it('should handle known method data', async () => { - const controller = new TransactionController({ provider: MOCK_MAINNET_NETWORK }); - controller.context = { - NetworkController: MOCK_MAINNET_NETWORK, - } as any; - controller.onComposed(); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_MAINNET_NETWORK.state, + onNetworkStateChange: MOCK_MAINNET_NETWORK.subscribe, + getProvider: MOCK_MAINNET_NETWORK.getProvider, + }, + {}, + ); const registry = await controller.handleMethodData('0xf39b5b9b'); expect(registry.parsedRegistryMethod).toEqual({ args: [{ type: 'uint256' }, { type: 'uint256' }], @@ -937,15 +1017,17 @@ describe('TransactionController', () => { }); it('should stop a transaction', async () => { - const controller = new TransactionController({ - provider: PROVIDER, - sign: async (transaction: any) => transaction, - }); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { + sign: async (transaction: any) => transaction, + }, + ); const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); const { result } = await controller.addTransaction({ from, gas: '0x0', @@ -959,12 +1041,10 @@ describe('TransactionController', () => { it('should fail to stop a transaction if no sign method', async () => { const controller = new TransactionController({ - provider: PROVIDER, + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, }); - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); const from = '0xe6509775f3f3614576c0d83f8647752f87cd6659'; const to = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; await controller.addTransaction({ from, to }); @@ -976,15 +1056,17 @@ describe('TransactionController', () => { it('should speed up a transaction', async () => { await new Promise(async (resolve) => { - const controller = new TransactionController({ - provider: PROVIDER, - sign: async (transaction: any) => transaction, - }); + const controller = new TransactionController( + { + getNetworkState: () => MOCK_NETWORK.state, + onNetworkStateChange: MOCK_NETWORK.subscribe, + getProvider: MOCK_NETWORK.getProvider, + }, + { + sign: async (transaction: any) => transaction, + }, + ); const from = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; - controller.context = { - NetworkController: MOCK_NETWORK, - } as any; - controller.onComposed(); await controller.addTransaction({ from, gas: '0x0', diff --git a/src/transaction/TransactionController.ts b/src/transaction/TransactionController.ts index dd79ce4b68..7f9bed4fd1 100644 --- a/src/transaction/TransactionController.ts +++ b/src/transaction/TransactionController.ts @@ -7,8 +7,7 @@ import Transaction from 'ethereumjs-tx'; import { v1 as random } from 'uuid'; import { Mutex } from 'async-mutex'; import BaseController, { BaseConfig, BaseState } from '../BaseController'; -import NetworkController from '../network/NetworkController'; - +import type { NetworkState, NetworkController } from '../network/NetworkController'; import { BNToHex, fractionBN, @@ -191,7 +190,6 @@ export interface EtherscanTransactionMeta { */ export interface TransactionConfig extends BaseConfig { interval: number; - provider: any; sign?: (transaction: Transaction, from: string) => Promise; } @@ -243,6 +241,8 @@ export class TransactionController extends BaseController NetworkState; + private failTransaction(transactionMeta: TransactionMeta, error: Error) { const newTransactionMeta = { ...transactionMeta, @@ -349,11 +349,6 @@ export class TransactionController extends BaseController, state?: Partial) { + constructor( + { + getNetworkState, + onNetworkStateChange, + getProvider, + }: { + getNetworkState: () => NetworkState; + onNetworkStateChange: (listener: (state: NetworkState) => void) => void; + getProvider: () => NetworkController['provider']; + }, + config?: Partial, + state?: Partial, + ) { super(config, state); this.defaultConfig = { interval: 5000, - provider: undefined, }; this.defaultState = { methodData: {}, transactions: [], }; this.initialize(); + const provider = getProvider(); + this.getNetworkState = getNetworkState; + this.ethQuery = new EthQuery(provider); + this.registry = new MethodRegistry({ provider }); + onNetworkStateChange(() => { + const newProvider = getProvider(); + this.ethQuery = new EthQuery(newProvider); + this.registry = new MethodRegistry({ provider: newProvider }); + }); this.poll(); } @@ -426,22 +445,15 @@ export class TransactionController extends BaseController { - const network = this.context.NetworkController as NetworkController; + const { provider, network } = this.getNetworkState(); const { transactions } = this.state; transaction = normalizeTransaction(transaction); validateTransaction(transaction); - const { - state: { - network: networkID, - provider: { chainId }, - }, - } = network; - const transactionMeta = { id: random(), - networkID, - chainId, + networkID: network, + chainId: provider.chainId, origin, status: TransactionStatus.unapproved as TransactionStatus.unapproved, time: Date.now(), @@ -494,9 +506,8 @@ export class TransactionController extends BaseController transactionID === id); const transactionMeta = transactions[index]; const { nonce } = transactionMeta.transaction; @@ -681,23 +692,6 @@ export class TransactionController extends BaseController { - this.ethQuery = network.provider ? new EthQuery(network.provider) : /* istanbul ignore next */ null; - this.registry = network.provider - ? new MethodRegistry({ provider: network.provider }) /* istanbul ignore next */ - : null; - }; - onProviderUpdate(); - network.subscribe(onProviderUpdate); - } - /** * Resiliently checks all submitted transactions on the blockchain * and verifies that it has been included in a block @@ -707,9 +701,8 @@ export class TransactionController extends BaseController Promise.all( @@ -761,12 +754,8 @@ export class TransactionController extends BaseController { // Using fallback to networkID only when there is no chainId present. Should be removed when networkID is completely removed. const isCurrentNetwork = chainId === currentChainId || (!chainId && networkID === currentNetworkID); @@ -785,13 +774,8 @@ export class TransactionController extends BaseController { - const network = this.context.NetworkController; - const { - state: { - network: currentNetworkID, - provider: { type: networkType, chainId: currentChainId }, - }, - } = network; + const { provider, network: currentNetworkID } = this.getNetworkState(); + const { chainId: currentChainId, type: networkType } = provider; const supportedNetworkIds = ['1', '3', '4', '42']; /* istanbul ignore next */ From b9ea46972d6877b2eb64c8371ef0792831e0ca0e Mon Sep 17 00:00:00 2001 From: Mark Stacey Date: Tue, 23 Mar 2021 18:15:19 -0230 Subject: [PATCH 2/3] Replace `initialIdentities` with `getIdentities` The `initialIdentities` option for the `AccountTrackerController` has been replaced with a `getIdentities` option that returns the identities on-demand. This bypasses the need to manage a copy of the identity state in the `AccountTrackerController`, and is a bit more similar to how this would be done with the new base controller API. --- src/assets/AccountTrackerController.test.ts | 33 ++++++++++++++++----- src/assets/AccountTrackerController.ts | 15 +++++----- src/user/PreferencesController.ts | 1 - 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/assets/AccountTrackerController.test.ts b/src/assets/AccountTrackerController.test.ts index f82204053e..8500d5b566 100644 --- a/src/assets/AccountTrackerController.test.ts +++ b/src/assets/AccountTrackerController.test.ts @@ -1,28 +1,39 @@ import { stub, spy } from 'sinon'; import HttpProvider from 'ethjs-provider-http'; import type { ContactEntry } from '../user/AddressBookController'; -import PreferencesController from '../user/PreferencesController'; +import { PreferencesController } from '../user/PreferencesController'; import AccountTrackerController from './AccountTrackerController'; const provider = new HttpProvider('https://ropsten.infura.io/v3/341eacb578dd44a1a049cbc5f6fd4035'); describe('AccountTrackerController', () => { it('should set default state', () => { - const controller = new AccountTrackerController({ onPreferencesStateChange: stub(), initialIdentities: {} }); + const controller = new AccountTrackerController({ + onPreferencesStateChange: stub(), + getIdentities: () => ({}), + }); expect(controller.state).toEqual({ accounts: {}, }); }); it('should throw when provider property is accessed', () => { - const controller = new AccountTrackerController({ onPreferencesStateChange: stub(), initialIdentities: {} }); + const controller = new AccountTrackerController({ + onPreferencesStateChange: stub(), + getIdentities: () => ({}), + }); expect(() => console.log(controller.provider)).toThrow('Property only used for setting'); }); it('should get real balance', async () => { const address = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; const controller = new AccountTrackerController( - { onPreferencesStateChange: stub(), initialIdentities: { [address]: {} as ContactEntry } }, + { + onPreferencesStateChange: stub(), + getIdentities: () => { + return { [address]: {} as ContactEntry }; + }, + }, { provider }, ); await controller.refresh(); @@ -31,7 +42,12 @@ describe('AccountTrackerController', () => { it('should sync addresses', () => { const controller = new AccountTrackerController( - { onPreferencesStateChange: stub(), initialIdentities: { baz: {} as ContactEntry } }, + { + onPreferencesStateChange: stub(), + getIdentities: () => { + return { baz: {} as ContactEntry }; + }, + }, { provider }, { accounts: { @@ -47,7 +63,10 @@ describe('AccountTrackerController', () => { it('should subscribe to new sibling preference controllers', async () => { const preferences = new PreferencesController(); const controller = new AccountTrackerController( - { onPreferencesStateChange: (listener) => preferences.subscribe(listener), initialIdentities: {} }, + { + onPreferencesStateChange: (listener) => preferences.subscribe(listener), + getIdentities: () => ({}), + }, { provider }, ); controller.refresh = stub(); @@ -63,7 +82,7 @@ describe('AccountTrackerController', () => { const controller = new AccountTrackerController( { onPreferencesStateChange: (listener) => preferences.subscribe(listener), - initialIdentities: {}, + getIdentities: () => ({}), }, { provider, interval: 100 }, ); diff --git a/src/assets/AccountTrackerController.ts b/src/assets/AccountTrackerController.ts index 707b3cf218..e342035144 100644 --- a/src/assets/AccountTrackerController.ts +++ b/src/assets/AccountTrackerController.ts @@ -50,7 +50,7 @@ export class AccountTrackerController extends BaseController existing.indexOf(address) === -1); const oldAddresses = existing.filter((address) => addresses.indexOf(address) === -1); @@ -68,24 +68,24 @@ export class AccountTrackerController extends BaseController PreferencesState['identities']; /** * Creates an AccountTracker instance * * @param options * @param options.onPreferencesStateChange - Allows subscribing to preference controller state changes - * @param options.initialIdentities - The initial `identities` state from the Preferences controller + * @param options.getIdentities - Gets the identities from the Preferences store * @param config - Initial options used to configure this controller * @param state - Initial state to set on this controller */ constructor( { onPreferencesStateChange, - initialIdentities, + getIdentities, }: { onPreferencesStateChange: (listener: (preferencesState: PreferencesState) => void) => void; - initialIdentities: PreferencesState['identities']; + getIdentities: () => PreferencesState['identities']; }, config?: Partial, state?: Partial, @@ -96,9 +96,8 @@ export class AccountTrackerController extends BaseController { - this.identities = identities; + this.getIdentities = getIdentities; + onPreferencesStateChange(() => { this.refresh(); }); this.poll(); diff --git a/src/user/PreferencesController.ts b/src/user/PreferencesController.ts index 4d952770cf..076156c7b2 100644 --- a/src/user/PreferencesController.ts +++ b/src/user/PreferencesController.ts @@ -47,7 +47,6 @@ export interface PreferencesState extends BaseState { lostIdentities: { [address: string]: ContactEntry }; selectedAddress: string; } - /** * Controller that stores shared settings and exposes convenience methods */ From f63c6f10029baa5357275185650cc8ba3d7f8e66 Mon Sep 17 00:00:00 2001 From: Mark Stacey Date: Tue, 13 Apr 2021 19:33:07 -0230 Subject: [PATCH 3/3] Fix `AssetsDetectionController` constructor comment The constructor JSDoc comment has been fixed. Also a mistakenly removed newline has been restored. --- src/assets/AssetsDetectionController.ts | 2 +- src/user/PreferencesController.ts | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/assets/AssetsDetectionController.ts b/src/assets/AssetsDetectionController.ts index dd3ff9fa40..5e492620a1 100644 --- a/src/assets/AssetsDetectionController.ts +++ b/src/assets/AssetsDetectionController.ts @@ -104,7 +104,7 @@ export class AssetsDetectionController extends BaseController