From dfc3acdd5e9d70d17aaa8d77c92e9e376f135ca6 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 15 Mar 2023 17:59:52 -0500 Subject: [PATCH] add rollbackToPreviousProvider method --- .../src/NftController.test.ts | 5 +- packages/controller-utils/jest.config.js | 2 +- packages/controller-utils/src/types.ts | 17 +- .../src/NetworkController.ts | 56 +- .../tests/NetworkController.test.ts | 764 +++++++++++------- .../tests/provider-api-tests/helpers.ts | 4 +- .../tests/provider-api-tests/shared-tests.ts | 5 +- 7 files changed, 545 insertions(+), 308 deletions(-) diff --git a/packages/assets-controllers/src/NftController.test.ts b/packages/assets-controllers/src/NftController.test.ts index b65b1e34d49..353f72a5f6b 100644 --- a/packages/assets-controllers/src/NftController.test.ts +++ b/packages/assets-controllers/src/NftController.test.ts @@ -14,6 +14,7 @@ import { OPENSEA_API_URL, ERC721, NetworksChainId, + NetworkType, } from '@metamask/controller-utils'; import { Network } from '@ethersproject/providers'; import { AssetsContractController } from './AssetsContractController'; @@ -46,8 +47,8 @@ const DEPRESSIONIST_CLOUDFLARE_IPFS_SUBDOMAIN_PATH = getFormattedIpfsUrl( true, ); -const SEPOLIA = { chainId: '11155111', type: 'sepolia' as const }; -const GOERLI = { chainId: '5', type: 'goerli' as const }; +const SEPOLIA = { chainId: '11155111', type: NetworkType.sepolia }; +const GOERLI = { chainId: '5', type: NetworkType.goerli }; // Mock out detectNetwork function for cleaner tests, Ethers calls this a bunch of times because the Web3Provider is paranoid. jest.mock('@ethersproject/providers', () => { diff --git a/packages/controller-utils/jest.config.js b/packages/controller-utils/jest.config.js index 52278f2baf4..60e9658a633 100644 --- a/packages/controller-utils/jest.config.js +++ b/packages/controller-utils/jest.config.js @@ -17,7 +17,7 @@ module.exports = merge(baseConfig, { coverageThreshold: { global: { branches: 68.05, - functions: 80.55, + functions: 76.92, lines: 69.82, statements: 70.17, }, diff --git a/packages/controller-utils/src/types.ts b/packages/controller-utils/src/types.ts index 5be4fc21bc8..4abdbf46f7a 100644 --- a/packages/controller-utils/src/types.ts +++ b/packages/controller-utils/src/types.ts @@ -1,12 +1,17 @@ /** * Human-readable network name */ -export type NetworkType = - | 'localhost' - | 'mainnet' - | 'goerli' - | 'sepolia' - | 'rpc'; +export enum NetworkType { + localhost = 'localhost', + mainnet = 'mainnet', + goerli = 'goerli', + sepolia = 'sepolia', + rpc = 'rpc', +} + +export const isNetworkType = (val: any): val is NetworkType => { + return Object.values(NetworkType).includes(val); +}; export enum NetworksChainId { mainnet = '1', diff --git a/packages/network-controller/src/NetworkController.ts b/packages/network-controller/src/NetworkController.ts index 13ee2ee3660..c55084a2891 100644 --- a/packages/network-controller/src/NetworkController.ts +++ b/packages/network-controller/src/NetworkController.ts @@ -12,8 +12,6 @@ import { RestrictedControllerMessenger, } from '@metamask/base-controller'; import { - MAINNET, - RPC, TESTNET_NETWORK_TYPE_TO_TICKER_SYMBOL, NetworksChainId, NetworkType, @@ -21,6 +19,7 @@ import { } from '@metamask/controller-utils'; import { assertIsStrictHexString } from '@metamask/utils'; +import { isNetworkType } from '../../controller-utils/src/types'; /** * @type ProviderConfig @@ -39,7 +38,7 @@ export type ProviderConfig = { chainId: string; ticker?: string; nickname?: string; - id?: string; + id?: NetworkConfigurationId; }; export type Block = { @@ -148,7 +147,10 @@ export type NetworkControllerOptions = { export const defaultState: NetworkState = { network: 'loading', isCustomNetwork: false, - providerConfig: { type: MAINNET, chainId: NetworksChainId.mainnet }, + providerConfig: { + type: NetworkType.mainnet, + chainId: NetworksChainId.mainnet, + }, networkDetails: { isEIP1559Compatible: false }, networkConfigurations: {}, }; @@ -166,6 +168,8 @@ type MetaMetricsEventPayload = { value?: number; }; +type NetworkConfigurationId = string; + /** * Controller that creates and manages an Ethereum network provider. */ @@ -184,6 +188,8 @@ export class NetworkController extends BaseControllerV2< private mutex = new Mutex(); + #previousNetworkSpecifier: NetworkType | NetworkConfigurationId | null; + #provider: Provider | undefined; #providerProxy: ProviderProxy | undefined; @@ -238,6 +244,8 @@ export class NetworkController extends BaseControllerV2< return this.ethQuery; }, ); + + this.#previousNetworkSpecifier = this.state.providerConfig.type; } private initializeProvider( @@ -252,15 +260,15 @@ export class NetworkController extends BaseControllerV2< }); switch (type) { - case MAINNET: - case 'goerli': - case 'sepolia': + case NetworkType.mainnet: + case NetworkType.goerli: + case NetworkType.sepolia: this.setupInfuraProvider(type); break; - case 'localhost': + case NetworkType.localhost: this.setupStandardProvider(LOCALHOST_RPC_URL); break; - case RPC: + case NetworkType.rpc: rpcUrl && this.setupStandardProvider(rpcUrl, chainId, ticker, nickname); break; default: @@ -433,12 +441,25 @@ export class NetworkController extends BaseControllerV2< } } + /** + * Convenience method to set the current provider config to the private providerConfig class variable. + */ + #setCurrentAsPreviousProvider() { + const { type, id } = this.state.providerConfig; + if (type === NetworkType.rpc && id) { + this.#previousNetworkSpecifier = id; + } else { + this.#previousNetworkSpecifier = type; + } + } + /** * Convenience method to update provider network type settings. * * @param type - Human readable network name. */ setProviderType(type: NetworkType) { + this.#setCurrentAsPreviousProvider(); // If testnet the ticker symbol should use a testnet prefix const ticker = type in TESTNET_NETWORK_TYPE_TO_TICKER_SYMBOL && @@ -452,6 +473,7 @@ export class NetworkController extends BaseControllerV2< state.providerConfig.chainId = NetworksChainId[type]; state.providerConfig.rpcUrl = undefined; state.providerConfig.nickname = undefined; + state.providerConfig.id = undefined; }); this.refreshNetwork(); } @@ -462,6 +484,8 @@ export class NetworkController extends BaseControllerV2< * @param networkConfigurationId - The unique id for the network configuration to set as the active provider. */ setActiveNetwork(networkConfigurationId: string) { + this.#setCurrentAsPreviousProvider(); + const targetNetwork = this.state.networkConfigurations[networkConfigurationId]; @@ -472,7 +496,7 @@ export class NetworkController extends BaseControllerV2< } this.update((state) => { - state.providerConfig.type = RPC; + state.providerConfig.type = NetworkType.rpc; state.providerConfig.rpcUrl = targetNetwork.rpcUrl; state.providerConfig.chainId = targetNetwork.chainId; state.providerConfig.ticker = targetNetwork.ticker; @@ -663,6 +687,18 @@ export class NetworkController extends BaseControllerV2< delete state.networkConfigurations[networkConfigurationId]; }); } + + /** + * Rolls back provider config to the previous provider in case of errors or inability to connect during network switch. + */ + rollbackToPreviousProvider() { + const specifier = this.#previousNetworkSpecifier; + if (isNetworkType(specifier)) { + this.setProviderType(specifier); + } else if (typeof specifier === 'string') { + this.setActiveNetwork(specifier); + } + } } export default NetworkController; diff --git a/packages/network-controller/tests/NetworkController.test.ts b/packages/network-controller/tests/NetworkController.test.ts index bac19c6d8cc..9296a892eff 100644 --- a/packages/network-controller/tests/NetworkController.test.ts +++ b/packages/network-controller/tests/NetworkController.test.ts @@ -10,6 +10,7 @@ import type { ProviderEngine } from 'web3-provider-engine'; import createMetamaskProvider from 'web3-provider-engine/zero'; import { Patch } from 'immer'; import { v4 } from 'uuid'; +import { NetworkType } from '@metamask/controller-utils'; import { waitForResult } from '../../../tests/helpers'; import { NetworkController, @@ -85,7 +86,7 @@ describe('NetworkController', () => { networkConfigurations: {}, network: 'loading', isCustomNetwork: false, - providerConfig: { type: 'mainnet' as const, chainId: '1' }, + providerConfig: { type: NetworkType.mainnet, chainId: '1' }, networkDetails: { isEIP1559Compatible: false }, }); }); @@ -104,7 +105,7 @@ describe('NetworkController', () => { networkConfigurations: {}, network: 'loading', isCustomNetwork: true, - providerConfig: { type: 'mainnet', chainId: '1' }, + providerConfig: { type: NetworkType.mainnet, chainId: '1' }, networkDetails: { isEIP1559Compatible: true }, }); }, @@ -183,259 +184,263 @@ describe('NetworkController', () => { }); }); - (['mainnet', 'goerli', 'sepolia'] as const).forEach((networkType) => { - describe(`when the provider config in state contains a network type of "${networkType}"`, () => { - it(`sets the provider to an Infura provider pointed to ${networkType}`, async () => { - await withController( - { - state: { - providerConfig: buildProviderConfig({ - type: networkType, - }), + [NetworkType.mainnet, NetworkType.goerli, NetworkType.sepolia].forEach( + (networkType) => { + describe(`when the provider config in state contains a network type of "${networkType}"`, () => { + it(`sets the provider to an Infura provider pointed to ${networkType}`, async () => { + await withController( + { + state: { + providerConfig: buildProviderConfig({ + type: networkType, + }), + }, + infuraProjectId: 'infura-project-id', }, - infuraProjectId: 'infura-project-id', - }, - async ({ controller }) => { - const fakeInfuraProvider = buildFakeInfuraProvider(); - createInfuraProviderMock.mockReturnValue(fakeInfuraProvider); - const fakeInfuraSubprovider = buildFakeInfuraSubprovider(); - SubproviderMock.mockReturnValue(fakeInfuraSubprovider); - const fakeMetamaskProvider = buildFakeMetamaskProvider([ - { - request: { - method: 'eth_chainId', - }, - response: { - result: '0x1337', + async ({ controller }) => { + const fakeInfuraProvider = buildFakeInfuraProvider(); + createInfuraProviderMock.mockReturnValue(fakeInfuraProvider); + const fakeInfuraSubprovider = buildFakeInfuraSubprovider(); + SubproviderMock.mockReturnValue(fakeInfuraSubprovider); + const fakeMetamaskProvider = buildFakeMetamaskProvider([ + { + request: { + method: 'eth_chainId', + }, + response: { + result: '0x1337', + }, }, - }, - ]); - createMetamaskProviderMock.mockReturnValue( - fakeMetamaskProvider, - ); + ]); + createMetamaskProviderMock.mockReturnValue( + fakeMetamaskProvider, + ); - controller.providerConfig = { - // NOTE: Neither the type nor chainId needs to match the - // values in state, or match each other - type: 'mainnet', - chainId: '99999', - nickname: 'some nickname', - }; + controller.providerConfig = { + // NOTE: Neither the type nor chainId needs to match the + // values in state, or match each other + type: NetworkType.mainnet, + chainId: '99999', + nickname: 'some nickname', + }; + + expect(createInfuraProviderMock).toHaveBeenCalledWith({ + network: networkType, + projectId: 'infura-project-id', + }); + expect(createMetamaskProviderMock).toHaveBeenCalledWith({ + type: NetworkType.mainnet, + chainId: '99999', + nickname: 'some nickname', + dataSubprovider: fakeInfuraSubprovider, + engineParams: { + blockTrackerProvider: fakeInfuraProvider, + pollingInterval: 12000, + }, + }); + const { provider } = controller.getProviderAndBlockTracker(); + assert(provider, 'Provider is not set'); + const promisifiedSendAsync = promisify( + provider.sendAsync, + ).bind(provider); + const chainIdResult = await promisifiedSendAsync({ + id: 1, + jsonrpc: '2.0', + method: 'eth_chainId', + }); + expect(chainIdResult.result).toBe('0x1337'); + }, + ); + }); - expect(createInfuraProviderMock).toHaveBeenCalledWith({ - network: networkType, - projectId: 'infura-project-id', - }); - expect(createMetamaskProviderMock).toHaveBeenCalledWith({ - type: 'mainnet', - chainId: '99999', - nickname: 'some nickname', - dataSubprovider: fakeInfuraSubprovider, - engineParams: { - blockTrackerProvider: fakeInfuraProvider, - pollingInterval: 12000, + it('ensures that the existing provider is stopped while replacing it', async () => { + await withController( + { + state: { + providerConfig: buildProviderConfig({ + type: networkType, + }), }, - }); - const { provider } = controller.getProviderAndBlockTracker(); - assert(provider, 'Provider is not set'); - const promisifiedSendAsync = promisify(provider.sendAsync).bind( - provider, - ); - const chainIdResult = await promisifiedSendAsync({ - id: 1, - jsonrpc: '2.0', - method: 'eth_chainId', - }); - expect(chainIdResult.result).toBe('0x1337'); - }, - ); - }); - - it('ensures that the existing provider is stopped while replacing it', async () => { - await withController( - { - state: { - providerConfig: buildProviderConfig({ - type: networkType, - }), + infuraProjectId: 'infura-project-id', }, - infuraProjectId: 'infura-project-id', - }, - ({ controller }) => { - const fakeInfuraProvider = buildFakeInfuraProvider(); - createInfuraProviderMock.mockReturnValue(fakeInfuraProvider); - const fakeInfuraSubprovider = buildFakeInfuraSubprovider(); - SubproviderMock.mockReturnValue(fakeInfuraSubprovider); - const fakeMetamaskProviders = [ - buildFakeMetamaskProvider(), - buildFakeMetamaskProvider(), - ]; - jest.spyOn(fakeMetamaskProviders[0], 'stop'); - createMetamaskProviderMock - .mockImplementationOnce(() => fakeMetamaskProviders[0]) - .mockImplementationOnce(() => fakeMetamaskProviders[1]); - - controller.providerConfig = buildProviderConfig(); - controller.providerConfig = buildProviderConfig(); - assert(controller.getProviderAndBlockTracker().provider); - jest.runAllTimers(); - - expect(fakeMetamaskProviders[0].stop).toHaveBeenCalled(); - }, - ); - }); + ({ controller }) => { + const fakeInfuraProvider = buildFakeInfuraProvider(); + createInfuraProviderMock.mockReturnValue(fakeInfuraProvider); + const fakeInfuraSubprovider = buildFakeInfuraSubprovider(); + SubproviderMock.mockReturnValue(fakeInfuraSubprovider); + const fakeMetamaskProviders = [ + buildFakeMetamaskProvider(), + buildFakeMetamaskProvider(), + ]; + jest.spyOn(fakeMetamaskProviders[0], 'stop'); + createMetamaskProviderMock + .mockImplementationOnce(() => fakeMetamaskProviders[0]) + .mockImplementationOnce(() => fakeMetamaskProviders[1]); + + controller.providerConfig = buildProviderConfig(); + controller.providerConfig = buildProviderConfig(); + assert(controller.getProviderAndBlockTracker().provider); + jest.runAllTimers(); + + expect(fakeMetamaskProviders[0].stop).toHaveBeenCalled(); + }, + ); + }); - describe('when an "error" event occurs on the new provider', () => { - describe('if the network version could not be retrieved while providerConfig was being set', () => { - it('retrieves the network version twice more (due to the "error" event being listened to twice) and, assuming success, persists it to state', async () => { - const messenger = buildMessenger(); - await withController( - { - messenger, - state: { - providerConfig: buildProviderConfig({ - type: networkType, - }), + describe('when an "error" event occurs on the new provider', () => { + describe('if the network version could not be retrieved while providerConfig was being set', () => { + it('retrieves the network version twice more (due to the "error" event being listened to twice) and, assuming success, persists it to state', async () => { + const messenger = buildMessenger(); + await withController( + { + messenger, + state: { + providerConfig: buildProviderConfig({ + type: networkType, + }), + }, + infuraProjectId: 'infura-project-id', }, - infuraProjectId: 'infura-project-id', - }, - async ({ controller }) => { - const fakeInfuraProvider = buildFakeInfuraProvider(); - createInfuraProviderMock.mockReturnValue( - fakeInfuraProvider, - ); - const fakeInfuraSubprovider = buildFakeInfuraSubprovider(); - SubproviderMock.mockReturnValue(fakeInfuraSubprovider); - const fakeMetamaskProvider = buildFakeMetamaskProvider([ - { - request: { - method: 'net_version', - }, - response: { - error: 'oops', + async ({ controller }) => { + const fakeInfuraProvider = buildFakeInfuraProvider(); + createInfuraProviderMock.mockReturnValue( + fakeInfuraProvider, + ); + const fakeInfuraSubprovider = + buildFakeInfuraSubprovider(); + SubproviderMock.mockReturnValue(fakeInfuraSubprovider); + const fakeMetamaskProvider = buildFakeMetamaskProvider([ + { + request: { + method: 'net_version', + }, + response: { + error: 'oops', + }, }, - }, - { - request: { - method: 'net_version', + { + request: { + method: 'net_version', + }, + response: { + result: '1', + }, }, - response: { - result: '1', + { + request: { + method: 'net_version', + }, + response: { + result: '2', + }, }, - }, - { - request: { - method: 'net_version', + ]); + createMetamaskProviderMock.mockReturnValue( + fakeMetamaskProvider, + ); + + await waitForPublishedEvents( + messenger, + 'NetworkController:providerConfigChange', + { + produceEvents: () => { + controller.providerConfig = buildProviderConfig(); + assert( + controller.getProviderAndBlockTracker().provider, + ); + }, }, - response: { - result: '2', + ); + + await waitForStateChanges(messenger, { + propertyPath: ['network'], + count: 2, + produceStateChanges: () => { + controller + .getProviderAndBlockTracker() + .provider.emit('error', { some: 'error' }); }, - }, - ]); - createMetamaskProviderMock.mockReturnValue( - fakeMetamaskProvider, - ); + }); + expect(controller.state.network).toBe('2'); + }, + ); + }); + }); - await waitForPublishedEvents( + describe('if the network version could be retrieved while providerConfig was being set', () => { + it('does not retrieve the network version again', async () => { + const messenger = buildMessenger(); + await withController( + { messenger, - 'NetworkController:providerConfigChange', - { - produceEvents: () => { - controller.providerConfig = buildProviderConfig(); - assert( - controller.getProviderAndBlockTracker().provider, - ); - }, + state: { + providerConfig: buildProviderConfig({ + type: networkType, + }), }, - ); - - await waitForStateChanges(messenger, { - propertyPath: ['network'], - count: 2, - produceStateChanges: () => { - controller - .getProviderAndBlockTracker() - .provider.emit('error', { some: 'error' }); - }, - }); - expect(controller.state.network).toBe('2'); - }, - ); - }); - }); - - describe('if the network version could be retrieved while providerConfig was being set', () => { - it('does not retrieve the network version again', async () => { - const messenger = buildMessenger(); - await withController( - { - messenger, - state: { - providerConfig: buildProviderConfig({ - type: networkType, - }), + infuraProjectId: 'infura-project-id', }, - infuraProjectId: 'infura-project-id', - }, - async ({ controller }) => { - const fakeInfuraProvider = buildFakeInfuraProvider(); - createInfuraProviderMock.mockReturnValue( - fakeInfuraProvider, - ); - const fakeInfuraSubprovider = buildFakeInfuraSubprovider(); - SubproviderMock.mockReturnValue(fakeInfuraSubprovider); - const fakeMetamaskProvider = buildFakeMetamaskProvider([ - { - request: { - method: 'net_version', + async ({ controller }) => { + const fakeInfuraProvider = buildFakeInfuraProvider(); + createInfuraProviderMock.mockReturnValue( + fakeInfuraProvider, + ); + const fakeInfuraSubprovider = + buildFakeInfuraSubprovider(); + SubproviderMock.mockReturnValue(fakeInfuraSubprovider); + const fakeMetamaskProvider = buildFakeMetamaskProvider([ + { + request: { + method: 'net_version', + }, + response: { + result: '1', + }, }, - response: { - result: '1', + { + request: { + method: 'net_version', + }, + response: { + result: '2', + }, }, - }, - { - request: { - method: 'net_version', + ]); + createMetamaskProviderMock.mockReturnValue( + fakeMetamaskProvider, + ); + + await waitForPublishedEvents( + messenger, + 'NetworkController:providerConfigChange', + { + produceEvents: () => { + controller.providerConfig = buildProviderConfig(); + assert( + controller.getProviderAndBlockTracker().provider, + ); + }, }, - response: { - result: '2', + ); + + await waitForStateChanges(messenger, { + propertyPath: ['network'], + count: 0, + produceStateChanges: () => { + controller + .getProviderAndBlockTracker() + .provider.emit('error', { some: 'error' }); }, - }, - ]); - createMetamaskProviderMock.mockReturnValue( - fakeMetamaskProvider, - ); - - await waitForPublishedEvents( - messenger, - 'NetworkController:providerConfigChange', - { - produceEvents: () => { - controller.providerConfig = buildProviderConfig(); - assert( - controller.getProviderAndBlockTracker().provider, - ); - }, - }, - ); - - await waitForStateChanges(messenger, { - propertyPath: ['network'], - count: 0, - produceStateChanges: () => { - controller - .getProviderAndBlockTracker() - .provider.emit('error', { some: 'error' }); - }, - }); - expect(controller.state.network).toBe('1'); - }, - ); + }); + expect(controller.state.network).toBe('1'); + }, + ); + }); }); }); }); - }); - }); + }, + ); describe(`when the provider config in state contains a network type of "localhost"`, () => { it('sets the provider to a custom RPC provider pointed to localhost, initialized with the configured chain ID, nickname, and ticker', async () => { @@ -443,7 +448,7 @@ describe('NetworkController', () => { { state: { providerConfig: buildProviderConfig({ - type: 'localhost', + type: NetworkType.localhost, chainId: '66666', nickname: "doesn't matter", rpcUrl: 'http://doesntmatter.com', @@ -466,11 +471,11 @@ describe('NetworkController', () => { controller.providerConfig = buildProviderConfig({ // NOTE: The type does not need to match the type in state - type: 'mainnet', + type: NetworkType.mainnet, }); expect(createMetamaskProviderMock).toHaveBeenCalledWith({ - type: 'mainnet', + type: NetworkType.mainnet, chainId: undefined, engineParams: { pollingInterval: 12000 }, nickname: undefined, @@ -494,7 +499,7 @@ describe('NetworkController', () => { { state: { providerConfig: buildProviderConfig({ - type: 'localhost', + type: NetworkType.localhost, }), }, }, @@ -527,7 +532,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: buildProviderConfig({ - type: 'localhost', + type: NetworkType.localhost, }), }, }, @@ -598,7 +603,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: buildProviderConfig({ - type: 'localhost', + type: NetworkType.localhost, }), }, }, @@ -662,7 +667,7 @@ describe('NetworkController', () => { { state: { providerConfig: { - type: 'rpc', + type: NetworkType.rpc, chainId: '123', nickname: 'some cool network', rpcUrl: 'http://example.com', @@ -687,11 +692,11 @@ describe('NetworkController', () => { controller.providerConfig = buildProviderConfig({ // NOTE: The type does not need to match the type in state - type: 'mainnet', + type: NetworkType.mainnet, }); expect(createMetamaskProviderMock).toHaveBeenCalledWith({ - type: 'mainnet', + type: NetworkType.mainnet, chainId: '123', engineParams: { pollingInterval: 12000 }, nickname: 'some cool network', @@ -715,7 +720,7 @@ describe('NetworkController', () => { { state: { providerConfig: buildProviderConfig({ - type: 'rpc', + type: NetworkType.rpc, rpcUrl: 'http://example.com', }), }, @@ -749,7 +754,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: buildProviderConfig({ - type: 'rpc', + type: NetworkType.rpc, rpcUrl: 'http://example.com', }), }, @@ -821,7 +826,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: buildProviderConfig({ - type: 'rpc', + type: NetworkType.rpc, rpcUrl: 'http://example.com', }), }, @@ -885,7 +890,7 @@ describe('NetworkController', () => { { state: { providerConfig: buildProviderConfig({ - type: 'rpc', + type: NetworkType.rpc, }), }, }, @@ -1015,7 +1020,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'mainnet', + type: NetworkType.mainnet, chainId: '1', }, }, @@ -1044,7 +1049,7 @@ describe('NetworkController', () => { expect(providerConfigChanges).toStrictEqual([ [ { - type: 'mainnet', + type: NetworkType.mainnet, chainId: '1', }, ], @@ -1142,7 +1147,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'mainnet', + type: NetworkType.mainnet, chainId: '1', }, }, @@ -1171,7 +1176,7 @@ describe('NetworkController', () => { expect(providerConfigChanges).toStrictEqual([ [ { - type: 'mainnet', + type: NetworkType.mainnet, chainId: '1', }, ], @@ -1243,7 +1248,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'localhost', + type: NetworkType.localhost, rpcUrl: 'http://somethingexisting.com', chainId: '99999', ticker: 'something existing', @@ -1262,16 +1267,17 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['network'], produceStateChanges: () => { - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); }, }); expect(controller.state.providerConfig).toStrictEqual({ - type: 'mainnet', + type: NetworkType.mainnet, ticker: 'ETH', chainId: '1', rpcUrl: undefined, nickname: undefined, + id: undefined, }); }, ); @@ -1298,7 +1304,7 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['isCustomNetwork'], produceStateChanges: () => { - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); }, }); @@ -1329,10 +1335,10 @@ describe('NetworkController', () => { ]); createMetamaskProviderMock.mockReturnValue(fakeMetamaskProvider); - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); expect(createInfuraProviderMock).toHaveBeenCalledWith({ - network: 'mainnet', + network: NetworkType.mainnet, projectId: 'infura-project-id', }); expect(createMetamaskProviderMock).toHaveBeenCalledWith({ @@ -1383,7 +1389,7 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['networkDetails', 'isEIP1559Compatible'], produceStateChanges: () => { - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); }, }); @@ -1409,8 +1415,8 @@ describe('NetworkController', () => { .mockImplementationOnce(() => fakeMetamaskProviders[0]) .mockImplementationOnce(() => fakeMetamaskProviders[1]); - controller.setProviderType('mainnet' as const); - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); + controller.setProviderType(NetworkType.mainnet); assert(controller.getProviderAndBlockTracker().provider); jest.runAllTimers(); @@ -1441,7 +1447,7 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['network'], produceStateChanges: () => { - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); }, }); @@ -1483,7 +1489,7 @@ describe('NetworkController', () => { 'NetworkController:providerConfigChange', { produceEvents: () => { - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); assert(controller.getProviderAndBlockTracker().provider); }, }, @@ -1535,7 +1541,7 @@ describe('NetworkController', () => { 'NetworkController:providerConfigChange', { produceEvents: () => { - controller.setProviderType('mainnet' as const); + controller.setProviderType(NetworkType.mainnet); assert(controller.getProviderAndBlockTracker().provider); }, }, @@ -1560,13 +1566,13 @@ describe('NetworkController', () => { ( [ { - networkType: 'goerli', + networkType: NetworkType.goerli, ticker: 'GoerliETH', chainId: '5', networkName: 'Goerli', }, { - networkType: 'sepolia', + networkType: NetworkType.sepolia, ticker: 'SepoliaETH', chainId: '11155111', networkName: 'Sepolia', @@ -1581,7 +1587,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'localhost', + type: NetworkType.localhost, rpcUrl: 'http://somethingexisting.com', chainId: '99999', ticker: 'something existing', @@ -1610,6 +1616,7 @@ describe('NetworkController', () => { chainId, rpcUrl: undefined, nickname: undefined, + id: undefined, }); }, ); @@ -1907,7 +1914,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'localhost', + type: NetworkType.localhost, rpcUrl: 'http://somethingexisting.com', chainId: '99999', ticker: 'something existing', @@ -1922,16 +1929,17 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['providerConfig'], produceStateChanges: () => { - controller.setProviderType('rpc' as const); + controller.setProviderType(NetworkType.rpc); }, }); expect(controller.state.providerConfig).toStrictEqual({ - type: 'rpc', + type: NetworkType.rpc, ticker: 'ETH', chainId: '', rpcUrl: undefined, nickname: undefined, + id: undefined, }); }, ); @@ -1955,7 +1963,7 @@ describe('NetworkController', () => { { propertyPath: ['isCustomNetwork'] }, ); - controller.setProviderType('rpc' as const); + controller.setProviderType(NetworkType.rpc); await expect(promiseForIsCustomNetworkChange).toNeverResolve(); }, @@ -1967,7 +1975,7 @@ describe('NetworkController', () => { const fakeMetamaskProvider = buildFakeMetamaskProvider(); createMetamaskProviderMock.mockReturnValue(fakeMetamaskProvider); - controller.setProviderType('rpc' as const); + controller.setProviderType(NetworkType.rpc); expect(createMetamaskProviderMock).not.toHaveBeenCalled(); expect( @@ -2001,7 +2009,7 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['networkDetails', 'isEIP1559Compatible'], produceStateChanges: () => { - controller.setProviderType('rpc' as const); + controller.setProviderType(NetworkType.rpc); }, }); @@ -2021,7 +2029,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'localhost', + type: NetworkType.localhost, rpcUrl: 'http://somethingexisting.com', chainId: '99999', ticker: 'something existing', @@ -2036,16 +2044,17 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['network'], produceStateChanges: () => { - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); }, }); expect(controller.state.providerConfig).toStrictEqual({ - type: 'localhost', + type: NetworkType.localhost, ticker: 'ETH', chainId: '', rpcUrl: undefined, nickname: undefined, + id: undefined, }); }, ); @@ -2067,7 +2076,7 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['isCustomNetwork'], produceStateChanges: () => { - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); }, }); @@ -2090,7 +2099,7 @@ describe('NetworkController', () => { ]); createMetamaskProviderMock.mockReturnValue(fakeMetamaskProvider); - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); expect(createMetamaskProviderMock).toHaveBeenCalledWith({ chainId: undefined, @@ -2131,7 +2140,7 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['networkDetails', 'isEIP1559Compatible'], produceStateChanges: () => { - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); }, }); @@ -2152,8 +2161,8 @@ describe('NetworkController', () => { .mockImplementationOnce(() => fakeMetamaskProviders[0]) .mockImplementationOnce(() => fakeMetamaskProviders[1]); - controller.setProviderType('localhost' as const); - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); + controller.setProviderType(NetworkType.localhost); assert(controller.getProviderAndBlockTracker().provider); jest.runAllTimers(); @@ -2180,7 +2189,7 @@ describe('NetworkController', () => { await waitForStateChanges(messenger, { propertyPath: ['network'], produceStateChanges: () => { - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); }, }); @@ -2218,7 +2227,7 @@ describe('NetworkController', () => { 'NetworkController:providerConfigChange', { produceEvents: () => { - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); assert(controller.getProviderAndBlockTracker().provider); }, }, @@ -2266,7 +2275,7 @@ describe('NetworkController', () => { 'NetworkController:providerConfigChange', { produceEvents: () => { - controller.setProviderType('localhost' as const); + controller.setProviderType(NetworkType.localhost); assert(controller.getProviderAndBlockTracker().provider); }, }, @@ -2298,7 +2307,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'localhost', + type: NetworkType.localhost, rpcUrl: 'http://somethingexisting.com', chainId: '99999', ticker: 'something existing', @@ -2333,7 +2342,7 @@ describe('NetworkController', () => { }); expect(controller.state.providerConfig).toStrictEqual({ - type: 'rpc', + type: NetworkType.rpc, rpcUrl: 'https://mock-rpc-url', chainId: '0xtest', ticker: 'TEST', @@ -3474,7 +3483,7 @@ describe('NetworkController', () => { messenger, state: { providerConfig: { - type: 'mainnet', + type: NetworkType.mainnet, chainId: '1', }, }, @@ -3485,7 +3494,7 @@ describe('NetworkController', () => { ); expect(providerConfig).toStrictEqual({ - type: 'mainnet', + type: NetworkType.mainnet, chainId: '1', }); }, @@ -3646,6 +3655,184 @@ describe('NetworkController', () => { }); }); }); + + describe('rollbackToPreviousProvider', () => { + it('should overwrite the current provider with the previous provider when current provider has type "rpc" and previous provider has type "mainnet"', async () => { + const messenger = buildMessenger(); + const networkConfiguration = { + rpcUrl: 'https://mock-rpc-url', + chainId: '0xtest', + ticker: 'TEST', + nickname: undefined, + id: 'testNetworkConfigurationId', + }; + + const initialProviderConfig = { + ...buildProviderConfig({ + type: NetworkType.mainnet, + chainId: '1', + ticker: 'ETH', + }), + }; + await withController( + { + messenger, + state: { + networkConfigurations: { + testNetworkConfigurationId: networkConfiguration, + }, + providerConfig: initialProviderConfig, + }, + }, + async ({ controller }) => { + const fakeMetamaskProvider = buildFakeMetamaskProvider(); + createMetamaskProviderMock.mockReturnValue(fakeMetamaskProvider); + controller.setActiveNetwork('testNetworkConfigurationId'); + expect(controller.state.providerConfig).toStrictEqual({ + ...networkConfiguration, + type: NetworkType.rpc, + }); + controller.rollbackToPreviousProvider(); + expect(controller.state.providerConfig).toStrictEqual( + initialProviderConfig, + ); + }, + ); + }); + + it('should overwrite the current provider with the previous provider when current provider has type "mainnet" and previous provider has type "rpc"', async () => { + const messenger = buildMessenger(); + const networkConfiguration = { + rpcUrl: 'https://mock-rpc-url', + chainId: '0xtest', + ticker: 'TEST', + nickname: undefined, + id: 'testNetworkConfigurationId', + }; + + const initialProviderConfig = { + ...buildProviderConfig({ + ...networkConfiguration, + }), + type: NetworkType.rpc, + }; + await withController( + { + messenger, + state: { + networkConfigurations: { + testNetworkConfigurationId: networkConfiguration, + }, + providerConfig: initialProviderConfig, + }, + }, + async ({ controller }) => { + const fakeMetamaskProvider = buildFakeMetamaskProvider(); + createMetamaskProviderMock.mockReturnValue(fakeMetamaskProvider); + controller.setProviderType(NetworkType.mainnet); + expect(controller.state.providerConfig).toStrictEqual({ + type: NetworkType.mainnet, + chainId: '1', + ticker: 'ETH', + nickname: undefined, + id: undefined, + rpcUrl: undefined, + }); + controller.rollbackToPreviousProvider(); + expect(controller.state.providerConfig).toStrictEqual({ + ...networkConfiguration, + type: NetworkType.rpc, + }); + }, + ); + }); + + it('should overwrite the current provider with the previous provider when current provider has type "mainnet" and previous provider has type "sepolia"', async () => { + const messenger = buildMessenger(); + const initialProviderConfig = { + ...buildProviderConfig({ + type: NetworkType.mainnet, + chainId: '1', + ticker: 'ETH', + }), + }; + await withController( + { + messenger, + state: { + providerConfig: initialProviderConfig, + }, + }, + async ({ controller }) => { + const fakeMetamaskProvider = buildFakeMetamaskProvider(); + createMetamaskProviderMock.mockReturnValue(fakeMetamaskProvider); + controller.setProviderType(NetworkType.sepolia); + expect(controller.state.providerConfig).toStrictEqual({ + ...buildProviderConfig({ + type: NetworkType.sepolia, + chainId: '11155111', + ticker: 'SepoliaETH', + }), + }); + controller.rollbackToPreviousProvider(); + expect(controller.state.providerConfig).toStrictEqual( + initialProviderConfig, + ); + }, + ); + }); + + it('should overwrite the current provider with the previous provider when current provider has type "rpc" and previous provider has type "rpc"', async () => { + const messenger = buildMessenger(); + const networkConfiguration1 = { + rpcUrl: 'https://mock-rpc-url', + chainId: '0xtest', + ticker: 'TEST', + id: 'testNetworkConfigurationId', + nickname: 'test-network-1', + }; + + const networkConfiguration2 = { + rpcUrl: 'https://mock-rpc-url-2', + chainId: '0xtest2', + ticker: 'TEST2', + id: 'testNetworkConfigurationId2', + nickname: 'test-network-2', + }; + + const initialProviderConfig = { + ...buildProviderConfig({ + ...networkConfiguration1, + type: NetworkType.rpc, + }), + }; + await withController( + { + messenger, + state: { + networkConfigurations: { + testNetworkConfigurationId: networkConfiguration1, + testNetworkConfigurationId2: networkConfiguration2, + }, + providerConfig: initialProviderConfig, + }, + }, + async ({ controller }) => { + const fakeMetamaskProvider = buildFakeMetamaskProvider(); + createMetamaskProviderMock.mockReturnValue(fakeMetamaskProvider); + controller.setActiveNetwork('testNetworkConfigurationId2'); + expect(controller.state.providerConfig).toStrictEqual({ + ...networkConfiguration2, + type: NetworkType.rpc, + }); + controller.rollbackToPreviousProvider(); + expect(controller.state.providerConfig).toStrictEqual( + initialProviderConfig, + ); + }, + ); + }); + }); }); /** @@ -3718,7 +3905,14 @@ async function withController( * @returns The complete ProviderConfig object. */ function buildProviderConfig(config: Partial = {}) { - return { type: 'localhost' as const, chainId: '1337', ...config }; + return { + type: NetworkType.localhost, + chainId: '1337', + id: undefined, + nickname: undefined, + rpcUrl: undefined, + ...config, + }; } /** diff --git a/packages/network-controller/tests/provider-api-tests/helpers.ts b/packages/network-controller/tests/provider-api-tests/helpers.ts index 779dda64f43..68282957d97 100644 --- a/packages/network-controller/tests/provider-api-tests/helpers.ts +++ b/packages/network-controller/tests/provider-api-tests/helpers.ts @@ -284,7 +284,7 @@ export type MockCommunications = { export const withMockedCommunications = async ( { providerType, - infuraNetwork = 'mainnet', + infuraNetwork = NetworkType.mainnet, customRpcUrl = MOCK_RPC_URL, }: MockOptions, fn: (comms: MockCommunications) => Promise, @@ -367,7 +367,7 @@ export const waitForPromiseToBeFulfilledAfterRunningAllTimers = async ( export const withNetworkClient = async ( { providerType, - infuraNetwork = 'mainnet', + infuraNetwork = NetworkType.mainnet, customRpcUrl = MOCK_RPC_URL, }: MockOptions, fn: (client: MockNetworkClient) => Promise, diff --git a/packages/network-controller/tests/provider-api-tests/shared-tests.ts b/packages/network-controller/tests/provider-api-tests/shared-tests.ts index 1ea74a8bb38..a25937ca59e 100644 --- a/packages/network-controller/tests/provider-api-tests/shared-tests.ts +++ b/packages/network-controller/tests/provider-api-tests/shared-tests.ts @@ -1,5 +1,6 @@ /* eslint-disable jest/require-top-level-describe, jest/no-export, jest/no-identical-title, jest/no-if */ +import { NetworkType } from '@metamask/controller-utils/src'; import { testsForRpcMethodsThatCheckForBlockHashInResponse } from './block-hash-in-response'; import { testsForRpcMethodSupportingBlockParam } from './block-param'; import { @@ -367,14 +368,14 @@ export const testsForProviderType = (providerType: ProviderType) => { describe('net_version', () => { it('does hit RPC endpoint to get net_version', async () => { await withMockedCommunications( - { providerType, infuraNetwork: 'goerli' }, + { providerType, infuraNetwork: NetworkType.goerli }, async (comms) => { comms.mockRpcCall({ request: { method: 'net_version' }, response: { result: '5' }, }); const networkId = await withNetworkClient( - { providerType, infuraNetwork: 'goerli' }, + { providerType, infuraNetwork: NetworkType.goerli }, ({ makeRpcCall }) => { return makeRpcCall({ method: 'net_version',