diff --git a/contracts/ModuleRegistry.sol b/contracts/ModuleRegistry.sol index 59a02af74..a11111e9f 100644 --- a/contracts/ModuleRegistry.sol +++ b/contracts/ModuleRegistry.sol @@ -36,6 +36,7 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { */ bytes32 constant INITIALIZE = 0x9ef7257c3339b099aacf96e55122ee78fb65a36bd2a6c19249882be9c98633bf; //keccak256("initialised") + bytes32 constant LOCKED = 0xab99c6d7581cbb37d2e578d3097bfdd3323e05447f1fd7670b6c3a3fb9d9ff79; //keccak256("locked") bytes32 constant POLYTOKEN = 0xacf8fbd51bb4b83ba426cdb12f63be74db97c412515797993d2a385542e311d7; //keccak256("polyToken") bytes32 constant PAUSED = 0xee35723ac350a69d2a92d3703f17439cbaadf2f093a21ba5bf5f1a53eb2a14d9; //keccak256("paused") bytes32 constant OWNER = 0x02016836a56b71f0d02689e69e326f4f4c1b9057164ef592671cf0d37c8040c0; //keccak256("owner") @@ -87,6 +88,16 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { } } + /** + * @notice Modifier to prevent reentrancy + */ + modifier nonReentrant() { + set(LOCKED, getUintValue(LOCKED) + 1); + uint256 localCounter = getUintValue(LOCKED); + _; + require(localCounter == getUintValue(LOCKED)); + } + /** * @notice Modifier to make a function callable only when the contract is not paused and ignore is msg.sender is owner. */ @@ -121,6 +132,10 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { set(INITIALIZE, true); } + function _customModules() internal view returns (bool) { + return IFeatureRegistry(getAddressValue(FEATURE_REGISTRY)).getFeatureStatus("customModulesAllowed"); + } + /** * @notice Called by a SecurityToken to check if the ModuleFactory is verified or appropriate custom module * @dev ModuleFactory reputation increases by one every time it is deployed(used) by a ST. @@ -129,10 +144,11 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { * @param _moduleFactory is the address of the relevant module factory * @param _isUpgrade whether or not the function is being called as a result of an upgrade */ - function useModule(address _moduleFactory, bool _isUpgrade) external { - if (IFeatureRegistry(getAddressValue(FEATURE_REGISTRY)).getFeatureStatus("customModulesAllowed")) { + function useModule(address _moduleFactory, bool _isUpgrade) external nonReentrant { + if (_customModules()) { require( - getBoolValue(Encoder.getKey("verified", _moduleFactory)) || IOwnable(_moduleFactory).owner() == IOwnable(msg.sender).owner(), + getBoolValue(Encoder.getKey("verified", _moduleFactory)) || getAddressValue(Encoder.getKey("factoryOwner", _moduleFactory)) + == IOwnable(msg.sender).owner(), "ModuleFactory must be verified or SecurityToken owner must be ModuleFactory owner" ); } else { @@ -168,10 +184,13 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { * @notice Called by the ModuleFactory owner to register new modules for SecurityTokens to use * @param _moduleFactory is the address of the module factory to be registered */ - function registerModule(address _moduleFactory) external whenNotPausedOrOwner { - if (IFeatureRegistry(getAddressValue(FEATURE_REGISTRY)).getFeatureStatus("customModulesAllowed")) { + function registerModule(address _moduleFactory) external whenNotPausedOrOwner nonReentrant { + address factoryOwner = IOwnable(_moduleFactory).owner(); + // This is set statically to avoid having to call back out to unverified factories to determine owner + set(Encoder.getKey("factoryOwner", _moduleFactory), factoryOwner); + if (_customModules()) { require( - msg.sender == IOwnable(_moduleFactory).owner() || msg.sender == owner(), + msg.sender == factoryOwner || msg.sender == owner(), "msg.sender must be the Module Factory owner or registry curator" ); } else { @@ -190,14 +209,15 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { } require(moduleTypes.length != 0, "Factory must have type"); // NB - here we index by the first type of the module. - uint8 moduleType = moduleFactory.types()[0]; + uint8 moduleType = moduleTypes[0]; + require(uint256(moduleType) != 0, "Invalid type"); set(Encoder.getKey("registry", _moduleFactory), uint256(moduleType)); set( Encoder.getKey("moduleListIndex", _moduleFactory), uint256(getArrayAddress(Encoder.getKey("moduleList", uint256(moduleType))).length) ); pushArray(Encoder.getKey("moduleList", uint256(moduleType)), _moduleFactory); - emit ModuleRegistered(_moduleFactory, IOwnable(_moduleFactory).owner()); + emit ModuleRegistered(_moduleFactory, factoryOwner); } /** @@ -209,7 +229,7 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { require(moduleType != 0, "Module factory should be registered"); require( - msg.sender == IOwnable(_moduleFactory).owner() || msg.sender == owner(), + msg.sender == owner() || msg.sender == getAddressValue(Encoder.getKey("factoryOwner", _moduleFactory)), "msg.sender must be the Module Factory owner or registry curator" ); uint256 index = getUintValue(Encoder.getKey("moduleListIndex", _moduleFactory)); @@ -232,6 +252,8 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { set(Encoder.getKey("verified", _moduleFactory), false); // delete moduleListIndex[_moduleFactory]; set(Encoder.getKey("moduleListIndex", _moduleFactory), uint256(0)); + // delete module owner + set(Encoder.getKey("factoryOwner", _moduleFactory), address(0)); emit ModuleRemoved(_moduleFactory, msg.sender); } @@ -255,12 +277,12 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { * @notice -> Only if Polymath enabled the feature. * @param _moduleFactory is the address of the module factory to be verified */ - function unverifyModule(address _moduleFactory) external { + function unverifyModule(address _moduleFactory) external nonReentrant { // Can be called by the registry owner, the module factory, or the module factory owner bool isOwner = msg.sender == owner(); - bool isFactoryOwner = msg.sender == IOwnable(_moduleFactory).owner(); bool isFactory = msg.sender == _moduleFactory; - require(isOwner || isFactoryOwner || isFactory, "Not authorised"); + bool isFactoryOwner = msg.sender == getAddressValue(Encoder.getKey("factoryOwner", _moduleFactory)); + require(isOwner || isFactory || isFactoryOwner, "Not authorised"); require(getUintValue(Encoder.getKey("registry", _moduleFactory)) != uint256(0), "Module factory must be registered"); set(Encoder.getKey("verified", _moduleFactory), false); emit ModuleUnverified(_moduleFactory); @@ -300,6 +322,7 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { uint256 i; uint256 j; for (i = 0; i < _modules.length; i++) { + // NB - a malicious unverified module could throw on .tags() counter = counter + IModuleFactory(_modules[i]).tags().length; } bytes32[] memory tags = new bytes32[](counter); @@ -323,16 +346,42 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { * @return bool indicating whether module factory is verified * @return address array which contains the list of securityTokens that use that module factory */ - function getFactoryDetails(address _factoryAddress) external view returns(bool, address[] memory) { - return (getBoolValue(Encoder.getKey("verified", _factoryAddress)), getArrayAddress(Encoder.getKey("reputation", _factoryAddress))); + function getFactoryDetails(address _factoryAddress) external view returns(bool, address, address[] memory) { + return (getBoolValue(Encoder.getKey("verified", _factoryAddress)), getAddressValue(Encoder.getKey("factoryOwner", _factoryAddress)), getArrayAddress(Encoder.getKey("reputation", _factoryAddress))); } /** - * @notice Returns the list of addresses of Module Factory of a particular type + * @notice Returns the list of addresses of verified Module Factory of a particular type * @param _moduleType Type of Module * @return address array that contains the list of addresses of module factory contracts. */ function getModulesByType(uint8 _moduleType) public view returns(address[] memory) { + address[] memory _addressList = getArrayAddress(Encoder.getKey("moduleList", uint256(_moduleType))); + uint256 _len = _addressList.length; + uint256 counter = 0; + for (uint256 i = 0; i < _len; i++) { + if (getBoolValue(Encoder.getKey("verified", _addressList[i]))) { + counter++; + } + } + address[] memory _tempArray = new address[](counter); + counter = 0; + for (uint256 j = 0; j < _len; j++) { + if (getBoolValue(Encoder.getKey("verified", _addressList[j]))) { + _tempArray[counter] = _addressList[j]; + counter++; + } + } + return _tempArray; + } + + + /** + * @notice Returns the list of addresses of all Module Factory of a particular type + * @param _moduleType Type of Module + * @return address array that contains the list of addresses of module factory contracts. + */ + function getAllModulesByType(uint8 _moduleType) external view returns(address[] memory) { return getArrayAddress(Encoder.getKey("moduleList", uint256(_moduleType))); } @@ -345,15 +394,13 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { function getModulesByTypeAndToken(uint8 _moduleType, address _securityToken) public view returns(address[] memory) { address[] memory _addressList = getArrayAddress(Encoder.getKey("moduleList", uint256(_moduleType))); uint256 _len = _addressList.length; - bool _isCustomModuleAllowed = IFeatureRegistry(getAddressValue(FEATURE_REGISTRY)).getFeatureStatus( - "customModulesAllowed" - ); + bool _isCustomModuleAllowed = _customModules(); uint256 counter = 0; for (uint256 i = 0; i < _len; i++) { if (_isCustomModuleAllowed) { - if (IOwnable(_addressList[i]).owner() == IOwnable(_securityToken).owner() || getBoolValue( - Encoder.getKey("verified", _addressList[i]) - )) if (isCompatibleModule(_addressList[i], _securityToken)) counter++; + if (getBoolValue( + Encoder.getKey("verified", _addressList[i])) || getAddressValue(Encoder.getKey("factoryOwner", _addressList[i])) == IOwnable(_securityToken).owner() + ) if (isCompatibleModule(_addressList[i], _securityToken)) counter++; } else if (getBoolValue(Encoder.getKey("verified", _addressList[i]))) { if (isCompatibleModule(_addressList[i], _securityToken)) counter++; } @@ -362,7 +409,7 @@ contract ModuleRegistry is IModuleRegistry, EternalStorage { counter = 0; for (uint256 j = 0; j < _len; j++) { if (_isCustomModuleAllowed) { - if (IOwnable(_addressList[j]).owner() == IOwnable(_securityToken).owner() || getBoolValue( + if (getAddressValue(Encoder.getKey("factoryOwner", _addressList[j])) == IOwnable(_securityToken).owner() || getBoolValue( Encoder.getKey("verified", _addressList[j]) )) { if (isCompatibleModule(_addressList[j], _securityToken)) { diff --git a/contracts/interfaces/IModuleRegistry.sol b/contracts/interfaces/IModuleRegistry.sol index a1bff3dcb..c2bff56ce 100644 --- a/contracts/interfaces/IModuleRegistry.sol +++ b/contracts/interfaces/IModuleRegistry.sol @@ -53,7 +53,7 @@ interface IModuleRegistry { * @return bool indicating whether module factory is verified * @return address array which contains the list of securityTokens that use that module factory */ - function getFactoryDetails(address _factoryAddress) external view returns(bool, address[] memory); + function getFactoryDetails(address _factoryAddress) external view returns(bool, address, address[] memory); /** * @notice Returns all the tags related to the a module type which are valid for the given token @@ -72,6 +72,12 @@ interface IModuleRegistry { */ function getTagsByType(uint8 _moduleType) external view returns(bytes32[] memory, address[] memory); + /** + * @notice Returns the list of addresses of all Module Factory of a particular type + * @param _moduleType Type of Module + * @return address array that contains the list of addresses of module factory contracts. + */ + function getAllModulesByType(uint8 _moduleType) external view returns(address[] memory); /** * @notice Returns the list of addresses of Module Factory of a particular type * @param _moduleType Type of Module diff --git a/test/k_module_registry.js b/test/k_module_registry.js index 2acabae44..16cc89ff9 100644 --- a/test/k_module_registry.js +++ b/test/k_module_registry.js @@ -246,11 +246,11 @@ contract("ModuleRegistry", async (accounts) => { assert.equal(tx.logs[0].args._owner, account_polymath, "Should be the right owner"); let _list = await I_MRProxied.getModulesByType(transferManagerKey); - assert.equal(_list.length, 1, "Length should be 1"); - assert.equal(_list[0], I_GeneralTransferManagerFactory.address); + assert.equal(_list.length, 0, "Length should be 0 - unverified"); + // assert.equal(_list[0], I_GeneralTransferManagerFactory.address); let _reputation = await I_MRProxied.getFactoryDetails(I_GeneralTransferManagerFactory.address); - assert.equal(_reputation[1].length, 0); + assert.equal(_reputation[2].length, 0); }); it("Should fail the register the module -- Already registered module", async () => { @@ -284,7 +284,11 @@ contract("ModuleRegistry", async (accounts) => { let tx = await I_MRProxied.verifyModule(I_GeneralTransferManagerFactory.address, { from: account_polymath }); assert.equal(tx.logs[0].args._moduleFactory, I_GeneralTransferManagerFactory.address, "Failed in verifying the module"); let info = await I_MRProxied.getFactoryDetails.call(I_GeneralTransferManagerFactory.address); + let _list = await I_MRProxied.getModulesByType(transferManagerKey); + assert.equal(_list.length, 1, "Length should be 1"); + assert.equal(_list[0], I_GeneralTransferManagerFactory.address); assert.equal(info[0], true); + assert.equal(info[1], account_polymath); }); it("Should successfully verify the module -- false", async () => { @@ -357,7 +361,7 @@ contract("ModuleRegistry", async (accounts) => { "CappedSTOFactory module was not added" ); let _reputation = await I_MRProxied.getFactoryDetails.call(I_CappedSTOFactory2.address); - assert.equal(_reputation[1].length, 1); + assert.equal(_reputation[2].length, 1); }); it("Should successfully add module when custom modules switched on -- fail because factory owner is different", async () => { @@ -475,55 +479,52 @@ contract("ModuleRegistry", async (accounts) => { it("Should successfully remove module and delete data if msg.sender is curator", async () => { let snap = await takeSnapshot(); - + console.log("All modules: " + (await I_MRProxied.getModulesByType.call(3))); let sto1 = (await I_MRProxied.getModulesByType.call(3))[0]; - let sto2 = (await I_MRProxied.getModulesByType.call(3))[1]; - let sto3 = (await I_MRProxied.getModulesByType.call(3))[2]; - let sto4 = (await I_MRProxied.getModulesByType.call(3))[3]; + // let sto2 = (await I_MRProxied.getModulesByType.call(3))[1]; + // let sto3 = (await I_MRProxied.getModulesByType.call(3))[2]; + // let sto4 = (await I_MRProxied.getModulesByType.call(3))[3]; - assert.equal(sto1, I_CappedSTOFactory1.address); - assert.equal(sto2, I_CappedSTOFactory2.address); - assert.equal((await I_MRProxied.getModulesByType.call(3)).length, 4); + assert.equal(sto1, I_TestSTOFactory.address); + assert.equal((await I_MRProxied.getModulesByType.call(3)).length, 1); - let tx = await I_MRProxied.removeModule(sto4, { from: account_polymath }); + let tx = await I_MRProxied.removeModule(sto1, { from: account_polymath }); - assert.equal(tx.logs[0].args._moduleFactory, sto4, "Event is not properly emitted for _moduleFactory"); + assert.equal(tx.logs[0].args._moduleFactory, sto1, "Event is not properly emitted for _moduleFactory"); assert.equal(tx.logs[0].args._decisionMaker, account_polymath, "Event is not properly emitted for _decisionMaker"); let sto3_end = (await I_MRProxied.getModulesByType.call(3))[2]; - // re-ordering - assert.equal(sto3_end, sto3); // delete related data - assert.equal(await I_MRProxied.getUintValue.call(web3.utils.soliditySha3("registry", sto4)), 0); - assert.equal((await I_MRProxied.getFactoryDetails.call(sto4))[1], 0); - assert.equal((await I_MRProxied.getModulesByType.call(3)).length, 3); - assert.equal(await I_MRProxied.getBoolValue.call(web3.utils.soliditySha3("verified", sto4)), false); + assert.equal(await I_MRProxied.getUintValue.call(web3.utils.soliditySha3("registry", sto1)), 0); + assert.equal((await I_MRProxied.getFactoryDetails.call(sto1))[1], 0); + assert.equal((await I_MRProxied.getModulesByType.call(3)).length, 0); + assert.equal(await I_MRProxied.getBoolValue.call(web3.utils.soliditySha3("verified", sto1)), false); await revertToSnapshot(snap); }); it("Should successfully remove module and delete data if msg.sender is owner", async () => { - let sto1 = (await I_MRProxied.getModulesByType.call(3))[0]; - let sto2 = (await I_MRProxied.getModulesByType.call(3))[1]; + let sto1 = (await I_MRProxied.getAllModulesByType.call(3))[0]; + let sto2 = (await I_MRProxied.getAllModulesByType.call(3))[1]; assert.equal(sto1, I_CappedSTOFactory1.address); assert.equal(sto2, I_CappedSTOFactory2.address); - assert.equal((await I_MRProxied.getModulesByType.call(3)).length, 4); + assert.equal((await I_MRProxied.getAllModulesByType.call(3)).length, 4); let tx = await I_MRProxied.removeModule(sto2, { from: token_owner }); assert.equal(tx.logs[0].args._moduleFactory, sto2, "Event is not properly emitted for _moduleFactory"); assert.equal(tx.logs[0].args._decisionMaker, token_owner, "Event is not properly emitted for _decisionMaker"); - let sto1_end = (await I_MRProxied.getModulesByType.call(3))[0]; + let sto1_end = (await I_MRProxied.getAllModulesByType.call(3))[0]; // re-ordering assert.equal(sto1_end, sto1); // delete related data assert.equal(await I_MRProxied.getUintValue.call(web3.utils.soliditySha3("registry", sto2)), 0); assert.equal((await I_MRProxied.getFactoryDetails.call(sto2))[1], 0); - assert.equal((await I_MRProxied.getModulesByType.call(3)).length, 3); + assert.equal((await I_MRProxied.getAllModulesByType.call(3)).length, 3); assert.equal(await I_MRProxied.getBoolValue.call(web3.utils.soliditySha3("verified", sto2)), false); }); diff --git a/test/u_module_registry_proxy.js b/test/u_module_registry_proxy.js index e3447bf27..a1ab58a15 100644 --- a/test/u_module_registry_proxy.js +++ b/test/u_module_registry_proxy.js @@ -257,7 +257,7 @@ contract("ModuleRegistryProxy", async (accounts) => { describe("Execute functionality of the implementation contract on the earlier storage", async () => { it("Should get the previous data", async () => { let _data = await I_MRProxied.getFactoryDetails.call(I_GeneralTransferManagerFactory.address); - assert.equal(_data[1].length, new BN(0), "Should give the original length"); + assert.equal(_data[2].length, new BN(0), "Should give the original length"); }); it("Should alter the old storage", async () => { @@ -265,7 +265,7 @@ contract("ModuleRegistryProxy", async (accounts) => { from: account_polymath }); let _data = await I_MRProxied.getFactoryDetails.call(I_GeneralTransferManagerFactory.address); - assert.equal(_data[1].length, 2, "Should give the updated length"); + assert.equal(_data[2].length, 2, "Should give the updated length"); }); });