Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tests for csm build report #516

Merged
merged 11 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
self.validate_state(blockstamp)

distributed, shares, log = self.calculate_distribution(blockstamp)
if not distributed:
logger.info({"msg": "No shares distributed in the current frame"})

# Load the previous tree if any.
prev_root = self.w3.csm.get_csm_tree_root(blockstamp)
Expand All @@ -110,18 +108,24 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
else:
logger.info({"msg": "No previous CID available"})

tree = self.make_tree(shares)
tree_cid: CID | None = None
if distributed > 0:
curr_tree = self.make_tree(shares)
if not curr_tree:
raise InconsistentData("No tree to publish but shares are distributed")
madlabman marked this conversation as resolved.
Show resolved Hide resolved
report_tree_root = curr_tree.root
report_tree_cid = self.publish_tree(curr_tree)
else:
logger.info({"msg": "No shares distributed in the current frame"})
report_tree_root = prev_root
report_tree_cid = prev_cid

log_cid = self.publish_log(log)
if tree:
tree_cid = self.publish_tree(tree)

return ReportData(
self.report_contract.get_consensus_version(blockstamp.block_hash),
blockstamp.ref_slot,
tree_root=tree.root if tree else prev_root,
tree_cid=tree_cid or prev_cid or "",
tree_root=report_tree_root,
tree_cid=report_tree_cid,
log_cid=log_cid,
distributed=distributed,
).as_tuple()
Expand Down
4 changes: 2 additions & 2 deletions src/modules/csm/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Literal, TypeAlias
from typing import TypeAlias

from hexbytes import HexBytes

Expand All @@ -19,7 +19,7 @@ class ReportData:
consensusVersion: int
ref_slot: SlotNumber
tree_root: HexBytes
tree_cid: CID | Literal[""]
tree_cid: CID
madlabman marked this conversation as resolved.
Show resolved Hide resolved
log_cid: CID
distributed: int

Expand Down
4 changes: 1 addition & 3 deletions src/web3py/extensions/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ def get_csm_last_processing_ref_slot(self, blockstamp: BlockStamp) -> SlotNumber
def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes:
return self.fee_distributor.tree_root(blockstamp.block_hash)

def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID | None:
def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID:
result = self.fee_distributor.tree_cid(blockstamp.block_hash)
if not result:
return None
madlabman marked this conversation as resolved.
Show resolved Hide resolved
return CIDv0(result) if is_cid_v0(result) else CIDv1(result)

def get_operators_with_stucks_in_range(
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from web3.types import Timestamp

import src.variables
from src.constants import UINT64_MAX
from src.types import BlockNumber, EpochNumber, ReferenceBlockStamp, SlotNumber
from src.variables import CONSENSUS_CLIENT_URI, EXECUTION_CLIENT_URI, KEYS_API_URI
from src.web3py.contract_tweak import tweak_w3_contracts
Expand Down Expand Up @@ -113,6 +114,9 @@ def keys_api_client(request, responses_path, web3):
@pytest.fixture()
def csm(web3):
mock = Mock()
mock.module.MAX_OPERATORS_COUNT = UINT64_MAX
web3.ipfs = Mock()
web3.lido_contracts = Mock()
madlabman marked this conversation as resolved.
Show resolved Hide resolved
web3.attach_modules({"csm": lambda: mock})
return mock

Expand Down
253 changes: 252 additions & 1 deletion tests/modules/csm/test_csm_module.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import NoReturn
from typing import NoReturn, Iterable
from unittest.mock import Mock, patch, PropertyMock

import pytest
from hexbytes import HexBytes

from src.constants import UINT64_MAX
from src.modules.csm.csm import CSOracle
from src.modules.csm.state import AttestationsAccumulator, State
from src.modules.csm.tree import Tree
from src.modules.submodules.oracle_module import ModuleExecuteDelay
from src.modules.submodules.types import CurrentFrame
from src.providers.ipfs import CIDv0
from src.types import EpochNumber, NodeOperatorId, SlotNumber, StakingModuleId, ValidatorIndex
from src.web3py.extensions.csm import CSM
from tests.factory.blockstamp import ReferenceBlockStampFactory
Expand Down Expand Up @@ -493,3 +498,249 @@ def test_collect_data_fulfilled_state(
# assert that it is not early return from function
msg = list(filter(lambda log: "All epochs are already processed. Nothing to collect" in log, caplog.messages))
assert len(msg) == 0, "Unexpected message found in logs"


@dataclass(frozen=True)
class BuildReportTestParam:
prev_root: str
prev_cid: str
prev_acc_shares: Iterable[tuple[NodeOperatorId, int]]
curr_distribution: Mock
curr_root: str
curr_cid: str
curr_log: str
madlabman marked this conversation as resolved.
Show resolved Hide resolved
expected_make_tree_call_args: tuple | None
expected_func_result: tuple


@pytest.mark.parametrize(
"param",
[
pytest.param(
BuildReportTestParam(
prev_root="",
prev_cid="",
prev_acc_shares=[],
curr_distribution=Mock(
return_value=(
# distributed
0,
# shares
defaultdict(int),
# log
Mock(),
)
),
curr_root="",
curr_cid="",
curr_log="Qm1337",
expected_make_tree_call_args=None,
expected_func_result=(1, 100500, "", "", "Qm1337", 0),
),
id="empty_prev_report_and_no_new_distribution",
),
pytest.param(
BuildReportTestParam(
prev_root="",
prev_cid="",
prev_acc_shares=[],
curr_distribution=Mock(
return_value=(
# distributed
6,
# shares
defaultdict(int, {NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3}),
# log
Mock(),
)
),
curr_root="0x100e",
curr_cid="0x100c",
madlabman marked this conversation as resolved.
Show resolved Hide resolved
curr_log="Qm1337",
expected_make_tree_call_args=(({NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3},),),
expected_func_result=(1, 100500, "0x100e", "0x100c", "Qm1337", 6),
),
id="empty_prev_report_and_new_distribution",
),
pytest.param(
BuildReportTestParam(
prev_root="0x100e",
prev_cid="0x100c",
prev_acc_shares=[(NodeOperatorId(0), 100), (NodeOperatorId(1), 200), (NodeOperatorId(2), 300)],
curr_distribution=Mock(
return_value=(
# distributed
6,
# shares
defaultdict(int, {NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(3): 3}),
# log
Mock(),
)
),
curr_root="0x101e",
curr_cid="0x101c",
curr_log="Qm1337",
expected_make_tree_call_args=(
({NodeOperatorId(0): 101, NodeOperatorId(1): 202, NodeOperatorId(2): 300, NodeOperatorId(3): 3},),
),
expected_func_result=(1, 100500, "0x101e", "0x101c", "Qm1337", 6),
),
id="non_empty_prev_report_and_new_distribution",
),
pytest.param(
BuildReportTestParam(
prev_root="0x100e",
prev_cid="0x100c",
prev_acc_shares=[(NodeOperatorId(0), 100), (NodeOperatorId(1), 200), (NodeOperatorId(2), 300)],
curr_distribution=Mock(
return_value=(
# distributed
0,
# shares
defaultdict(int),
# log
Mock(),
)
),
curr_root="",
curr_cid="",
curr_log="Qm1337",
expected_make_tree_call_args=None,
expected_func_result=(1, 100500, "0x100e", "0x100c", "Qm1337", 0),
),
id="non_empty_prev_report_and_no_new_distribution",
),
],
)
def test_build_report(module: CSOracle, param: BuildReportTestParam):
module.validate_state = Mock()
module.report_contract.get_consensus_version = Mock(return_value=1)
# mock previous report
module.w3.csm.get_csm_tree_root = Mock(return_value=param.prev_root)
module.w3.csm.get_csm_tree_cid = Mock(return_value=param.prev_cid)
module.get_accumulated_shares = Mock(return_value=param.prev_acc_shares)
# mock current frame
module.calculate_distribution = param.curr_distribution
module.make_tree = Mock(return_value=Mock(root=param.curr_root))
module.publish_tree = Mock(return_value=param.curr_cid)
module.publish_log = Mock(return_value=param.curr_log)

report = module.build_report(blockstamp=Mock(ref_slot=100500))

assert module.make_tree.call_args == param.expected_make_tree_call_args
assert report == param.expected_func_result


def test_execute_module_not_collected(module: CSOracle):
module.collect_data = Mock(return_value=False)

execute_delay = module.execute_module(
last_finalized_blockstamp=Mock(slot_number=100500),
)
assert execute_delay is ModuleExecuteDelay.NEXT_FINALIZED_EPOCH


def test_execute_module_no_report_blockstamp(module: CSOracle):
module.collect_data = Mock(return_value=True)
module.get_blockstamp_for_report = Mock(return_value=None)

execute_delay = module.execute_module(
last_finalized_blockstamp=Mock(slot_number=100500),
)
assert execute_delay is ModuleExecuteDelay.NEXT_FINALIZED_EPOCH


def test_execute_module_processed(module: CSOracle):
module.collect_data = Mock(return_value=True)
module.get_blockstamp_for_report = Mock(return_value=Mock(slot_number=100500))
module.process_report = Mock()

execute_delay = module.execute_module(
last_finalized_blockstamp=Mock(slot_number=100500),
)
assert execute_delay is ModuleExecuteDelay.NEXT_SLOT


@pytest.fixture()
def tree():
return Tree.new(
[
(NodeOperatorId(0), 0),
(NodeOperatorId(1), 1),
(NodeOperatorId(2), 42),
(NodeOperatorId(UINT64_MAX), 0),
]
)


def test_get_accumulated_shares(module: CSOracle, tree: Tree):
encoded_tree = tree.encode()
module.w3.ipfs.fetch = Mock(return_value=encoded_tree)

for i, leaf in enumerate(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=tree.root)):
assert tuple(leaf) == tree.tree.values[i]["value"]


def test_get_accumulated_shares_unexpected_root(module: CSOracle, tree: Tree):
encoded_tree = tree.encode()
module.w3.ipfs.fetch = Mock(return_value=encoded_tree)

with pytest.raises(ValueError):
next(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=HexBytes("0x100500")))


@dataclass(frozen=True)
class MakeTreeTestParam:
shares: dict[NodeOperatorId, int]
expected_tree_values: tuple | None


@pytest.mark.parametrize(
"param",
[
pytest.param(MakeTreeTestParam(shares={}, expected_tree_values=None), id="empty"),
pytest.param(
MakeTreeTestParam(
shares={NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3},
expected_tree_values=(
{'treeIndex': 4, 'value': (0, 1)},
{'treeIndex': 2, 'value': (1, 2)},
{'treeIndex': 3, 'value': (2, 3)},
),
),
id="normal_tree",
),
pytest.param(
MakeTreeTestParam(
shares={NodeOperatorId(0): 1},
expected_tree_values=(
{'treeIndex': 2, 'value': (0, 1)},
{'treeIndex': 1, 'value': (18446744073709551615, 0)},
),
),
id="put_stone",
),
pytest.param(
MakeTreeTestParam(
shares={
NodeOperatorId(0): 1,
NodeOperatorId(1): 2,
NodeOperatorId(2): 3,
NodeOperatorId(18446744073709551615): 0,
},
expected_tree_values=(
{'treeIndex': 4, 'value': (0, 1)},
{'treeIndex': 2, 'value': (1, 2)},
{'treeIndex': 3, 'value': (2, 3)},
),
),
id="remove_stone",
),
],
)
def test_make_tree(module: CSOracle, param: MakeTreeTestParam):
tree = module.make_tree(param.shares)
if param.expected_tree_values is not None:
assert tree.tree.values == param.expected_tree_values
else:
assert not tree
Loading