diff --git a/packages/smock/src/smockit/smockit.ts b/packages/smock/src/smockit/smockit.ts index fb1df4d14e81..1673064518aa 100644 --- a/packages/smock/src/smockit/smockit.ts +++ b/packages/smock/src/smockit/smockit.ts @@ -89,11 +89,15 @@ const smockifyFunction = ( return { functionName: fragment.name, + functionSignature: fragment.format(), data, } }) .filter((functionResult: any) => { - return functionResult.functionName === functionName + return ( + functionResult.functionName === functionName || + functionResult.functionSignature === functionName + ) }) .map((functionResult: any) => { return functionResult.data diff --git a/packages/smock/test/contracts/TestHelpers_MockCaller.sol b/packages/smock/test/contracts/TestHelpers_MockCaller.sol new file mode 100644 index 000000000000..818f3d77087a --- /dev/null +++ b/packages/smock/test/contracts/TestHelpers_MockCaller.sol @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.7.0; + +contract TestHelpers_MockCaller { + function callMock(address _target, bytes memory _data) public { + _target.call(_data); + } +} diff --git a/packages/smock/test/smockit/call-assertions.spec.ts b/packages/smock/test/smockit/call-assertions.spec.ts new file mode 100644 index 000000000000..2df21355cf17 --- /dev/null +++ b/packages/smock/test/smockit/call-assertions.spec.ts @@ -0,0 +1,56 @@ +/* Imports: External */ +import hre from 'hardhat' +import { expect } from 'chai' +import { Contract } from 'ethers' + +/* Imports: Internal */ +import { MockContract, smockit } from '../../src' + +describe('[smock]: call assertion tests', () => { + const ethers = (hre as any).ethers + + let mock: MockContract + beforeEach(async () => { + mock = await smockit('TestHelpers_BasicReturnContract') + }) + + let mockCaller: Contract + before(async () => { + const mockCallerFactory = await ethers.getContractFactory( + 'TestHelpers_MockCaller' + ) + mockCaller = await mockCallerFactory.deploy() + }) + + describe('overloaded functions', () => { + it('should be able to modify both versions of an overloaded function', async () => { + mock.smocked['overloadedFunction(uint256)'].will.return.with(0) + mock.smocked['overloadedFunction(uint256,uint256)'].will.return.with(0) + + const expected1 = ethers.BigNumber.from(1234) + await mockCaller.callMock( + mock.address, + mock.interface.encodeFunctionData('overloadedFunction(uint256)', [ + expected1, + ]) + ) + + expect( + mock.smocked['overloadedFunction(uint256)'].calls[0] + ).to.deep.equal([expected1]) + + const expected2 = ethers.BigNumber.from(5678) + await mockCaller.callMock( + mock.address, + mock.interface.encodeFunctionData( + 'overloadedFunction(uint256,uint256)', + [expected2, expected2] + ) + ) + + expect( + mock.smocked['overloadedFunction(uint256,uint256)'].calls[0] + ).to.deep.equal([expected2, expected2]) + }) + }) +})