Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow contract-to-contract write methods #550

Merged
merged 10 commits into from
Sep 30, 2024
25 changes: 22 additions & 3 deletions backend/consensus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import asyncio
from collections import deque
import json
import traceback
from typing import Callable, Iterator

Expand Down Expand Up @@ -106,8 +107,9 @@ async def _run_consensus(self):
# https://docs.sqlalchemy.org/en/20/orm/session_basics.html#is-the-session-thread-safe-is-asyncsession-safe-to-share-in-concurrent-tasks
transaction = await queue.get()
with self.get_session() as session:
tg.create_task(
self.exec_transaction(

async def exec_transaction_with_session_handling():
await self.exec_transaction(
transaction,
TransactionsProcessor(session),
ChainSnapshot(session),
Expand All @@ -116,7 +118,9 @@ async def _run_consensus(self):
contract_address, session
),
)
)
session.commit()

tg.create_task(exec_transaction_with_session_handling())

except Exception as e:
print("Error running consensus", e)
Expand Down Expand Up @@ -286,6 +290,21 @@ async def exec_transaction(
consensus_data,
)

# Insert pending transactions generated by contract-to-contract calls
pending_transactions_to_insert = leader_receipt.pending_transactions
for pending_transaction in pending_transactions_to_insert:
transactions_processor.insert_transaction(
transaction.to_address, # new calls are done by the contract
pending_transaction.address,
{
"function_name": pending_transaction.method_name,
"function_args": json.dumps(pending_transaction.args),
},
value=0, # No value gets transferred?
type=TransactionType.RUN_CONTRACT.value,
leader_only=transaction.leader_only, # Cascade
)

def execute_transfer(
self,
transaction: Transaction,
Expand Down
44 changes: 32 additions & 12 deletions backend/node/genvm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from backend.node.genvm.equivalence_principle import EquivalencePrinciple
from backend.node.genvm.code_enforcement import code_enforcement_check
from backend.node.genvm.std.vector_store import VectorStore
from backend.node.genvm.types import Receipt, ExecutionResultStatus, ExecutionMode
from backend.node.genvm.types import (
PendingTransaction,
Receipt,
ExecutionResultStatus,
ExecutionMode,
)
from backend.protocol_rpc.message_handler.base import MessageHandler


Expand Down Expand Up @@ -72,6 +77,7 @@ def __init__(
self.contract_runner = ContractRunner(
validator_mode, validator, contract_snapshot_factory
)
self.pending_transactions: list[PendingTransaction] = []

@staticmethod
def _get_contract_class_name(contract_code: str) -> str:
Expand Down Expand Up @@ -101,6 +107,7 @@ def _generate_receipt(
eq_outputs=self.contract_runner.eq_outputs,
execution_result=execution_result,
error=error,
pending_transactions=self.pending_transactions,
)

async def deploy_contract(
Expand Down Expand Up @@ -131,6 +138,7 @@ async def deploy_contract(
"Contract": partial(
ExternalContract,
self.contract_runner.contract_snapshot_factory,
lambda x: self.pending_transactions.append(x),
),
}
):
Expand Down Expand Up @@ -227,6 +235,7 @@ async def run_contract(
"Contract": partial(
ExternalContract,
self.contract_runner.contract_snapshot_factory,
lambda x: self.pending_transactions.append(x),
),
}
):
Expand Down Expand Up @@ -392,6 +401,7 @@ def get_contract_data(
"Contract": partial(
ExternalContract,
contract_snapshot_factory,
None, # TODO: should read methods be allowed to add new transactions?
)
}
):
Expand All @@ -410,24 +420,34 @@ def get_contract_data(

class ExternalContract:
def __init__(
self, contract_snapshot_factory: Callable[[str], ContractSnapshot], address: str
self,
contract_snapshot_factory: Callable[[str], ContractSnapshot],
schedule_pending_transaction: Callable[[PendingTransaction], None],
address: str,
):
self.address = address

self.contract_snapshot = contract_snapshot_factory(address)
self.contract_snapshot_factory = contract_snapshot_factory
self.schedule_pending_transaction = schedule_pending_transaction

def __getattr__(self, name):
def method(*args, **kwargs):
if not re.match("get_", name):
raise Exception("Method name must start with 'get_'")

return GenVM.get_contract_data(
self.contract_snapshot.contract_code,
self.contract_snapshot.encoded_state,
name,
args,
self.contract_snapshot_factory,
)
if re.match("get_", name):
return GenVM.get_contract_data(
self.contract_snapshot.contract_code,
self.contract_snapshot.encoded_state,
name,
args,
self.contract_snapshot_factory,
)
else:
self.schedule_pending_transaction(
PendingTransaction(
address=self.address, method_name=name, args=args
)
)

return None

return method
14 changes: 13 additions & 1 deletion backend/node/genvm/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Iterable, Optional


class Vote(Enum):
Expand All @@ -18,6 +18,13 @@ class ExecutionResultStatus(Enum):
ERROR = "ERROR"


@dataclass
class PendingTransaction:
address: str # Address of the contract to call
method_name: str
args: list


@dataclass
class Receipt:
class_name: str
Expand All @@ -31,6 +38,7 @@ class Receipt:
execution_result: ExecutionResultStatus
error: Optional[Exception] = None
vote: Optional[Vote] = None
pending_transactions: Iterable[PendingTransaction] = ()

def to_dict(self):
return {
Expand All @@ -45,4 +53,8 @@ def to_dict(self):
"node_config": self.node_config,
"eq_outputs": self.eq_outputs,
"error": str(self.error) if self.error else None,
"pending_transactions": [
pending_transaction.__dict__
for pending_transaction in self.pending_transactions
],
}
2 changes: 1 addition & 1 deletion backend/protocol_rpc/transactions_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def decode_method_call_data(data: str) -> DecodedMethodCallData:
try:
data_decoded = rlp.decode(data_bytes, MethodCallTransactionPayload)
except rlp.exceptions.DeserializationError as e:
print("Error decoding method call data, falling back to default:", e)
print("WARN | falling back to default decode method call data:", e)
data_decoded = rlp.decode(data_bytes, MethodCallTransactionPayloadDefault)

leader_only = getattr(data_decoded, "leader_only", False)
Expand Down
5 changes: 4 additions & 1 deletion frontend/src/hooks/useInputMap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ export const useInputMap = () => {
const component = InputTypesMap[type];

if (!component) {
throw new Error(`Component not found for input type: ${type}`);
console.warn(
`Component not found for input type: ${type}, defaulting to string`,
);
return StringField;
}

return component;
Expand Down
12 changes: 8 additions & 4 deletions frontend/test/unit/hooks/useInputMap.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ describe('useInputMap composable', () => {
expect(component).toBe(BooleanField);
});

it('should throw an error for an unknown type', () => {
expect(() => getComponent('unknown')).toThrowError(
'Component not found for input type: unknown',
);
it('should default to string for an empty type', () => {
const component = getComponent('');
expect(component).toBe(StringField);
});

it('should default to string for an unknown type', () => {
const component = getComponent('unknown');
expect(component).toBe(StringField);
});
});
35 changes: 35 additions & 0 deletions tests/integration/contract_examples/multi_tenant_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from backend.node.genvm.icontract import IContract


class MultiTentantStorage(IContract):
"""
Same functionality as UserStorage, but implemented with multiple storage contracts.
Each user is assigned to a storage contract, and all storage contracts are managed by this same contract.
This contract does not prevent users from directly interacting with the storage contracts, but it doesn't bother us for testing purposes.
This is done to test contract calls between different contracts.
"""

def __init__(self, storage_contracts: list[str]):
self.all_storage_contracts = storage_contracts.copy()
self.available_storage_contracts = self.all_storage_contracts
self.mappings = {} # mapping of user address to storage contract address

def get_available_contracts(self) -> list[str]:
return self.available_storage_contracts

def get_all_storages(self) -> dict[str, str]:
return {
storage_contract: Contract(storage_contract).get_storage()
for storage_contract in self.all_storage_contracts
}

def update_storage(self, new_storage: str) -> None:
# Assign user to a storage contract if not already assigned
if contract_runner.from_address not in self.mappings:
self.mappings[contract_runner.from_address] = (
self.available_storage_contracts[0]
)
self.available_storage_contracts = self.available_storage_contracts[1:]

contract_to_use = self.mappings[contract_runner.from_address]
Contract(contract_to_use).update_storage(new_storage)
119 changes: 119 additions & 0 deletions tests/integration/contract_examples/test_multi_tenant_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import json
import os
import time

from tests.common.accounts import create_new_account
from tests.common.request import (
call_contract_method,
deploy_intelligent_contract,
send_transaction,
)
from tests.common.response import has_success_status


def test_multi_tenant_storage(setup_validators):
"""
This test verifies the functionality of a multi-tenant storage contract. It deploys two separate storage contracts
and a multi-tenant storage contract that manages them. The test aims to:

1. Deploy two different storage contracts with initial storage values.
2. Deploy a multi-tenant storage contract that can interact with multiple storage contracts.
3. Test the ability of the multi-tenant contract to update and retrieve storage values for multiple users
across different storage contracts.
4. Ensure the multi-tenant contract correctly assigns users to storage contracts and manages their data.

This test demonstrates contract-to-contract interactions and multi-tenant data management.
"""
main_account = create_new_account()
user_account_a = create_new_account()
user_account_b = create_new_account()

current_directory = os.path.dirname(os.path.abspath(__file__))

# Storage Contracts
contract_code = open("examples/contracts/storage.py", "r").read()

## Deploy first Storage Contract
first_storage_contract_address, transaction_response_deploy = (
deploy_intelligent_contract(
main_account,
contract_code,
json.dumps({"initial_storage": "initial_storage_a"}),
)
)
assert has_success_status(transaction_response_deploy)

## Deploy second Storage Contract

second_storage_contract_address, transaction_response_deploy = (
deploy_intelligent_contract(
main_account,
contract_code,
json.dumps({"initial_storage": "initial_storage_b"}),
)
)
assert has_success_status(transaction_response_deploy)

# Deploy Multi Tenant Storage Contract
contract_file = os.path.join(current_directory, "multi_tenant_storage.py")
contract_code = open(contract_file, "r").read()

multi_tenant_storage_address, transaction_response_deploy = (
deploy_intelligent_contract(
main_account,
contract_code,
json.dumps(
{
"storage_contracts": [
first_storage_contract_address,
second_storage_contract_address,
]
}
),
)
)
assert has_success_status(transaction_response_deploy)

# update storage for first contract
transaction_response_call = send_transaction(
user_account_a,
multi_tenant_storage_address,
"update_storage",
["user_a_storage"],
)

assert has_success_status(transaction_response_call)

# update storage for second contract
transaction_response_call = send_transaction(
user_account_b,
multi_tenant_storage_address,
"update_storage",
["user_b_storage"],
)

assert has_success_status(transaction_response_call)

# wait for transactions to be processed
for attempt in range(10):
print(f"Attempt {attempt}")

# get all storages
storages = call_contract_method(
multi_tenant_storage_address,
main_account,
"get_all_storages",
[],
)
print(f"Storages: {storages}")

if storages == {
first_storage_contract_address: "user_a_storage",
second_storage_contract_address: "user_b_storage",
}:
break

time.sleep(5)

else:
assert False, f"Storages don't match: {storages}"
2 changes: 2 additions & 0 deletions tests/integration/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
asyncio_default_fixture_loop_scope = function