diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 7ee52addd..7855a5ed8 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -95,37 +95,42 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: self.validate_state(blockstamp) - distributed, shares, log = self.calculate_distribution(blockstamp) - - # Load the previous tree if any. prev_root = self.w3.csm.get_csm_tree_root(blockstamp) prev_cid = self.w3.csm.get_csm_tree_cid(blockstamp) - if prev_cid: + if bool(prev_root) != bool(prev_cid): + raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}") + + distributed, shares, log = self.calculate_distribution(blockstamp) + + log_cid = self.publish_log(log) + + if not distributed: + logger.info({"msg": "No shares distributed in the current frame"}) + return ReportData( + self.report_contract.get_consensus_version(blockstamp.block_hash), + blockstamp.ref_slot, + tree_root=prev_root or HexBytes(32), + tree_cid=prev_cid or "", + log_cid=log_cid, + distributed=0, + ).as_tuple() + + if prev_cid and prev_root: # Update cumulative amount of shares for all operators. for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root): shares[no_id] += acc_shares else: - logger.info({"msg": "No previous CID available"}) - - if distributed > 0: - curr_tree = self.make_tree(shares) - if not curr_tree: - raise InconsistentData("No tree to publish but shares are distributed") - report_tree_root = curr_tree.root - report_tree_cid = self.publish_tree(curr_tree) - else: - logger.info({"msg": "No shares distributed in the current frame"}) - report_tree_root = prev_root - report_tree_cid = prev_cid + logger.info({"msg": "No previous distribution. Nothing to accumulate"}) - log_cid = self.publish_log(log) + tree = self.make_tree(shares) + tree_cid = self.publish_tree(tree) return ReportData( self.report_contract.get_consensus_version(blockstamp.block_hash), blockstamp.ref_slot, - tree_root=report_tree_root, - tree_cid=report_tree_cid, + tree_root=tree.root, + tree_cid=tree_cid, log_cid=log_cid, distributed=distributed, ).as_tuple() @@ -173,7 +178,9 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: report_blockstamp = self.get_blockstamp_for_report(blockstamp) if report_blockstamp and report_blockstamp.ref_epoch != r_epoch: logger.warning( - {"msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}"} + { + "msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}" + } ) return False @@ -300,9 +307,9 @@ def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId ) return stuck - def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree | None: + def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree: if not shares: - return None + raise ValueError("No shares to build a tree") # XXX: We put a stone here to make sure, that even with only 1 node operator in the tree, it's still possible to # claim rewards. The CSModule contract skips pulling rewards if the proof's length is zero, which is the case diff --git a/src/modules/csm/types.py b/src/modules/csm/types.py index d0235b852..541a0fc5f 100644 --- a/src/modules/csm/types.py +++ b/src/modules/csm/types.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import TypeAlias +from typing import TypeAlias, Literal from hexbytes import HexBytes @@ -19,7 +19,7 @@ class ReportData: consensusVersion: int ref_slot: SlotNumber tree_root: HexBytes - tree_cid: CID + tree_cid: CID | Literal[""] log_cid: CID distributed: int diff --git a/src/web3py/extensions/csm.py b/src/web3py/extensions/csm.py index 637ef8796..1baa7964e 100644 --- a/src/web3py/extensions/csm.py +++ b/src/web3py/extensions/csm.py @@ -43,11 +43,16 @@ def get_csm_last_processing_ref_slot(self, blockstamp: BlockStamp) -> SlotNumber FRAME_PREV_REPORT_REF_SLOT.labels("csm_oracle").set(result) return result - def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes: - return self.fee_distributor.tree_root(blockstamp.block_hash) + def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes | None: + result = self.fee_distributor.tree_root(blockstamp.block_hash) + if result == HexBytes(32): + return None + return result - def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID: + def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID | None: result = self.fee_distributor.tree_cid(blockstamp.block_hash) + if result == "": + return None return CIDv0(result) if is_cid_v0(result) else CIDv1(result) def get_operators_with_stucks_in_range( diff --git a/tests/conftest.py b/tests/conftest.py index ee4010d71..99b965314 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -114,9 +114,6 @@ def keys_api_client(request, responses_path, web3): @pytest.fixture() def csm(web3): mock = Mock() - mock.module.MAX_OPERATORS_COUNT = UINT64_MAX - web3.ipfs = Mock() - web3.lido_contracts = Mock() web3.attach_modules({"csm": lambda: mock}) return mock diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index 685123279..6a015e00d 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -1,7 +1,7 @@ import logging from collections import defaultdict from dataclasses import dataclass -from typing import NoReturn, Iterable +from typing import NoReturn, Iterable, Literal, Type from unittest.mock import Mock, patch, PropertyMock import pytest @@ -13,7 +13,7 @@ from src.modules.csm.tree import Tree from src.modules.submodules.oracle_module import ModuleExecuteDelay from src.modules.submodules.types import CurrentFrame -from src.providers.ipfs import CIDv0 +from src.providers.ipfs import CIDv0, CID from src.types import EpochNumber, NodeOperatorId, SlotNumber, StakingModuleId, ValidatorIndex from src.web3py.extensions.csm import CSM from tests.factory.blockstamp import ReferenceBlockStampFactory @@ -215,7 +215,7 @@ class FrameTestParam: last_processing_ref_slot: int current_ref_slot: int finalized_slot: int - expected_frame: tuple[int, int] + expected_frame: tuple[int, int] | Type[ValueError] @pytest.mark.parametrize( @@ -228,10 +228,9 @@ class FrameTestParam: last_processing_ref_slot=0, current_ref_slot=0, finalized_slot=0, - expected_frame=(0, 0), + expected_frame=ValueError, ), id="initial_epoch_not_set", - marks=pytest.mark.xfail(raises=ValueError), ), pytest.param( FrameTestParam( @@ -329,10 +328,15 @@ def test_current_frame_range(module: CSOracle, csm: CSM, mock_chain_config: NoRe ) ) module.get_initial_ref_slot = Mock(return_value=param.initial_ref_slot) - bs = ReferenceBlockStampFactory.build(slot_number=param.finalized_slot) - l_epoch, r_epoch = module.current_frame_range(bs) - assert (l_epoch, r_epoch) == param.expected_frame + if param.expected_frame is ValueError: + with pytest.raises(ValueError): + module.current_frame_range(ReferenceBlockStampFactory.build(slot_number=param.finalized_slot)) + else: + bs = ReferenceBlockStampFactory.build(slot_number=param.finalized_slot) + + l_epoch, r_epoch = module.current_frame_range(bs) + assert (l_epoch, r_epoch) == param.expected_frame @pytest.fixture() @@ -502,13 +506,13 @@ def test_collect_data_fulfilled_state( @dataclass(frozen=True) class BuildReportTestParam: - prev_root: str - prev_cid: str + prev_tree_root: HexBytes | None + prev_tree_cid: CID | None prev_acc_shares: Iterable[tuple[NodeOperatorId, int]] curr_distribution: Mock - curr_root: str - curr_cid: str - curr_log: str + curr_tree_root: HexBytes + curr_tree_cid: CID | Literal[""] + curr_log_cid: CID expected_make_tree_call_args: tuple | None expected_func_result: tuple @@ -518,8 +522,8 @@ class BuildReportTestParam: [ pytest.param( BuildReportTestParam( - prev_root="", - prev_cid="", + prev_tree_root=None, + prev_tree_cid=None, prev_acc_shares=[], curr_distribution=Mock( return_value=( @@ -531,18 +535,18 @@ class BuildReportTestParam: Mock(), ) ), - curr_root="", - curr_cid="", - curr_log="Qm1337", + curr_tree_root=HexBytes(32), + curr_tree_cid="", + curr_log_cid=CID("QmLOG"), expected_make_tree_call_args=None, - expected_func_result=(1, 100500, "", "", "Qm1337", 0), + expected_func_result=(1, 100500, HexBytes(32), "", CID("QmLOG"), 0), ), id="empty_prev_report_and_no_new_distribution", ), pytest.param( BuildReportTestParam( - prev_root="", - prev_cid="", + prev_tree_root=None, + prev_tree_cid=None, prev_acc_shares=[], curr_distribution=Mock( return_value=( @@ -554,18 +558,25 @@ class BuildReportTestParam: Mock(), ) ), - curr_root="0x100e", - curr_cid="0x100c", - curr_log="Qm1337", + curr_tree_root=HexBytes("NEW_TREE_ROOT".encode()), + curr_tree_cid=CID("QmNEW_TREE"), + curr_log_cid=CID("QmLOG"), expected_make_tree_call_args=(({NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3},),), - expected_func_result=(1, 100500, "0x100e", "0x100c", "Qm1337", 6), + expected_func_result=( + 1, + 100500, + HexBytes("NEW_TREE_ROOT".encode()), + CID("QmNEW_TREE"), + CID("QmLOG"), + 6, + ), ), id="empty_prev_report_and_new_distribution", ), pytest.param( BuildReportTestParam( - prev_root="0x100e", - prev_cid="0x100c", + prev_tree_root=HexBytes("OLD_TREE_ROOT".encode()), + prev_tree_cid=CID("QmOLD_TREE"), prev_acc_shares=[(NodeOperatorId(0), 100), (NodeOperatorId(1), 200), (NodeOperatorId(2), 300)], curr_distribution=Mock( return_value=( @@ -577,20 +588,27 @@ class BuildReportTestParam: Mock(), ) ), - curr_root="0x101e", - curr_cid="0x101c", - curr_log="Qm1337", + curr_tree_root=HexBytes("NEW_TREE_ROOT".encode()), + curr_tree_cid=CID("QmNEW_TREE"), + curr_log_cid=CID("QmLOG"), expected_make_tree_call_args=( ({NodeOperatorId(0): 101, NodeOperatorId(1): 202, NodeOperatorId(2): 300, NodeOperatorId(3): 3},), ), - expected_func_result=(1, 100500, "0x101e", "0x101c", "Qm1337", 6), + expected_func_result=( + 1, + 100500, + HexBytes("NEW_TREE_ROOT".encode()), + CID("QmNEW_TREE"), + CID("QmLOG"), + 6, + ), ), id="non_empty_prev_report_and_new_distribution", ), pytest.param( BuildReportTestParam( - prev_root="0x100e", - prev_cid="0x100c", + prev_tree_root=HexBytes("OLD_TREE_ROOT".encode()), + prev_tree_cid=CID("QmOLD_TREE"), prev_acc_shares=[(NodeOperatorId(0), 100), (NodeOperatorId(1), 200), (NodeOperatorId(2), 300)], curr_distribution=Mock( return_value=( @@ -602,28 +620,35 @@ class BuildReportTestParam: Mock(), ) ), - curr_root="", - curr_cid="", - curr_log="Qm1337", + curr_tree_root=HexBytes(32), + curr_tree_cid="", + curr_log_cid=CID("QmLOG"), expected_make_tree_call_args=None, - expected_func_result=(1, 100500, "0x100e", "0x100c", "Qm1337", 0), + expected_func_result=( + 1, + 100500, + HexBytes("OLD_TREE_ROOT".encode()), + CID("QmOLD_TREE"), + CID("QmLOG"), + 0, + ), ), id="non_empty_prev_report_and_no_new_distribution", ), ], ) -def test_build_report(module: CSOracle, param: BuildReportTestParam): +def test_build_report(csm: CSM, module: CSOracle, param: BuildReportTestParam): module.validate_state = Mock() module.report_contract.get_consensus_version = Mock(return_value=1) # mock previous report - module.w3.csm.get_csm_tree_root = Mock(return_value=param.prev_root) - module.w3.csm.get_csm_tree_cid = Mock(return_value=param.prev_cid) + module.w3.csm.get_csm_tree_root = Mock(return_value=param.prev_tree_root) + module.w3.csm.get_csm_tree_cid = Mock(return_value=param.prev_tree_cid) module.get_accumulated_shares = Mock(return_value=param.prev_acc_shares) # mock current frame module.calculate_distribution = param.curr_distribution - module.make_tree = Mock(return_value=Mock(root=param.curr_root)) - module.publish_tree = Mock(return_value=param.curr_cid) - module.publish_log = Mock(return_value=param.curr_log) + module.make_tree = Mock(return_value=Mock(root=param.curr_tree_root)) + module.publish_tree = Mock(return_value=param.curr_tree_cid) + module.publish_log = Mock(return_value=param.curr_log_cid) report = module.build_report(blockstamp=Mock(ref_slot=100500)) @@ -675,7 +700,7 @@ def tree(): def test_get_accumulated_shares(module: CSOracle, tree: Tree): encoded_tree = tree.encode() - module.w3.ipfs.fetch = Mock(return_value=encoded_tree) + module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) for i, leaf in enumerate(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=tree.root)): assert tuple(leaf) == tree.tree.values[i]["value"] @@ -683,7 +708,7 @@ def test_get_accumulated_shares(module: CSOracle, tree: Tree): def test_get_accumulated_shares_unexpected_root(module: CSOracle, tree: Tree): encoded_tree = tree.encode() - module.w3.ipfs.fetch = Mock(return_value=encoded_tree) + module.w3.ipfs = Mock(fetch=Mock(return_value=encoded_tree)) with pytest.raises(ValueError): next(module.get_accumulated_shares(cid=CIDv0("0x100500"), root=HexBytes("0x100500"))) @@ -692,13 +717,13 @@ def test_get_accumulated_shares_unexpected_root(module: CSOracle, tree: Tree): @dataclass(frozen=True) class MakeTreeTestParam: shares: dict[NodeOperatorId, int] - expected_tree_values: tuple | None + expected_tree_values: tuple | Type[ValueError] @pytest.mark.parametrize( "param", [ - pytest.param(MakeTreeTestParam(shares={}, expected_tree_values=None), id="empty"), + pytest.param(MakeTreeTestParam(shares={}, expected_tree_values=ValueError), id="empty"), pytest.param( MakeTreeTestParam( shares={NodeOperatorId(0): 1, NodeOperatorId(1): 2, NodeOperatorId(2): 3}, @@ -739,8 +764,11 @@ class MakeTreeTestParam: ], ) def test_make_tree(module: CSOracle, param: MakeTreeTestParam): - tree = module.make_tree(param.shares) - if param.expected_tree_values is not None: - assert tree.tree.values == param.expected_tree_values + module.w3.csm.module.MAX_OPERATORS_COUNT = UINT64_MAX + + if param.expected_tree_values is ValueError: + with pytest.raises(ValueError): + module.make_tree(param.shares) else: - assert not tree + tree = module.make_tree(param.shares) + assert tree.tree.values == param.expected_tree_values