Skip to content

Commit

Permalink
fix: re-refactoring, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Sep 18, 2024
1 parent 5b527fe commit 6e526f3
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 81 deletions.
51 changes: 29 additions & 22 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,37 +95,42 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute
def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
self.validate_state(blockstamp)

distributed, shares, log = self.calculate_distribution(blockstamp)

# Load the previous tree if any.
prev_root = self.w3.csm.get_csm_tree_root(blockstamp)
prev_cid = self.w3.csm.get_csm_tree_cid(blockstamp)

if prev_cid:
if bool(prev_root) != bool(prev_cid):
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")

distributed, shares, log = self.calculate_distribution(blockstamp)

log_cid = self.publish_log(log)

if not distributed:
logger.info({"msg": "No shares distributed in the current frame"})
return ReportData(
self.report_contract.get_consensus_version(blockstamp.block_hash),
blockstamp.ref_slot,
tree_root=prev_root or HexBytes(32),
tree_cid=prev_cid or "",
log_cid=log_cid,
distributed=0,
).as_tuple()

if prev_cid and prev_root:
# Update cumulative amount of shares for all operators.
for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root):
shares[no_id] += acc_shares
else:
logger.info({"msg": "No previous CID available"})

if distributed > 0:
curr_tree = self.make_tree(shares)
if not curr_tree:
raise InconsistentData("No tree to publish but shares are distributed")
report_tree_root = curr_tree.root
report_tree_cid = self.publish_tree(curr_tree)
else:
logger.info({"msg": "No shares distributed in the current frame"})
report_tree_root = prev_root
report_tree_cid = prev_cid
logger.info({"msg": "No previous distribution. Nothing to accumulate"})

log_cid = self.publish_log(log)
tree = self.make_tree(shares)
tree_cid = self.publish_tree(tree)

return ReportData(
self.report_contract.get_consensus_version(blockstamp.block_hash),
blockstamp.ref_slot,
tree_root=report_tree_root,
tree_cid=report_tree_cid,
tree_root=tree.root,
tree_cid=tree_cid,
log_cid=log_cid,
distributed=distributed,
).as_tuple()
Expand Down Expand Up @@ -173,7 +178,9 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
report_blockstamp = self.get_blockstamp_for_report(blockstamp)
if report_blockstamp and report_blockstamp.ref_epoch != r_epoch:
logger.warning(
{"msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}"}
{
"msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}"
}
)
return False

Expand Down Expand Up @@ -300,9 +307,9 @@ def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId
)
return stuck

def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree | None:
def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree:
if not shares:
return None
raise ValueError("No shares to build a tree")

# XXX: We put a stone here to make sure, that even with only 1 node operator in the tree, it's still possible to
# claim rewards. The CSModule contract skips pulling rewards if the proof's length is zero, which is the case
Expand Down
4 changes: 2 additions & 2 deletions src/modules/csm/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import TypeAlias
from typing import TypeAlias, Literal

from hexbytes import HexBytes

Expand All @@ -19,7 +19,7 @@ class ReportData:
consensusVersion: int
ref_slot: SlotNumber
tree_root: HexBytes
tree_cid: CID
tree_cid: CID | Literal[""]
log_cid: CID
distributed: int

Expand Down
11 changes: 8 additions & 3 deletions src/web3py/extensions/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ def get_csm_last_processing_ref_slot(self, blockstamp: BlockStamp) -> SlotNumber
FRAME_PREV_REPORT_REF_SLOT.labels("csm_oracle").set(result)
return result

def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes:
return self.fee_distributor.tree_root(blockstamp.block_hash)
def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes | None:
result = self.fee_distributor.tree_root(blockstamp.block_hash)
if result == HexBytes(32):
return None
return result

def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID:
def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID | None:
result = self.fee_distributor.tree_cid(blockstamp.block_hash)
if result == "":
return None
return CIDv0(result) if is_cid_v0(result) else CIDv1(result)

def get_operators_with_stucks_in_range(
Expand Down
3 changes: 0 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ def keys_api_client(request, responses_path, web3):
@pytest.fixture()
def csm(web3):
mock = Mock()
mock.module.MAX_OPERATORS_COUNT = UINT64_MAX
web3.ipfs = Mock()
web3.lido_contracts = Mock()
web3.attach_modules({"csm": lambda: mock})
return mock

Expand Down
Loading

0 comments on commit 6e526f3

Please sign in to comment.