diff --git a/starknet_py/utils/typed_data.py b/starknet_py/utils/typed_data.py index 84e2e2835..04a76a16a 100644 --- a/starknet_py/utils/typed_data.py +++ b/starknet_py/utils/typed_data.py @@ -132,6 +132,9 @@ def _encode_value( hashes = [self._encode_value(type_name, val) for val in value] return compute_hash_on_elements(hashes) + if type_name not in _get_basic_type_names(self.domain.resolved_revision): + raise ValueError(f"Type [{type_name}] is not defined in types.") + basic_type = BasicType(type_name) if basic_type == BasicType.MERKLE_TREE and isinstance(value, list): @@ -164,12 +167,34 @@ def _encode_data(self, type_name: str, data: dict) -> List[int]: return values def _verify_types(self): - reserved_type_names = ["felt", "felt*", "string", "selector", "merkletree"] + if self.domain.separator_name not in self.types: + raise ValueError(f"Types must contain '{self.domain.separator_name}'.") + + basic_type_names = _get_basic_type_names(self.domain.resolved_revision) - for type_name in reserved_type_names: + for type_name in basic_type_names: if type_name in self.types: raise ValueError(f"Reserved type name: {type_name}") + referenced_types = { + ref_type.contains + if ref_type.contains is not None + else strip_pointer(ref_type.type) + for type_name in self.types + for ref_type in self.types[type_name] + } + referenced_types.update([self.domain.separator_name, self.primary_type]) + + for type_name in self.types: + if not type_name: + raise ValueError("Type names cannot be empty.") + if is_pointer(type_name): + raise ValueError(f"Type names cannot end in *. {type_name} was found.") + if type_name not in referenced_types: + raise ValueError( + f"Dangling types are not allowed. Unreferenced type {type_name} was found." + ) + def _get_dependencies(self, type_name: str) -> List[str]: if type_name not in self.types: # type_name is a primitive type, has no dependencies @@ -280,7 +305,7 @@ def get_hex(value: Union[int, str]) -> str: def is_pointer(value: str) -> bool: - return len(value) > 0 and value[-1] == "*" + return value.endswith("*") def strip_pointer(value: str) -> str: @@ -306,9 +331,26 @@ class BasicType(Enum): FELT = "felt" SELECTOR = "selector" MERKLE_TREE = "merkletree" + STRING = "string" SHORT_STRING = "shortstring" +def _get_basic_type_names(revision: Revision) -> List[str]: + basic_types_v0 = [ + BasicType.FELT, + BasicType.SELECTOR, + BasicType.MERKLE_TREE, + BasicType.STRING, + ] + + basic_types_v1 = basic_types_v0 + [ + BasicType.SHORT_STRING, + ] + + basic_types = basic_types_v0 if revision == Revision.V0 else basic_types_v1 + return [basic_type.value for basic_type in basic_types] + + # pylint: disable=unused-argument # pylint: disable=no-self-use diff --git a/starknet_py/utils/typed_data_test.py b/starknet_py/utils/typed_data_test.py index ffe0a6e3d..62751b2f3 100644 --- a/starknet_py/utils/typed_data_test.py +++ b/starknet_py/utils/typed_data_test.py @@ -9,7 +9,13 @@ from starknet_py.net.models.typed_data import Revision from starknet_py.tests.e2e.fixtures.constants import TYPED_DATA_DIR -from starknet_py.utils.typed_data import Domain, Parameter, TypedData, get_hex +from starknet_py.utils.typed_data import ( + BasicType, + Domain, + Parameter, + TypedData, + get_hex, +) class CasesRev0(Enum): @@ -197,15 +203,80 @@ def _make_typed_data(included_type: str, revision: Revision): @pytest.mark.parametrize( - "included_type", + "included_type, revision", + [ + ("", Revision.V1), + ("myType*", Revision.V1) + ], +) +def test_invalid_type_names(included_type: str, revision: Revision): + with pytest.raises(ValueError): + _make_typed_data(included_type, revision) + + +@pytest.mark.parametrize( + "included_type, revision", [ - "felt", - "felt*", - "string", - "selector", - "merkletree" + (BasicType.FELT.value, Revision.V0), + (BasicType.STRING.value, Revision.V0), + (BasicType.SELECTOR.value, Revision.V0), + (BasicType.MERKLE_TREE.value, Revision.V0), + (BasicType.FELT.value, Revision.V1), + (BasicType.STRING.value, Revision.V1), + (BasicType.SELECTOR.value, Revision.V1), + (BasicType.MERKLE_TREE.value, Revision.V1), + (BasicType.SHORT_STRING.value, Revision.V1), ], ) -def test_invalid_types(included_type: str): +def test_types_redefinition(included_type: str, revision: Revision): with pytest.raises(ValueError, match=f"Reserved type name: {included_type}"): - _make_typed_data(included_type, Revision.V1) + _make_typed_data(included_type, revision) + + +def test_custom_type_definition(): + _make_typed_data("myType", Revision.V0) + + +@pytest.mark.parametrize( + "revision", + list(Revision), +) +def test_missing_domain_type(revision: Revision): + domain = domain_v0 if revision == Revision.V0 else domain_v1 + + with pytest.raises(ValueError, match=f"Types must contain '{domain.separator_name}'."): + TypedData( + types={}, + primary_type="felt", + domain=domain, + message={}, + ) + + +def test_dangling_type(): + with pytest.raises(ValueError, match="Dangling types are not allowed. Unreferenced type dangling was found."): + TypedData( + types={ + **domain_type_v1, + "dangling": [], + "mytype": [] + }, + primary_type="mytype", + domain=domain_v1, + message={"mytype": 1}, + ) + + +def test_missing_dependency(): + typed_data = TypedData( + types={ + **domain_type_v1, + "house": [Parameter(name="fridge", type="ice cream")] + }, + primary_type="house", + domain=domain_v1, + message={"fridge": 1}, + ) + + with pytest.raises(ValueError, match=r"Type \[ice cream\] is not defined in types."): + typed_data.struct_hash("house", {"fridge": 1})