Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overloading to contract functions. #28

Merged
merged 14 commits into from
Nov 10, 2023
2 changes: 1 addition & 1 deletion .github/workflows/static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: get all Python files
id: list_files
run: |
echo "files=$(git ls-files '*.py' '*.pyi' | xargs)" >> $GITHUB_OUTPUT
echo "files=$(git ls-files '*.py' '*.pyi' | grep -v '^snapshots/' | xargs)" >> $GITHUB_OUTPUT

- name: run Pylint on files
run: |
Expand Down
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ fail-under=10
#from-stdin=

# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
ignore=

# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\' represents the directory delimiter on Windows systems, it
# can't be used as an escape character.
ignore-paths=docs,.venv,hyperdrive_solidity/.venv
ignore-paths=docs,.venv,snapshots

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
4 changes: 2 additions & 2 deletions pypechain/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def main(argv: Sequence[str] | None = None) -> None:

abi_file_path, output_dir, line_length = parse_arguments(argv)

# Create/clear the directory
# Create/clear the output directory
setup_directory(output_dir)

# List to store all JSON files to be processed
# List to store all JSON ABI files to be processed
json_files_to_process = []

# Check if provided path is a directory or file
Expand Down
155 changes: 99 additions & 56 deletions pypechain/render/contract.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Functions to render Python files from an abi usng a jinja2 template."""
from __future__ import annotations

from pathlib import Path
from typing import TypedDict

from jinja2 import Template
from web3.types import ABI

from pypechain.utilities.abi import (
get_abi_items,
Expand All @@ -14,10 +16,26 @@
load_abi_from_file,
)
from pypechain.utilities.format import capitalize_first_letter_only
from pypechain.utilities.sort import get_intersection_and_unique
from pypechain.utilities.templates import get_jinja_env


def render_contract_file(contract_name: str, contract_template: Template, abi_file_path: Path) -> str:
class SignatureData(TypedDict):
"""Define the structure of the signature_datas dictionary"""

input_names_and_types: list[str]
input_names: list[str]
outputs: list[str]


class FunctionData(TypedDict):
"""Define the structure of the function_data dictionary"""

name: str
capitalized_name: str
signature_datas: list[SignatureData]


def render_contract_file(contract_name: str, abi_file_path: Path) -> str:
"""Returns the serialized code of the contract file to be generated.

Arguments
Expand All @@ -32,68 +50,93 @@ def render_contract_file(contract_name: str, contract_template: Template, abi_fi
str
A serialized python file.
"""
env = get_jinja_env()
base_template = env.get_template("contract.py/base.py.jinja2")
functions_template = env.get_template("contract.py/functions.py.jinja2")
abi_template = env.get_template("contract.py/abi.py.jinja2")
contract_template = env.get_template("contract.py/contract.py.jinja2")

# TODO: add return types to function calls

abi = load_abi_from_file(abi_file_path)
function_datas, constructor_data = get_function_datas(abi)
has_overloading = any(len(function_data["signature_datas"]) > 1 for function_data in function_datas.values())

functions_block = functions_template.render(
abi=abi,
contract_name=contract_name,
functions=function_datas,
# TODO: use this data to add a typed constructor
constructor=constructor_data,
)

abi_block = abi_template.render(
abi=abi,
contract_name=contract_name,
)

contract_block = contract_template.render(
contract_name=contract_name,
functions=function_datas,
)

# TODO: return types to function calls
# Extract function names and their input parameters from the ABI
function_datas = {}
constructor_data = {}
for abi_function in get_abi_items(abi_file_path):
# Render the template
return base_template.render(
contract_name=contract_name,
has_overloading=has_overloading,
functions_block=functions_block,
abi_block=abi_block,
contract_block=contract_block,
# TODO: use this data to add a typed constructor
# constructor_data=constructor_data,
)


def get_function_datas(abi: ABI) -> tuple[dict[str, FunctionData], SignatureData | None]:
"""_summary_

Arguments
---------
abi : ABI
An application boundary interface for smart contract in json format.

Returns
-------
tuple[dict[str, FunctionData], SignatureData | None]
A tuple where the first value is a dictionary of FunctionData's keyed by function name and
the second value is SignatureData for the constructor.
"""
function_datas: dict[str, FunctionData] = {}
constructor_data: SignatureData | None = None
for abi_function in get_abi_items(abi):
if is_abi_function(abi_function):
# TODO: investigate better typing here? templete.render expects an object so we'll have
# to convert.

# hanndle constructor
if is_abi_constructor(abi_function):
constructor_data = {
"input_names_and_types": [get_input_names_and_values(abi_function)],
"input_names": [get_input_names(abi_function)],
"outputs": [get_output_names(abi_function)],
"input_names_and_types": get_input_names_and_values(abi_function),
"input_names": get_input_names(abi_function),
"outputs": get_output_names(abi_function),
}

# handle functions
# handle all other functions
else:
name = abi_function.get("name", "")
if name and name not in function_datas:
function_data = {
# TODO: pass a typeguarded ABIFunction that has only required fields?
# name is required in the typeguard. Should be safe to default to empty string.
"name": name,
"capitalized_name": capitalize_first_letter_only(name),
"input_names_and_types": [get_input_names_and_values(abi_function)],
"input_names": [get_input_names(abi_function)],
"outputs": [get_output_names(abi_function)],
}
signature_data: SignatureData = {
"input_names_and_types": get_input_names_and_values(abi_function),
"input_names": get_input_names(abi_function),
"outputs": get_output_names(abi_function),
}
function_data: FunctionData = {
# TODO: pass a typeguarded ABIFunction that has only required fields?
# name is required in the typeguard. Should be safe to default to empty string.
"name": name,
"capitalized_name": capitalize_first_letter_only(name),
"signature_datas": [signature_data],
}
if not function_datas.get(name):
function_datas[name] = function_data
else: # this function already exists, presumably with a different signature
function_datas[name]["input_names_and_types"].append(get_input_names_and_values(abi_function))
function_datas[name]["input_names"].append(get_input_names(abi_function))
function_datas[name]["outputs"].append(get_output_names(abi_function))
# input_names_and_types will need optional args at the end
(
shared_input_names_and_types,
unique_input_names_and_types,
) = get_intersection_and_unique(function_datas[name]["input_names_and_types"])
function_datas[name]["required_input_names_and_types"] = shared_input_names_and_types
function_datas[name]["optional_input_names_and_types"] = []
for name_and_type in unique_input_names_and_types: # optional args
name_and_type += " | None = None"
function_datas[name]["optional_input_names_and_types"].append(name_and_type)
# we will also need the names to be separated
(
shared_input_names,
unique_input_names,
) = get_intersection_and_unique(function_datas[name]["input_names"])
function_datas[name]["required_input_names"] = shared_input_names

function_datas[name]["optional_input_names"] = unique_input_names

abi = load_abi_from_file(abi_file_path)

# Render the template
return contract_template.render(
abi=abi,
contract_name=contract_name,
functions=list(function_datas.values()),
# TODO: use this data to add a typed constructor
constructor=constructor_data,
)
else:
function_datas[name]["signature_datas"].append(signature_data)
return function_datas, constructor_data
69 changes: 69 additions & 0 deletions pypechain/render/contract_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Tests for rendering the contract file."""
import json
import os

from web3.types import ABI

from pypechain.render.contract import get_function_datas
from pypechain.utilities.templates import get_jinja_env

# using pytest fixtures necessitates this.
# pylint: disable=redefined-outer-name

current_path = os.path.abspath(os.path.dirname(__file__))
project_root = os.path.dirname(os.path.dirname(current_path))


class TestOverloading:
"""Tests pipeline from bots making trades to viewing the trades in the db"""

def test_overloading(self, snapshot):
"""Runs the entire pipeline and checks the database at the end.
All arguments are fixtures.
"""

env = get_jinja_env()
functions_template = env.get_template("contract.py/functions.py.jinja2")

# TODO: add return types to function calls

abi_str = """
[
{
"constant": true,
"inputs": [],
"name": "balanceOf",
"outputs": [{"name": "", "type": "uint256"}],
"payable": false,
"stateMutability": "view",
"type": "function"
},
{
"constant": true,
"inputs": [{"name": "who", "type": "address"}],
"name": "balanceOf",
"outputs": [{"name": "", "type": "uint256"}],
"payable": false,
"stateMutability": "view",
"type": "function"
}
]
"""

abi: ABI = json.loads(abi_str)

function_datas, constructor_data = get_function_datas(abi)
has_overloading = any(len(function_data["signature_datas"]) > 1 for function_data in function_datas.values())
contract_name = "Overloaded"

functions_block = functions_template.render(
abi=abi,
contract_name=contract_name,
functions=function_datas,
# TODO: use this data to add a typed constructor
constructor=constructor_data,
)
assert has_overloading is True

snapshot.snapshot_dir = "snapshots" # This line is optional.
snapshot.assert_match(functions_block, "expected_overloading.py")
8 changes: 2 additions & 6 deletions pypechain/render/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pypechain.render.types import render_types_file
from pypechain.utilities.file import write_string_to_file
from pypechain.utilities.format import apply_black_formatting
from pypechain.utilities.templates import setup_templates


def render_files(abi_file_path: str, output_dir: str, line_length: int) -> None:
Expand All @@ -20,12 +19,9 @@ def render_files(abi_file_path: str, output_dir: str, line_length: int) -> None:
contract_name = os.path.splitext(filename)[0]
contract_path = output_path.joinpath(f"{contract_name}")

# grab the templates
contract_template, types_template = setup_templates()

# render the code
rendered_contract_code = render_contract_file(contract_name, contract_template, file_path)
rendered_types_code = render_types_file(contract_name, types_template, file_path)
rendered_contract_code = render_contract_file(contract_name, file_path)
rendered_types_code = render_types_file(contract_name, file_path)

# Format the generated code using Black
formatted_contract_code = apply_black_formatting(rendered_contract_code, line_length)
Expand Down
13 changes: 5 additions & 8 deletions pypechain/render/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@
from dataclasses import asdict
from pathlib import Path

from jinja2 import Template
from pypechain.utilities.abi import get_events_for_abi, get_structs_for_abi, load_abi_from_file
from pypechain.utilities.templates import get_jinja_env

from pypechain.utilities.abi import (
get_events_for_abi,
get_structs_for_abi,
load_abi_from_file,
)


def render_types_file(contract_name: str, types_template: Template, abi_file_path: Path) -> str:
def render_types_file(contract_name: str, abi_file_path: Path) -> str:
"""Returns the serialized code of the types file to be generated.

Arguments
Expand All @@ -28,6 +23,8 @@ def render_types_file(contract_name: str, types_template: Template, abi_file_pat
str
A serialized python file.
"""
env = get_jinja_env()
types_template = env.get_template("types.py.jinja2")

abi = load_abi_from_file(abi_file_path)

Expand Down
Loading