Skip to content

Commit

Permalink
add rollbackToPreviousProvider method
Browse files Browse the repository at this point in the history
  • Loading branch information
adonesky1 committed Mar 17, 2023
1 parent 479e349 commit dfc3acd
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 308 deletions.
5 changes: 3 additions & 2 deletions packages/assets-controllers/src/NftController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
OPENSEA_API_URL,
ERC721,
NetworksChainId,
NetworkType,
} from '@metamask/controller-utils';
import { Network } from '@ethersproject/providers';
import { AssetsContractController } from './AssetsContractController';
Expand Down Expand Up @@ -46,8 +47,8 @@ const DEPRESSIONIST_CLOUDFLARE_IPFS_SUBDOMAIN_PATH = getFormattedIpfsUrl(
true,
);

const SEPOLIA = { chainId: '11155111', type: 'sepolia' as const };
const GOERLI = { chainId: '5', type: 'goerli' as const };
const SEPOLIA = { chainId: '11155111', type: NetworkType.sepolia };
const GOERLI = { chainId: '5', type: NetworkType.goerli };

// Mock out detectNetwork function for cleaner tests, Ethers calls this a bunch of times because the Web3Provider is paranoid.
jest.mock('@ethersproject/providers', () => {
Expand Down
2 changes: 1 addition & 1 deletion packages/controller-utils/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module.exports = merge(baseConfig, {
coverageThreshold: {
global: {
branches: 68.05,
functions: 80.55,
functions: 76.92,
lines: 69.82,
statements: 70.17,
},
Expand Down
17 changes: 11 additions & 6 deletions packages/controller-utils/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
/**
* Human-readable network name
*/
export type NetworkType =
| 'localhost'
| 'mainnet'
| 'goerli'
| 'sepolia'
| 'rpc';
export enum NetworkType {
localhost = 'localhost',
mainnet = 'mainnet',
goerli = 'goerli',
sepolia = 'sepolia',
rpc = 'rpc',
}

export const isNetworkType = (val: any): val is NetworkType => {
return Object.values(NetworkType).includes(val);
};

export enum NetworksChainId {
mainnet = '1',
Expand Down
56 changes: 46 additions & 10 deletions packages/network-controller/src/NetworkController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ import {
RestrictedControllerMessenger,
} from '@metamask/base-controller';
import {
MAINNET,
RPC,
TESTNET_NETWORK_TYPE_TO_TICKER_SYMBOL,
NetworksChainId,
NetworkType,
isSafeChainId,
} from '@metamask/controller-utils';

import { assertIsStrictHexString } from '@metamask/utils';
import { isNetworkType } from '../../controller-utils/src/types';

/**
* @type ProviderConfig
Expand All @@ -39,7 +38,7 @@ export type ProviderConfig = {
chainId: string;
ticker?: string;
nickname?: string;
id?: string;
id?: NetworkConfigurationId;
};

export type Block = {
Expand Down Expand Up @@ -148,7 +147,10 @@ export type NetworkControllerOptions = {
export const defaultState: NetworkState = {
network: 'loading',
isCustomNetwork: false,
providerConfig: { type: MAINNET, chainId: NetworksChainId.mainnet },
providerConfig: {
type: NetworkType.mainnet,
chainId: NetworksChainId.mainnet,
},
networkDetails: { isEIP1559Compatible: false },
networkConfigurations: {},
};
Expand All @@ -166,6 +168,8 @@ type MetaMetricsEventPayload = {
value?: number;
};

type NetworkConfigurationId = string;

/**
* Controller that creates and manages an Ethereum network provider.
*/
Expand All @@ -184,6 +188,8 @@ export class NetworkController extends BaseControllerV2<

private mutex = new Mutex();

#previousNetworkSpecifier: NetworkType | NetworkConfigurationId | null;

#provider: Provider | undefined;

#providerProxy: ProviderProxy | undefined;
Expand Down Expand Up @@ -238,6 +244,8 @@ export class NetworkController extends BaseControllerV2<
return this.ethQuery;
},
);

this.#previousNetworkSpecifier = this.state.providerConfig.type;
}

private initializeProvider(
Expand All @@ -252,15 +260,15 @@ export class NetworkController extends BaseControllerV2<
});

switch (type) {
case MAINNET:
case 'goerli':
case 'sepolia':
case NetworkType.mainnet:
case NetworkType.goerli:
case NetworkType.sepolia:
this.setupInfuraProvider(type);
break;
case 'localhost':
case NetworkType.localhost:
this.setupStandardProvider(LOCALHOST_RPC_URL);
break;
case RPC:
case NetworkType.rpc:
rpcUrl && this.setupStandardProvider(rpcUrl, chainId, ticker, nickname);
break;
default:
Expand Down Expand Up @@ -433,12 +441,25 @@ export class NetworkController extends BaseControllerV2<
}
}

/**
* Convenience method to set the current provider config to the private providerConfig class variable.
*/
#setCurrentAsPreviousProvider() {
const { type, id } = this.state.providerConfig;
if (type === NetworkType.rpc && id) {
this.#previousNetworkSpecifier = id;
} else {
this.#previousNetworkSpecifier = type;
}
}

/**
* Convenience method to update provider network type settings.
*
* @param type - Human readable network name.
*/
setProviderType(type: NetworkType) {
this.#setCurrentAsPreviousProvider();
// If testnet the ticker symbol should use a testnet prefix
const ticker =
type in TESTNET_NETWORK_TYPE_TO_TICKER_SYMBOL &&
Expand All @@ -452,6 +473,7 @@ export class NetworkController extends BaseControllerV2<
state.providerConfig.chainId = NetworksChainId[type];
state.providerConfig.rpcUrl = undefined;
state.providerConfig.nickname = undefined;
state.providerConfig.id = undefined;
});
this.refreshNetwork();
}
Expand All @@ -462,6 +484,8 @@ export class NetworkController extends BaseControllerV2<
* @param networkConfigurationId - The unique id for the network configuration to set as the active provider.
*/
setActiveNetwork(networkConfigurationId: string) {
this.#setCurrentAsPreviousProvider();

const targetNetwork =
this.state.networkConfigurations[networkConfigurationId];

Expand All @@ -472,7 +496,7 @@ export class NetworkController extends BaseControllerV2<
}

this.update((state) => {
state.providerConfig.type = RPC;
state.providerConfig.type = NetworkType.rpc;
state.providerConfig.rpcUrl = targetNetwork.rpcUrl;
state.providerConfig.chainId = targetNetwork.chainId;
state.providerConfig.ticker = targetNetwork.ticker;
Expand Down Expand Up @@ -663,6 +687,18 @@ export class NetworkController extends BaseControllerV2<
delete state.networkConfigurations[networkConfigurationId];
});
}

/**
* Rolls back provider config to the previous provider in case of errors or inability to connect during network switch.
*/
rollbackToPreviousProvider() {
const specifier = this.#previousNetworkSpecifier;
if (isNetworkType(specifier)) {
this.setProviderType(specifier);
} else if (typeof specifier === 'string') {
this.setActiveNetwork(specifier);
}
}
}

export default NetworkController;
Loading

0 comments on commit dfc3acd

Please sign in to comment.