Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: separate CSM frame log publishing. additional tests #514

Merged
merged 9 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assets/CSFeeOracle.json

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class FrameCheckpointsIterator:
# to check duty of every epoch, we need to check 64 slots (32 slots of duty epoch + 32 slots of next epoch).
# In the end we got 255 committees and 8192 block_roots to check them for every checkpoint.
MAX_CHECKPOINT_STEP = 255
# Delay from last duty epoch to get checkpoint slot
# Delay from last duty epoch to get checkpoint slot.
# Regard to EIP-7045 if we want to process epoch N, we need to get attestation data from epoch N and N + 1.
# To get attestation data block roots for epoch N and N + 1 we need to
# get roots from state checkpoint slot for epoch N + 2. That's why we need the delay from epoch N.
CHECKPOINT_SLOT_DELAY_EPOCHS = 2

def __init__(
Expand Down
67 changes: 44 additions & 23 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from src.modules.csm.types import ReportData, Shares
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.modules.submodules.types import ZERO_HASH
from src.providers.execution.contracts.cs_fee_oracle import CSFeeOracleContract
from src.providers.execution.exceptions import InconsistentData
from src.providers.ipfs import CID
Expand Down Expand Up @@ -95,32 +96,46 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute
def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
self.validate_state(blockstamp)

distributed, shares, log = self.calculate_distribution(blockstamp)
if not distributed:
logger.info({"msg": "No shares distributed in the current frame"})

# Load the previous tree if any.
prev_root = self.w3.csm.get_csm_tree_root(blockstamp)
prev_cid = self.w3.csm.get_csm_tree_cid(blockstamp)

if prev_cid:
if (prev_cid is None) != (prev_root == ZERO_HASH):
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")

distributed, shares, log = self.calculate_distribution(blockstamp)

if distributed != sum(shares.values()):
raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}")

log_cid = self.publish_log(log)

if not distributed and not shares:
logger.info({"msg": "No shares distributed in the current frame"})
return ReportData(
self.report_contract.get_consensus_version(blockstamp.block_hash),
blockstamp.ref_slot,
tree_root=prev_root,
tree_cid=prev_cid or "",
log_cid=log_cid,
distributed=0,
).as_tuple()

if prev_cid and prev_root != ZERO_HASH:
# Update cumulative amount of shares for all operators.
for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root):
shares[no_id] += acc_shares
else:
logger.info({"msg": "No previous CID available"})
logger.info({"msg": "No previous distribution. Nothing to accumulate"})

tree = self.make_tree(shares)
tree_cid: CID | None = None

if tree:
tree_cid = self.publish_tree(tree, log)
tree_cid = self.publish_tree(tree)

return ReportData(
self.report_contract.get_consensus_version(blockstamp.block_hash),
blockstamp.ref_slot,
tree_root=tree.root if tree else prev_root,
tree_cid=tree_cid or prev_cid or "",
tree_root=tree.root,
tree_cid=tree_cid,
log_cid=log_cid,
distributed=distributed,
).as_tuple()

Expand Down Expand Up @@ -161,17 +176,20 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
l_epoch, r_epoch = self.current_frame_range(blockstamp)
logger.info({"msg": f"Frame for performance data collect: epochs [{l_epoch};{r_epoch}]"})

# Finalized slot is the first slot of justifying epoch, so we need to take the previous
finalized_epoch = EpochNumber(converter.get_epoch_by_slot(blockstamp.slot_number) - 1)

report_blockstamp = self.get_blockstamp_for_report(blockstamp)
if report_blockstamp and report_blockstamp.ref_epoch != r_epoch:
epoch = converter.get_epoch_by_slot(blockstamp.slot_number)
logger.warning(
{"msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {epoch}"}
{
"msg": f"Frame has been changed, but the change is not yet observed on finalized epoch {finalized_epoch}"
}
)
return False

# Finalized slot is the first slot of justifying epoch, so we need to take the previous
finalized_epoch = EpochNumber(converter.get_epoch_by_slot(blockstamp.slot_number) - 1)
if l_epoch > finalized_epoch:
logger.info({"msg": "The starting epoch of the frame is not finalized yet"})
return False

self.state.migrate(l_epoch, r_epoch)
Expand Down Expand Up @@ -293,9 +311,9 @@ def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId
)
return stuck

def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree | None:
def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree:
if not shares:
return None
raise ValueError("No shares to build a tree")

# XXX: We put a stone here to make sure, that even with only 1 node operator in the tree, it's still possible to
# claim rewards. The CSModule contract skips pulling rewards if the proof's length is zero, which is the case
Expand All @@ -311,13 +329,16 @@ def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree | None:
logger.info({"msg": "New tree built for the report", "root": repr(tree.root)})
return tree

def publish_tree(self, tree: Tree, log: FramePerfLog) -> CID:
log_cid = self.w3.ipfs.publish(log.encode())
logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)})
tree_cid = self.w3.ipfs.publish(tree.encode({"logCID": log_cid}))
def publish_tree(self, tree: Tree) -> CID:
tree_cid = self.w3.ipfs.publish(tree.encode())
logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)})
return tree_cid

def publish_log(self, log: FramePerfLog) -> CID:
log_cid = self.w3.ipfs.publish(log.encode())
logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)})
return log_cid

@lru_cache(maxsize=1)
def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, EpochNumber]:
converter = self.converter(blockstamp)
Expand Down
18 changes: 5 additions & 13 deletions src/modules/csm/tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from dataclasses import dataclass
from typing import Self, Sequence, TypedDict
from typing import Self, Sequence

from hexbytes import HexBytes
from oz_merkle_tree import Dump, StandardMerkleTree
Expand All @@ -9,14 +9,6 @@
from src.providers.ipfs.cid import CID


class TreeMeta(TypedDict):
logCID: CID


class TreeDump(Dump[RewardTreeLeaf]):
metadata: TreeMeta


class TreeJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, bytes):
Expand Down Expand Up @@ -45,7 +37,7 @@ def decode(cls, content: bytes) -> Self:
except json.JSONDecodeError as e:
raise ValueError("Unsupported tree format") from e

def encode(self, metadata: TreeMeta) -> bytes:
def encode(self) -> bytes:
"""Convert the underlying StandardMerkleTree to a binary representation"""

return (
Expand All @@ -54,12 +46,12 @@ def encode(self, metadata: TreeMeta) -> bytes:
separators=(',', ':'),
sort_keys=True,
)
.encode(self.dump(metadata))
.encode(self.dump())
.encode()
)

def dump(self, metadata: TreeMeta) -> TreeDump:
return {**self.tree.dump(), "metadata": metadata}
def dump(self) -> Dump[RewardTreeLeaf]:
return self.tree.dump()

@classmethod
def new(cls, values: Sequence[RewardTreeLeaf]) -> Self:
Expand Down
4 changes: 3 additions & 1 deletion src/modules/csm/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Literal, TypeAlias
from typing import TypeAlias, Literal

from hexbytes import HexBytes

Expand All @@ -20,6 +20,7 @@ class ReportData:
ref_slot: SlotNumber
tree_root: HexBytes
tree_cid: CID | Literal[""]
log_cid: CID
distributed: int

def as_tuple(self):
Expand All @@ -29,5 +30,6 @@ def as_tuple(self):
self.ref_slot,
self.tree_root,
str(self.tree_cid),
str(self.log_cid),
self.distributed,
)
2 changes: 1 addition & 1 deletion src/web3py/extensions/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_csm_tree_root(self, blockstamp: BlockStamp) -> HexBytes:

def get_csm_tree_cid(self, blockstamp: BlockStamp) -> CID | None:
result = self.fee_distributor.tree_cid(blockstamp.block_hash)
if not result:
if result == "":
return None
return CIDv0(result) if is_cid_v0(result) else CIDv1(result)

Expand Down
Loading
Loading