Skip to content

Commit

Permalink
Add support for declaring contract compiled with cairo 2.6 (#1314)
Browse files Browse the repository at this point in the history
* Add support for declaring contracts compiled with 2.6.0
  • Loading branch information
tkumor3 authored Mar 14, 2024
1 parent 1834b50 commit 0a2cbf2
Show file tree
Hide file tree
Showing 9 changed files with 10,972 additions and 11 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: Checks

env:
DEVNET_SHA: "c6ffb99"
CAIRO_LANG_VERSION: "0.13.1"

on:
push:
Expand Down Expand Up @@ -86,7 +87,7 @@ jobs:
- name: Install deprecated cairo compiler
run: |
pip install --upgrade setuptools
pip install cairo-lang==0.13.0
pip install cairo-lang==${{ env.CAIRO_LANG_VERSION }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand All @@ -105,7 +106,7 @@ jobs:
uses: actions/cache@v3
with:
path: starknet_py/tests/e2e/mock/contracts_compiled
key: ${{ runner.os }}-contracts-${{ hashFiles('starknet_py/tests/e2e/mock/contracts', 'poetry.lock') }}
key: ${{ runner.os }}-contracts-${{ hashFiles('starknet_py/tests/e2e/mock/contracts', 'poetry.lock') }}-${{ env.CAIRO_LANG_VERSION }}

- name: Compile contracts
if: steps.cache-contracts.outputs.cache-hit != 'true'
Expand Down
2 changes: 1 addition & 1 deletion starknet_py/contract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def test_compute_hash(balance_contract):
assert (
Contract.compute_contract_hash(balance_contract)
== 0x12177EA61E5791CC068D7EE979B74F60A7205A23404C07440F4892B826147C0
== 0x7A98EAB69A2592EF5D3805990A43525D633DDC42B4D5B2524C7F38B7C59265F
)


Expand Down
105 changes: 103 additions & 2 deletions starknet_py/hash/casm_class_hash.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import List
from typing import List, Optional, Sequence, Tuple

from poseidon_py.poseidon_hash import poseidon_hash_many

from starknet_py.cairo.felt import encode_shortstring
from starknet_py.hash.compiled_class_hash_objects import (
BytecodeLeaf,
BytecodeSegment,
BytecodeSegmentedNode,
BytecodeSegmentStructure,
NestedIntList,
)
from starknet_py.net.client_models import CasmClass, CasmClassEntryPoint

CASM_CLASS_VERSION = "COMPILED_CLASS_V1"
Expand All @@ -26,7 +33,14 @@ def compute_casm_class_hash(casm_contract_class: CasmClass) -> int:
_entry_points_array(_entry_points.constructor)
)

bytecode_hash = poseidon_hash_many(casm_contract_class.bytecode)
if casm_contract_class.bytecode_segment_lengths is not None:
bytecode_hash = create_bytecode_segment_structure(
bytecode=casm_contract_class.bytecode,
bytecode_segment_lengths=casm_contract_class.bytecode_segment_lengths,
visited_pcs=None,
).hash()
else:
bytecode_hash = poseidon_hash_many(casm_contract_class.bytecode)

return poseidon_hash_many(
[
Expand All @@ -51,3 +65,90 @@ def _entry_points_array(entry_points: List[CasmClassEntryPoint]) -> List[int]:
)

return entry_points_array


# create_bytecode_segment_structure and _create_bytecode_segment_structure_inner are copied from
# https://github.com/starkware-libs/cairo-lang/blob/v0.13.1/src/starkware/starknet/core/os/contract_class/compiled_class_hash.py


def create_bytecode_segment_structure(
bytecode: List[int],
bytecode_segment_lengths: NestedIntList,
visited_pcs: Optional[Sequence[int]],
) -> BytecodeSegmentStructure:
"""
Creates a BytecodeSegmentStructure instance from the given bytecode and
bytecode_segment_lengths.
"""
rev_visited_pcs = list(
visited_pcs if visited_pcs is not None else range(len(bytecode))
)[::-1]

res, total_len = _create_bytecode_segment_structure_inner(
bytecode=bytecode,
bytecode_segment_lengths=bytecode_segment_lengths,
visited_pcs=rev_visited_pcs,
bytecode_offset=0,
)
assert total_len == len(
bytecode
), f"Invalid length bytecode segment structure: {total_len}. Bytecode length: {len(bytecode)}."
assert len(rev_visited_pcs) == 0, f"PC {rev_visited_pcs[-1]} is out of range."
return res


def _create_bytecode_segment_structure_inner(
bytecode: List[int],
bytecode_segment_lengths: NestedIntList,
visited_pcs: List[int],
bytecode_offset: int,
) -> Tuple[BytecodeSegmentStructure, int]:
"""
Helper function for `create_bytecode_segment_structure`.
`visited_pcs` should be given in reverse order, and is consumed by the function.
Returns the BytecodeSegmentStructure and the total length of the processed segment.
"""
if isinstance(bytecode_segment_lengths, int):
segment_end = bytecode_offset + bytecode_segment_lengths

# Remove all the visited PCs that are in the segment.
while len(visited_pcs) > 0 and bytecode_offset <= visited_pcs[-1] < segment_end:
visited_pcs.pop()

return (
BytecodeLeaf(data=bytecode[bytecode_offset:segment_end]),
bytecode_segment_lengths,
)

res = []
total_len = 0
for item in bytecode_segment_lengths:
visited_pc_before = visited_pcs[-1] if len(visited_pcs) > 0 else None

current_structure, item_len = _create_bytecode_segment_structure_inner(
bytecode=bytecode,
bytecode_segment_lengths=item,
visited_pcs=visited_pcs,
bytecode_offset=bytecode_offset,
)

visited_pc_after = visited_pcs[-1] if len(visited_pcs) > 0 else None
is_used = visited_pc_after != visited_pc_before

if is_used and visited_pc_before != bytecode_offset:
raise ValueError(
f"Invalid segment structure: PC {visited_pc_before} was visited, "
f"but the beginning of the segment ({bytecode_offset}) was not."
)

res.append(
BytecodeSegment(
segment_length=item_len,
is_used=is_used,
inner_structure=current_structure,
)
)
bytecode_offset += item_len
total_len += item_len

return BytecodeSegmentedNode(segments=res), total_len
1 change: 1 addition & 0 deletions starknet_py/hash/casm_class_hash_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
("minimal_contract_compiled.casm", 0x186f6c4ca3af40dbcbf3f08f828ab0ee072938aaaedccc74ef3b9840cbd9fb3),
("test_contract_compiled.casm", 0x379b75c0922c68e73f6451b69e8e50b7f0745e6fa3f67ffc0b9608238eeaf45),
("token_bridge_compiled.casm", 0x1d60f20e5dd449af4e6b0d63821cfa95f3469faa942caf78eba2172e2ec3468),
("precompiled/starknet_contract_v2_6.casm", 0x603dd72504d8b0bc54df4f1102fdcf87fc3b2b94750a9083a5876913eec08e4),
],
)
def test_compute_casm_class_hash(casm_contract_class_source, expected_casm_class_hash):
Expand Down
12 changes: 6 additions & 6 deletions starknet_py/hash/class_hash_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

@pytest.mark.parametrize(
"contract_source, expected_class_hash", [
("balance_compiled.json", 0x12177ea61e5791cc068d7ee979b74f60a7205a23404c07440f4892b826147c0),
("map_compiled.json", 0x45dc8f1a90d242f9ebdd07c42301eb16845fbad294f7f9118cce544c16d64b4),
("erc20_compiled.json", 0x528d1ce44f53e888c2259738018e2e77bea9cb97c8b7fc7edab67aa4a880181),
("oz_proxy_compiled.json", 0x3e1526155defb7e26a017e9020e1043cce3c5a9144a9ce497c95648ababbdf1),
("argent_proxy_compiled.json", 0x191295ed4e4bbc63209aaf4d025979f8180fe998c761f616ccd29b5acc8ae1f),
("universal_deployer_compiled.json", 0x1fda6c88607d4edd7881671959cf73fb2172c952910a60f3d01ef0cd63a635),
("balance_compiled.json", 0x7a98eab69a2592ef5d3805990a43525d633ddc42b4d5b2524c7f38b7c59265f),
("map_compiled.json", 0x5eefff2c17c81fb81b1c34d2a9f324e7baf8c3099165b94d037a84b74b6900e),
("erc20_compiled.json", 0x2c709fc176283331897d0c5f113ba64b00e1530c3e91103dcf1b05a056b1aa1),
("oz_proxy_compiled.json", 0x382f95037fa7983ff69465b9d3f7394ce336870631066de682cf547dc1899dd),
("argent_proxy_compiled.json", 0x743aa3636b7c795931e9c4ed56dc57e7edda223a66c09df04fda40f9ba4cd53),
("universal_deployer_compiled.json", 0x710bb1f5ef7f208249a370372a7586b72a759fbd2923013b14bd7f2e51bc4c),
("precompiled/oz_proxy_address_0.8.1_compiled.json", 0x413c36c287cb410d42f9e531563f68ac60a2913b5053608d640fb9b643acfe6),
]
)
Expand Down
111 changes: 111 additions & 0 deletions starknet_py/hash/compiled_class_hash_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# File is copied from
# https://github.com/starkware-libs/cairo-lang/blob/v0.13.1/src/starkware/starknet/core/os/contract_class/compiled_class_hash_objects.py

import dataclasses
import itertools
from abc import ABC, abstractmethod
from typing import Any, List, Union

from poseidon_py.poseidon_hash import poseidon_hash_many


class BytecodeSegmentStructure(ABC):
"""
Represents the structure of the bytecode to allow loading it partially into the OS memory.
See the documentation of the OS function `bytecode_hash_node` in `compiled_class.cairo`
for more details.
"""

@abstractmethod
def hash(self) -> int:
"""
Computes the hash of the node.
"""

def bytecode_with_skipped_segments(self):
"""
Returns the bytecode of the node.
Skipped segments are replaced with [-1, -2, -2, -2, ...].
"""
res: List[int] = []
self.add_bytecode_with_skipped_segments(res)
return res

@abstractmethod
def add_bytecode_with_skipped_segments(self, data: List[int]):
"""
Same as bytecode_with_skipped_segments, but appends the result to the given list.
"""


@dataclasses.dataclass
class BytecodeLeaf(BytecodeSegmentStructure):
"""
Represents a leaf in the bytecode segment tree.
"""

data: List[int]

def hash(self) -> int:
return poseidon_hash_many(self.data)

def add_bytecode_with_skipped_segments(self, data: List[int]):
data.extend(self.data)


@dataclasses.dataclass
class BytecodeSegmentedNode(BytecodeSegmentStructure):
"""
Represents an internal node in the bytecode segment tree.
Each child can be loaded into memory or skipped.
"""

segments: List["BytecodeSegment"]

def hash(self) -> int:
return (
poseidon_hash_many(
itertools.chain( # pyright: ignore
*[
(node.segment_length, node.inner_structure.hash())
for node in self.segments
]
)
)
+ 1
)

def add_bytecode_with_skipped_segments(self, data: List[int]):
for segment in self.segments:
if segment.is_used:
segment.inner_structure.add_bytecode_with_skipped_segments(data)
else:
data.append(-1)
data.extend(-2 for _ in range(segment.segment_length - 1))


@dataclasses.dataclass
class BytecodeSegment:
"""
Represents a child of BytecodeSegmentedNode.
"""

# The length of the segment.
segment_length: int
# Should the segment (or part of it) be loaded to memory.
# In other words, is the segment used during the execution.
# Note that if is_used is False, the entire segment is not loaded to memory.
# If is_used is True, it is possible that part of the segment will be skipped (according
# to the "is_used" field of the child segments).
is_used: bool
# The inner structure of the segment.
inner_structure: BytecodeSegmentStructure

def __post_init__(self):
assert (
self.segment_length > 0
), f"Invalid segment length: {self.segment_length}."


# Represents a nested list of integers. E.g., [1, [2, [3], 4], 5, 6].
NestedIntList = Union[int, List[Any]]
1 change: 1 addition & 0 deletions starknet_py/net/client_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ class CasmClass:
pythonic_hints: List[Any]
compiler_version: str
entry_points_by_type: CasmClassEntryPointsByType
bytecode_segment_lengths: Optional[List[int]]


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions starknet_py/net/schemas/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def make_dataclass(self, data, **kwargs) -> CasmClassEntryPointsByType:
class CasmClassSchema(Schema):
prime = Felt(data_key="prime", required=True)
bytecode = fields.List(Felt(), data_key="bytecode", required=True)
bytecode_segment_lengths = fields.List(
Felt(), data_key="bytecode_segment_lengths", load_default=None
)
hints = fields.List(fields.Raw(), data_key="hints", required=True)
pythonic_hints = fields.List(fields.Raw(), data_key="pythonic_hints", required=True)
compiler_version = fields.String(data_key="compiler_version", required=True)
Expand Down
Loading

0 comments on commit 0a2cbf2

Please sign in to comment.