Skip to content

Commit

Permalink
refactor: simplify performance cache
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed May 2, 2024
1 parent 4d3c13a commit 185ae48
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 216 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,7 @@ dmypy.json

# vim
*.swp

# Cache
*.pkl
*.buf
44 changes: 27 additions & 17 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from threading import Thread, Lock
from typing import Any, Iterable, cast

from src.modules.csm.typings import FramePerformance
from src.modules.csm.state import State
from src.providers.consensus.client import ConsensusClient
from src.typings import EpochNumber, BlockRoot, SlotNumber, BlockStamp
from src.typings import EpochNumber, BlockRoot, SlotNumber, BlockStamp, ValidatorIndex
from src.utils.range import seq
from src.utils.web3converter import Web3Converter

logger = logging.getLogger(__name__)
Expand All @@ -15,7 +16,7 @@
class CheckpointsFactory:
cc: ConsensusClient
converter: Web3Converter
frame_performance: FramePerformance
state: State

# min checkpoint step is 10 because it's a reasonable number of epochs to process at once (~1 hour)
MIN_CHECKPOINT_STEP = 10
Expand All @@ -24,10 +25,10 @@ class CheckpointsFactory:
# in the end we got 255 committees and 8192 block_roots to check them for every checkpoint
MAX_CHECKPOINT_STEP = 255

def __init__(self, cc: ConsensusClient, converter: Web3Converter, frame_performance: FramePerformance):
def __init__(self, cc: ConsensusClient, converter: Web3Converter, state: State):
self.cc = cc
self.converter = converter
self.frame_performance = frame_performance
self.state = state

def prepare_checkpoints(
self,
Expand All @@ -36,7 +37,7 @@ def prepare_checkpoints(
finalized_epoch: EpochNumber
):
def _prepare_checkpoint(_slot: SlotNumber, _duty_epochs: list[EpochNumber]):
return Checkpoint(self.cc, self.converter, self.frame_performance, _slot, _duty_epochs)
return Checkpoint(self.cc, self.converter, self.state, _slot, _duty_epochs)

def _max_in_seq(items: Iterable[Any]) -> Any:
sorted_ = sorted(items)
Expand All @@ -48,7 +49,7 @@ def _max_in_seq(items: Iterable[Any]) -> Any:
item = curr
return item

l_epoch = _max_in_seq((l_epoch, *self.frame_performance.processed_epochs))
l_epoch = _max_in_seq((l_epoch, *self.state.processed_epochs))
processing_delay = finalized_epoch - l_epoch

if l_epoch == r_epoch:
Expand All @@ -59,7 +60,8 @@ def _max_in_seq(items: Iterable[Any]) -> Any:
logger.info({"msg": f"Minimum checkpoint step is not reached, current delay is {processing_delay}"})
return []

duty_epochs = cast(list[EpochNumber], list(range(l_epoch, r_epoch + 1)))
r_epoch = min(r_epoch, EpochNumber(finalized_epoch - 1))
duty_epochs = cast(list[EpochNumber], list(seq(l_epoch, r_epoch)))
checkpoints: list[Checkpoint] = []
checkpoint_epochs = []
for index, epoch in enumerate(duty_epochs, 1):
Expand All @@ -83,7 +85,7 @@ class Checkpoint:
converter: Web3Converter

threads: list[Thread]
frame_performance: FramePerformance
state: State

slot: SlotNumber # last slot of the epoch
duty_epochs: list[EpochNumber] # max 255 elements
Expand All @@ -93,7 +95,7 @@ def __init__(
self,
cc: ConsensusClient,
converter: Web3Converter,
frame_performance: FramePerformance,
state: State,
slot: SlotNumber,
duty_epochs: list[EpochNumber]
):
Expand All @@ -103,15 +105,15 @@ def __init__(
self.duty_epochs = duty_epochs
self.block_roots = []
self.threads = []
self.frame_performance = frame_performance
self.state = state

@property
def free_threads(self):
return self.MAX_THREADS - len(self.threads)

def process(self, last_finalized_blockstamp: BlockStamp):
for duty_epoch in self.duty_epochs:
if duty_epoch in self.frame_performance.processed_epochs:
if duty_epoch in self.state.processed_epochs:
continue
if not self.block_roots:
self._get_block_roots()
Expand Down Expand Up @@ -140,9 +142,9 @@ def _select_roots_to_check(
# inspired by the spec
# https://github.com/ethereum/consensus-specs/blob/dev/specs/phase0/beacon-chain.md#get_block_root_at_slot
roots_to_check = []
slots = range(
slots = seq(
self.converter.get_epoch_first_slot(duty_epoch),
self.converter.get_epoch_last_slot(EpochNumber(duty_epoch + 1)) + 1
self.converter.get_epoch_last_slot(EpochNumber(duty_epoch + 1))
)
for slot_to_check in slots:
# TODO: get the magic number from the CL spec
Expand All @@ -167,16 +169,24 @@ def _process_epoch(
):
logger.info({"msg": f"Process epoch {duty_epoch}"})
start = time.time()
checked_roots = set()
committees = self._prepare_committees(last_finalized_blockstamp, EpochNumber(duty_epoch))
for root in roots_to_check:
if root is None:
continue
slot_data = self.cc.get_block_details_raw(BlockRoot(root))
self._process_attestations(slot_data, committees)
checked_roots.add(root)

with lock:
self.frame_performance.dump(duty_epoch, committees, checked_roots)
for committee in committees.values():
for validator in committee:
self.state.inc(
ValidatorIndex(int(validator['index'])),
included=validator['included'],
)
self.state.processed_epochs.add(duty_epoch)
self.state.commit()
self.state.status()

logger.info({"msg": f"Epoch {duty_epoch} processed in {time.time() - start:.2f} seconds"})

def _prepare_committees(self, last_finalized_blockstamp: BlockStamp, epoch: int) -> dict:
Expand Down
92 changes: 49 additions & 43 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
from src.metrics.prometheus.business import CONTRACT_ON_PAUSE
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import CheckpointsFactory
from src.modules.csm.state import State
from src.modules.csm.tree import Tree
from src.modules.csm.typings import FramePerformance, ReportData
from src.modules.csm.types import ReportData
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.modules.submodules.typings import ZERO_HASH
from src.providers.execution.contracts.CSFeeOracle import CSFeeOracle
from src.typings import BlockStamp, EpochNumber, ReferenceBlockStamp, SlotNumber, ValidatorIndex
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.range import seq
from src.utils.slot import get_first_non_missed_slot
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
Expand All @@ -24,6 +26,10 @@
logger = logging.getLogger(__name__)


class InvalidState(Exception):
...


class CSOracle(BaseModule, ConsensusModule):
"""
CSM performance module collects performance of CSM node operators and creates a Merkle tree of the resulting
Expand All @@ -39,18 +45,17 @@ class CSOracle(BaseModule, ConsensusModule):
CONTRACT_VERSION = 1

report_contract: CSFeeOracle
frame_performance: FramePerformance | None

def __init__(self, w3: Web3):
self.report_contract = w3.csm.oracle
self.frame_performance = None
self.state: State | None = None
super().__init__(w3)

def refresh_contracts(self):
self.report_contract = self.w3.csm.oracle # type: ignore

def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecuteDelay:
collected = self._collect_data(last_finalized_blockstamp)
collected = self.collect_data(last_finalized_blockstamp)
if not collected:
# The data is not fully collected yet, wait for the next epoch
logger.info(
Expand All @@ -69,25 +74,28 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute
@duration_meter()
def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
# pylint: disable=too-many-branches
assert self.state

assert self.frame_performance
assert self.frame_performance.is_coherent

self._print_collect_result()
try:
self.validate_state(blockstamp)
except InvalidState as ex:
raise ValueError("Unable to build report") from ex
self.state.status()

threshold = self.frame_performance.avg_perf * self.w3.csm.oracle.perf_threshold(blockstamp.block_hash)
threshold = self.state.avg_perf * self.w3.csm.oracle.perf_threshold(blockstamp.block_hash)
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
# NOTE: r_block is guaranteed to be <= ref_slot, and the check
# in the inner frames assures the l_block <= r_block.
stuck_operators = self.w3.csm.get_csm_stuck_node_operators(
get_first_non_missed_slot(
self.w3.cc,
self.frame_performance.l_slot,
l_ref_slot,
blockstamp.slot_number,
direction='forward',
).message.body.execution_payload.block_hash,
get_first_non_missed_slot(
self.w3.cc,
self.frame_performance.r_slot,
r_ref_slot,
blockstamp.slot_number,
direction='back',
).message.body.execution_payload.block_hash,
Expand All @@ -104,7 +112,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:

for v in validators:
try:
perf = self.frame_performance.aggr_per_val[ValidatorIndex(int(v.index))].perf
perf = self.state[ValidatorIndex(int(v.index))].perf
if perf > threshold:
portion += 1
except KeyError:
Expand Down Expand Up @@ -191,28 +199,35 @@ def module(self) -> StakingModule:
def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> ValidatorsByNodeOperator:
return self.w3.lido_validators.get_module_validators_by_node_operators(self.module.id, blockstamp)

def _collect_data(self, blockstamp: BlockStamp) -> bool:
def validate_state(self, blockstamp) -> None:
assert self.state
converter = self.converter(blockstamp)
l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
l_epoch = EpochNumber(converter.get_epoch_by_slot(l_ref_slot) + 1)
r_epoch = converter.get_epoch_by_slot(r_ref_slot)
for epoch in self.state.processed_epochs:
if l_epoch <= epoch <= r_epoch:
continue
logger.info({"msg": f"Invalid state: processed {epoch=}, but range is [{l_epoch};{r_epoch}]"})
raise InvalidState

def collect_data(self, blockstamp: BlockStamp) -> bool:
"""Ongoing report data collection before the report ref slot and it's submission"""
logger.info({"msg": "Collecting data for the report"})

l_ref_slot, r_ref_slot = self.current_frame_range(blockstamp)
logger.info({"msg": f"Frame for performance data collect: ({l_ref_slot};{r_ref_slot}]"})

if self.frame_performance:
# If the cache is in memory, its left border should follow up the last ref slot.
assert self.frame_performance.l_slot <= l_ref_slot
# If the frame is extending we can reuse the cache.
if r_ref_slot > self.frame_performance.r_slot:
self.frame_performance.r_slot = r_ref_slot
# If the collected data overlaps the current frame, the cache should be invalidated.
if l_ref_slot > self.frame_performance.l_slot or r_ref_slot < self.frame_performance.r_slot:
self.frame_performance = None

if not self.frame_performance:
self.frame_performance = FramePerformance.try_read(
l_slot=l_ref_slot,
r_slot=r_ref_slot,
)
self.state = self.state or State.load()

try:
self.validate_state(blockstamp)
except InvalidState:
logger.info({"msg": "Discarding invalidated state cache"})
self.state.clear()
self.state.commit()

self.state.status()

converter = self.converter(blockstamp)
# Finalized slot is the first slot of justifying epoch, so we need to take the previous
Expand All @@ -223,21 +238,24 @@ def _collect_data(self, blockstamp: BlockStamp) -> bool:
return False
r_epoch = converter.get_epoch_by_slot(r_ref_slot)

factory = CheckpointsFactory(self.w3.cc, converter, self.frame_performance)
factory = CheckpointsFactory(self.w3.cc, converter, self.state)
checkpoints = factory.prepare_checkpoints(l_epoch, r_epoch, finalized_epoch)

start = time.time()
for checkpoint in checkpoints:
# TODO: Check that we still need to check these checkpoints.
if converter.get_epoch_by_slot(checkpoint.slot) > finalized_epoch:
# checkpoint isn't finalized yet, can't be processed
logger.info({"msg": f"Checkpoint for slot {checkpoint.slot} is not finalized yet"})
break
logger.info({"msg": f"Processing checkpoint for slot {checkpoint.slot}"})
logger.info({"msg": f"Processing {len(checkpoint.duty_epochs)} epochs"})
checkpoint.process(blockstamp)
if checkpoints:
logger.info({"msg": f"All epochs processed in {time.time() - start:.2f} seconds"})
return self.frame_performance.is_coherent

return all(epoch in self.state.processed_epochs for epoch in seq(l_epoch, r_epoch))

@lru_cache(maxsize=1)
def current_frame_range(self, blockstamp: BlockStamp) -> tuple[SlotNumber, SlotNumber]:
l_ref_slot = self.w3.csm.get_csm_last_processing_ref_slot(blockstamp)
r_ref_slot = self.get_current_frame(blockstamp).ref_slot
Expand All @@ -263,18 +281,6 @@ def current_frame_range(self, blockstamp: BlockStamp) -> tuple[SlotNumber, SlotN
def converter(self, blockstamp: BlockStamp) -> Web3Converter:
return Web3Converter(self.get_chain_config(blockstamp), self.get_frame_config(blockstamp))

def _print_collect_result(self):
assert self.frame_performance
assigned = 0
inc = 0
for _, aggr in self.frame_performance.aggr_per_val.items():
assigned += aggr.assigned
inc += aggr.included

logger.info({"msg": f"Assigned: {assigned}"})
logger.info({"msg": f"Included: {inc}"})
logger.info({"msg": f"Average performance: {self.frame_performance.avg_perf}"})

def _slot_to_block_identifier(self, slot: SlotNumber) -> BlockIdentifier:
block = self.w3.cc.get_block_details(slot)

Expand Down
Loading

0 comments on commit 185ae48

Please sign in to comment.