Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds errors to the Contract class #102

Merged
merged 9 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions example/abis/Example.json

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion example/contracts/Example.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ pragma solidity ^0.8.0;

contract Example {

event Flip(uint flip);
event Flop(uint flop);

string contractName;

enum Letters { A, B, C }
Expand Down Expand Up @@ -34,7 +37,9 @@ contract Example {
revert WrongChoice(answer, "Thank you for playing, but you chose the wrong letter");
}

function flipFlop(uint flip, uint flop) public pure returns (uint _flop, uint _flip) {
function flipFlop(uint flip, uint flop) public returns (uint _flop, uint _flip) {
emit Flip(flip);
emit Flop(flop);
return (flop,flip);
}

Expand Down
257 changes: 252 additions & 5 deletions example/types/ExampleContract.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions example/types/ExampleTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from dataclasses import dataclass

from web3.types import ABIEvent, ABIEventParams


@dataclass
class SimpleStruct:
Expand All @@ -45,6 +47,25 @@ class NestedStruct:
innerStruct: InnerStruct


Flip = ABIEvent(
anonymous=False,
inputs=[
ABIEventParams(indexed=False, name="flip", type="uint256"),
],
name="Flip",
type="event",
)

Flop = ABIEvent(
anonymous=False,
inputs=[
ABIEventParams(indexed=False, name="flop", type="uint256"),
],
name="Flop",
type="event",
)


@dataclass
class ErrorInfo:
"""Custom contract error information."""
Expand Down
14 changes: 13 additions & 1 deletion pypechain/render/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def render_contract_file(contract_info: ContractInfo) -> str | None:

function_datas, constructor_data = get_function_datas(contract_info.abi)
event_datas = get_event_datas(contract_info.abi)
error_infos = contract_info.errors.values()

has_bytecode = bool(contract_info.bytecode)
has_events = bool(len(event_datas.values()))
Expand All @@ -227,6 +228,12 @@ def render_contract_file(contract_info: ContractInfo) -> str | None:
events=event_datas,
)

has_errors = bool(len(error_infos))
errors_block = templates.errors_template.render(
contract_name=contract_info.contract_name,
errors=error_infos,
)

abi_block = templates.abi_template.render(
abi=contract_info.abi,
bytecode=contract_info.bytecode,
Expand All @@ -236,6 +243,7 @@ def render_contract_file(contract_info: ContractInfo) -> str | None:
contract_block = templates.contract_template.render(
has_bytecode=has_bytecode,
has_events=has_events,
has_errors=has_errors,
contract_name=contract_info.contract_name,
constructor=constructor_data,
functions=function_datas,
Expand All @@ -254,9 +262,11 @@ def render_contract_file(contract_info: ContractInfo) -> str | None:
has_overloading=has_overloading,
has_multiple_return_values=has_multiple_return_values,
has_bytecode=has_bytecode,
has_events=has_events,
functions_block=functions_block,
has_events=has_events,
events_block=events_block,
has_errors=has_errors,
errors_block=errors_block,
abi_block=abi_block,
contract_block=contract_block,
# TODO: use this data to add a typed constructor
Expand Down Expand Up @@ -320,6 +330,7 @@ class ContractTemplates(NamedTuple):
base_template: Any
functions_template: Any
events_template: Any
errors_template: Any
abi_template: Any
contract_template: Any

Expand All @@ -330,6 +341,7 @@ def get_templates_for_contract_file(env):
base_template=env.get_template("contract.py/base.py.jinja2"),
functions_template=env.get_template("contract.py/functions.py.jinja2"),
events_template=env.get_template("contract.py/events.py.jinja2"),
errors_template=env.get_template("contract.py/errors.py.jinja2"),
abi_template=env.get_template("contract.py/abi.py.jinja2"),
contract_template=env.get_template("contract.py/contract.py.jinja2"),
)
Expand Down
8 changes: 6 additions & 2 deletions pypechain/templates/contract.py/base.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ from dataclasses import fields, is_dataclass
from typing import Any, NamedTuple, Tuple, Type, TypeVar, cast, overload
from typing import Iterable, Sequence

from eth_abi.codec import ABICodec
from eth_abi.registry import registry as default_registry
from eth_typing import ChecksumAddress, HexStr
from eth_account.signers.local import LocalAccount
from hexbytes import HexBytes
Expand All @@ -34,13 +36,13 @@ from web3 import Web3
from web3.contract.contract import Contract, ContractFunction, ContractFunctions, ContractConstructor
from web3.contract.contract import ContractEvent, ContractEvents
from web3.exceptions import FallbackNotFound
from web3.types import ABI, BlockIdentifier, CallOverride, TxParams
from web3.types import ABI, BlockIdentifier, CallOverride, TxParams, ABIFunction
from web3.types import EventData
from web3._utils.filters import LogFilter
{% for struct_info in structs_used %}
from .{{struct_info.contract_name}}Types import {{struct_info.name}}
{% endfor %}
from .utilities import tuple_to_dataclass, dataclass_to_tuple, rename_returned_types
from .utilities import dataclass_to_tuple, get_abi_input_types, rename_returned_types, tuple_to_dataclass


structs = {
Expand All @@ -53,6 +55,8 @@ structs = {

{% if has_events %}{{ events_block }}{% endif %}

{% if has_errors %}{{ errors_block }}{% endif %}

{{abi_block}}

{{contract_block}}
8 changes: 8 additions & 0 deletions pypechain/templates/contract.py/contract.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class {{contract_name}}Contract(Contract):
{% if has_events -%}
self.events = {{contract_name}}ContractEvents({{contract_name | lower}}_abi, self.w3, address) # type: ignore
{%- endif %}
{% if has_errors -%}
self.errors = {{contract_name}}ContractErrors()
{%- endif %}

except FallbackNotFound:
print("Fallback function not found. Continuing...")
Expand All @@ -22,6 +25,10 @@ class {{contract_name}}Contract(Contract):
events: {{contract_name}}ContractEvents
{%- endif %}

{% if has_errors -%}
errors: {{contract_name}}ContractErrors = {{contract_name}}ContractErrors()
{%- endif %}

functions: {{contract_name}}ContractFunctions

{% set has_constructor_args = constructor.input_names_and_types|length > 0 %}
Expand Down Expand Up @@ -115,6 +122,7 @@ class {{contract_name}}Contract(Contract):
"""
contract = super().factory(w3, class_name, **kwargs)
contract.functions = {{contract_name}}ContractFunctions({{contract_name | lower}}_abi, w3, None)
contract.errors = {{contract_name}}ContractErrors()

return contract

74 changes: 74 additions & 0 deletions pypechain/templates/contract.py/errors.py.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
{# loop over all errors and create types for each #}
{%- for error_info in errors -%}
class {{contract_name}}{{error_info.name}}ContractError:
"""ContractError for {{error_info.name}}."""
# @combomethod destroys return types, so we are redefining functions as both class and instance
# pylint: disable=function-redefined

# 4 byte error selector
selector: str
# error signature, i.e. CustomError(uint256,bool)
signature: str

{# TODO: remove pylint disable when we add a type-hint for argument_names #}
# pylint: disable=useless-parent-delegation
def __init__(
self: "{{contract_name}}{{error_info.name}}ContractError",
) -> None:
self.selector = "{{error_info.selector}}"
self.signature = "{{error_info.signature}}"

def decode_error_data( # type: ignore
self: "{{contract_name}}{{error_info.name}}ContractError",
data: HexBytes,
# TODO: instead of returning a tuple, return a dataclass with the input names and types just like we do for functions
) -> tuple[Any, ...]:
"""Decodes error data returns from a smart contract."""
error_abi = cast(
ABIFunction,
[item for item in {{contract_name | lower}}_abi if item.get("name") == "{{error_info.name}}" and item.get("type") == "error"][0],
)
types = get_abi_input_types(error_abi)
abi_codec = ABICodec(default_registry)
decoded = abi_codec.decode(types, data)
return decoded

@classmethod
def decode_error_data( # type: ignore
cls: Type["{{contract_name}}{{error_info.name}}ContractError"],
data: HexBytes,
) -> tuple[Any, ...]:
"""Decodes error data returns from a smart contract."""
error_abi = cast(
ABIFunction,
[item for item in {{contract_name | lower}}_abi if item.get("name") == "{{error_info.name}}" and item.get("type") == "error"][0],
)
types = get_abi_input_types(error_abi)
abi_codec = ABICodec(default_registry)
decoded = abi_codec.decode(types, data)
return decoded
{% endfor %}

class {{contract_name}}ContractErrors:
"""ContractErrors for the {{contract_name}} contract."""
{% for error_info in errors %}
{{error_info.name}}: {{contract_name}}{{error_info.name}}ContractError
{% endfor %}

def __init__(
self,
) -> None:
{% for error_info in errors -%}
self.{{error_info.name}} = {{contract_name}}{{error_info.name}}ContractError()
{% endfor %}
self._all = [{% for error_info in errors -%}self.{{error_info.name}},{%- endfor %}]

def decode_custom_error(self, data: str) -> tuple[Any, ...]:
"""Decodes a custom contract error."""
selector = data[:10]
for err in self._all:
if err.selector == selector:
return err.decode_error_data(HexBytes(data[10:]))

raise ValueError(f"{{contract_name}} does not have a selector matching {selector}")

27 changes: 27 additions & 0 deletions pypechain/templates/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from dataclasses import fields, is_dataclass
from typing import Any, Tuple, TypeVar, cast

from eth_utils.abi import collapse_if_tuple
from web3.types import ABIFunction

T = TypeVar("T")


Expand Down Expand Up @@ -108,3 +111,27 @@ def rename_returned_types(
# cover case of single return value
converted_value = tuple_to_dataclass(return_types, structs, raw_values)
return converted_value


def get_abi_input_types(abi: ABIFunction) -> list[str]:
"""Gets all the solidity input types for a function or error.

Cribbed from web3._utils.abi.py file.

Parameters
----------

abi: ABIFunction
The ABIFunction or ABIError that we want to get input types for.

Returns
-------
list[str]
A list of solidity input types.

"""

if "inputs" not in abi and (abi.get("type") == "fallback" or abi.get("type") == "receive"):
return []
else:
return [collapse_if_tuple(cast(dict[str, Any], arg)) for arg in abi.get("inputs", [])]
11 changes: 11 additions & 0 deletions pypechain/test/errors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Generate ABIs

Install [`solc`](https://docs.soliditylang.org/en/latest/installing-solidity.html).

If pytest fails to run while editing tests, you'll need to recompile the contracts manually. From
this directory run:

```bash
rm abis/Errors.json
solc contracts/Errors.sol --combined-json abi,bin,metadata >> abis/Errors.json
```
Empty file.
Loading
Loading