diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index 50507c605..c833df3be 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -14,105 +14,155 @@ from src.utils.range import sequence from src.utils.web3converter import Web3Converter +from dataclasses import dataclass + logger = logging.getLogger(__name__) lock = Lock() -class CheckpointsFactory: - cc: ConsensusClient +@dataclass +class Checkpoint: + slot: SlotNumber # last slot of the epoch + duty_epochs: list[EpochNumber] # max 255 elements + + +class CheckpointsIterator: converter: Web3Converter - state: State - # min checkpoint step is 10 because it's a reasonable number of epochs to process at once (~1 hour) + l_epoch: EpochNumber + r_epoch: EpochNumber + + checkpoints: list[Checkpoint] + + # Min checkpoint step is 10 because it's a reasonable number of epochs to process at once (~1 hour) MIN_CHECKPOINT_STEP = 10 - # max checkpoint step is 255 epochs because block_roots size from state is 8192 slots (256 epochs) - # to check duty of every epoch, we need to check 64 slots (32 slots of duty epoch + 32 slots of next epoch) - # in the end we got 255 committees and 8192 block_roots to check them for every checkpoint + # Max checkpoint step is 255 epochs because block_roots size from state is 8192 slots (256 epochs) + # to check duty of every epoch, we need to check 64 slots (32 slots of duty epoch + 32 slots of next epoch). + # In the end we got 255 committees and 8192 block_roots to check them for every checkpoint. MAX_CHECKPOINT_STEP = 255 - def __init__(self, cc: ConsensusClient, converter: Web3Converter, state: State): - self.cc = cc + def __init__( + self, converter: Web3Converter, l_epoch: EpochNumber, r_epoch: EpochNumber, finalized_epoch: EpochNumber + ): + assert l_epoch <= r_epoch self.converter = converter - self.state = state + self.l_epoch = l_epoch + self.r_epoch = self._get_adjusted_r_epoch(r_epoch, finalized_epoch) - def prepare_checkpoints(self, l_epoch: EpochNumber, r_epoch: EpochNumber, finalized_epoch: EpochNumber): - def _prepare_checkpoint(_slot: SlotNumber, _duty_epochs: list[EpochNumber]): - return Checkpoint(self.cc, self.converter, self.state, _slot, _duty_epochs) + def __iter__(self): + self.checkpoints = [] - if not self.state.unprocessed_epochs: - logger.info({"msg": "All epochs processed. No checkpoint required."}) - return [] + duty_epochs = cast(list[EpochNumber], list(sequence(self.l_epoch, self.r_epoch))) - l_epoch = min(self.state.unprocessed_epochs) or l_epoch - assert l_epoch <= r_epoch + checkpoint_epochs = [] + for index, epoch in enumerate(duty_epochs, 1): + checkpoint_epochs.append(epoch) + if index % self.MAX_CHECKPOINT_STEP == 0 or epoch == self.r_epoch: + # We need to get the last slot of the next epoch to get fit to + # 8192 roots in `checkpoint_slot` state block_roots to check duties in every epoch in checkpoint. + # To check duties in the current epoch you need to + # get 32 slots of the current epoch and 32 slots of the next epoch. + checkpoint_slot = self.converter.get_epoch_last_slot(EpochNumber(epoch + 1)) + logger.info( + {"msg": f"Checkpoint slot {checkpoint_slot} with {len(checkpoint_epochs)} duty epochs is prepared"} + ) + self.checkpoints.append(Checkpoint(checkpoint_slot, checkpoint_epochs)) + checkpoint_epochs = [] + logger.info({"msg": f"Checkpoints to process: {len(self.checkpoints)}"}) + return self + + def __next__(self): + if not self.checkpoints: + raise StopIteration + return self.checkpoints.pop(0) - processing_delay = finalized_epoch - l_epoch + def _get_adjusted_r_epoch(self, r_epoch: EpochNumber, finalized_epoch: EpochNumber): + processing_delay = finalized_epoch - self.l_epoch if processing_delay < self.MIN_CHECKPOINT_STEP and finalized_epoch < r_epoch: logger.info( { "msg": f"Minimum checkpoint step is not reached, current delay is {processing_delay} epochs", "finalized_epoch": finalized_epoch, - "l_epoch": l_epoch, + "l_epoch": self.l_epoch, "r_epoch": r_epoch, } ) - return [] - - r_epoch = min(r_epoch, EpochNumber(finalized_epoch - 1)) - duty_epochs = cast(list[EpochNumber], list(sequence(l_epoch, r_epoch))) - checkpoints: list[Checkpoint] = [] - checkpoint_epochs = [] - for index, epoch in enumerate(duty_epochs, 1): - checkpoint_epochs.append(epoch) - if index % self.MAX_CHECKPOINT_STEP == 0 or epoch == r_epoch: - checkpoint_slot = self.converter.get_epoch_last_slot(EpochNumber(epoch + 1)) - checkpoints.append(_prepare_checkpoint(checkpoint_slot, checkpoint_epochs)) - logger.info( - {"msg": f"Checkpoint slot {checkpoint_slot} with {len(checkpoint_epochs)} duty epochs is prepared"} - ) - checkpoint_epochs = [] - logger.info({"msg": f"Checkpoints to process: {len(checkpoints)}"}) - return checkpoints + raise ValueError('Minimum checkpoint step is not reached yet') + adjusted_r_epoch = min(r_epoch, EpochNumber(finalized_epoch - 1)) + if r_epoch != adjusted_r_epoch: + logger.warning( + {"msg": f"Right border epoch of checkpoints iterator is recalculated according to the finalized epoch. " + f"Before: {r_epoch} After: {adjusted_r_epoch}"} + ) + return adjusted_r_epoch -class Checkpoint: +class CheckpointProcessor: cc: ConsensusClient converter: Web3Converter state: State + finalized_blockstamp: BlockStamp - slot: SlotNumber # last slot of the epoch - duty_epochs: list[EpochNumber] # max 255 elements - block_roots: list[BlockRoot | None] # max 8192 elements - - def __init__( - self, - cc: ConsensusClient, - converter: Web3Converter, - state: State, - slot: SlotNumber, - duty_epochs: list[EpochNumber], - ): + def __init__(self, cc: ConsensusClient, state: State, converter: Web3Converter, finalized_blockstamp: BlockStamp): self.cc = cc self.converter = converter - self.slot = slot - self.duty_epochs = duty_epochs - self.block_roots = [] self.state = state + self.finalized_blockstamp = finalized_blockstamp - def process(self, last_finalized_blockstamp: BlockStamp): - def _unprocessed(): - for _epoch in self.duty_epochs: - if _epoch in self.state.unprocessed_epochs: - if not self.block_roots: - self._get_block_roots() - yield _epoch + def exec(self, checkpoint: Checkpoint) -> int: + logger.info( + {"msg": f"Processing checkpoint for slot {checkpoint.slot} with {len(checkpoint.duty_epochs)} epochs"} + ) + unprocessed_epochs = [e for e in checkpoint.duty_epochs if e in self.state.unprocessed_epochs] + if not unprocessed_epochs: + logger.info({"msg": "Nothing to process in the checkpoint"}) + return 0 + block_roots = self._get_block_roots(checkpoint.slot) + duty_epochs_roots = { + duty_epoch: self._select_block_roots(duty_epoch, block_roots, checkpoint.slot) + for duty_epoch in unprocessed_epochs + } + self._process(unprocessed_epochs, duty_epochs_roots) + return len(unprocessed_epochs) + + def _get_block_roots(self, checkpoint_slot: SlotNumber): + logger.info({"msg": f"Get block roots for slot {checkpoint_slot}"}) + # checkpoint for us like a time point, that's why we use slot, not root + br = self.cc.get_state_block_roots(checkpoint_slot) + # replace duplicated roots to None to mark missed slots + return [None if br[i] == br[i - 1] else br[i] for i in range(len(br))] + def _select_block_roots( + self, duty_epoch: EpochNumber, block_roots: list[BlockRoot], checkpoint_slot: SlotNumber + ) -> list[BlockRoot | None]: + roots_to_check = [] + SLOTS_PER_HISTORICAL_ROOT = int(self.cc.get_config_spec().SLOTS_PER_HISTORICAL_ROOT) + # To check duties in the current epoch you need to + # have 32 slots of the current epoch and 32 slots of the next epoch + slots = sequence( + self.converter.get_epoch_first_slot(duty_epoch), + self.converter.get_epoch_last_slot(EpochNumber(duty_epoch + 1)), + ) + for slot_to_check in slots: + # From spec + # https://github.com/ethereum/consensus-specs/blob/dev/specs/phase0/beacon-chain.md#get_block_root_at_slot + if checkpoint_slot - SLOTS_PER_HISTORICAL_ROOT < slot_to_check <= checkpoint_slot: + roots_to_check.append(block_roots[slot_to_check % SLOTS_PER_HISTORICAL_ROOT]) + continue + raise ValueError("Slot is out of the state block roots range") + return roots_to_check + + def _process(self, unprocessed_epochs: list[EpochNumber], duty_epochs_roots: dict[EpochNumber, list[BlockRoot | None]]): with ThreadPoolExecutor(max_workers=variables.CSM_ORACLE_MAX_CONCURRENCY) as ext: try: futures = { - ext.submit(self._process_epoch, last_finalized_blockstamp, duty_epoch) - for duty_epoch in _unprocessed() + ext.submit( + self._check_duty, + duty_epoch, + duty_epochs_roots[duty_epoch] + ) + for duty_epoch in unprocessed_epochs } for future in as_completed(futures): future.result() @@ -123,49 +173,27 @@ def _unprocessed(): ext.shutdown(wait=False, cancel_futures=True) raise SystemExit(1) from e except Exception as e: - logger.error({"msg": "Error processing epochs in threads, wait the current threads", "error": str(e)}) + logger.error( + {"msg": "Error processing epochs in threads, wait the current threads", "error": str(e)}) # Wait only for the current running threads to prevent # a lot of similar error-possible requests to the consensus node. # Raise the error after a batch of running threads is finished ext.shutdown(wait=True, cancel_futures=True) raise ValueError(e) from e - def _select_roots_to_check(self, duty_epoch: EpochNumber) -> list[BlockRoot | None]: - # inspired by the spec - # https://github.com/ethereum/consensus-specs/blob/dev/specs/phase0/beacon-chain.md#get_block_root_at_slot - SLOTS_PER_HISTORICAL_ROOT = int(self.cc.get_config_spec().SLOTS_PER_HISTORICAL_ROOT) - roots_to_check = [] - slots = sequence( - self.converter.get_epoch_first_slot(duty_epoch), - self.converter.get_epoch_last_slot(EpochNumber(duty_epoch + 1)), - ) - for slot_to_check in slots: - if self.slot - SLOTS_PER_HISTORICAL_ROOT < slot_to_check <= self.slot: - roots_to_check.append(self.block_roots[slot_to_check % SLOTS_PER_HISTORICAL_ROOT]) - continue - raise ValueError("Slot is out of the state block roots range") - return roots_to_check - - def _get_block_roots(self): - logger.info({"msg": f"Get block roots for slot {self.slot}"}) - # checkpoint for us like a time point, that's why we use slot, not root - br = self.cc.get_state_block_roots(self.slot) - # replace duplicated roots to None to mark missed slots - self.block_roots = [None if br[i] == br[i - 1] else br[i] for i in range(len(br))] - - def _process_epoch( + def _check_duty( self, - last_finalized_blockstamp: BlockStamp, duty_epoch: EpochNumber, + block_roots: list[BlockRoot | None], ): logger.info({"msg": f"Process epoch {duty_epoch}"}) start = time.time() - committees = self._prepare_committees(last_finalized_blockstamp, EpochNumber(duty_epoch)) - for root in self._select_roots_to_check(duty_epoch): + committees = self._prepare_committees(EpochNumber(duty_epoch)) + for root in block_roots: if root is None: continue attestations = self.cc.get_block_attestations(BlockRoot(root)) - self._process_attestations(attestations, committees) + process_attestations(attestations, committees) with lock: for committee in committees.values(): @@ -182,36 +210,42 @@ def _process_epoch( logger.info({"msg": f"Epoch {duty_epoch} processed in {time.time() - start:.2f} seconds"}) - def _prepare_committees(self, last_finalized_blockstamp: BlockStamp, epoch: int) -> dict: + def _prepare_committees(self, epoch: int) -> dict: start = time.time() committees = {} - for committee in self.cc.get_attestation_committees(last_finalized_blockstamp, EpochNumber(epoch)): + for committee in self.cc.get_attestation_committees(self.finalized_blockstamp, EpochNumber(epoch)): validators = [] # Order of insertion is used to track the positions in the committees. for validator in committee.validators: data = {"index": validator, "included": False} validators.append(data) - committees[f"{committee.slot}{committee.index}"] = validators + committees[f"{committee.slot}_{committee.index}"] = validators logger.info({"msg": f"Committees for epoch {epoch} processed in {time.time() - start:.2f} seconds"}) return committees - def _process_attestations(self, attestations: Iterable[BlockAttestation], committees: dict) -> None: - def to_bits(aggregation_bits: str): - # copied from https://github.com/ethereum/py-ssz/blob/main/ssz/sedes/bitvector.py#L66 - att_bytes = bytes.fromhex(aggregation_bits[2:]) - return [bool((att_bytes[bit_index // 8] >> bit_index % 8) % 2) for bit_index in range(len(att_bytes) * 8)] - - for attestation in attestations: - committee_id = f"{attestation.data.slot}{attestation.data.index}" - committee = committees.get(committee_id) - att_bits = to_bits(attestation.aggregation_bits) - if not committee: + +def process_attestations(attestations: Iterable[BlockAttestation], committees: dict) -> None: + + for attestation in attestations: + committee_id = f"{attestation.data.slot}_{attestation.data.index}" + committee = committees.get(committee_id) + att_bits = _to_bits(attestation.aggregation_bits) + if not committee: + continue + for index_in_committee, validator in enumerate(committee): + if validator['included']: + # validator has already fulfilled its duties continue - for index_in_committee, validator in enumerate(committee): - if validator['included']: - # validator has already fulfilled its duties - continue - attested = att_bits[index_in_committee] - if attested: - validator['included'] = True - committees[committee_id][index_in_committee] = validator + if _is_attested(att_bits, index_in_committee): + validator['included'] = True + committees[committee_id][index_in_committee] = validator + + +def _is_attested(bits: list[bool], index: int) -> bool: + return bits[index] + + +def _to_bits(aggregation_bits: str): + # copied from https://github.com/ethereum/py-ssz/blob/main/ssz/sedes/bitvector.py#L66 + att_bytes = bytes.fromhex(aggregation_bits[2:]) + return [bool((att_bytes[bit_index // 8] >> bit_index % 8) % 2) for bit_index in range(len(att_bytes) * 8)] diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 7ba054178..f99e2c788 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -7,14 +7,14 @@ from src.metrics.prometheus.business import CONTRACT_ON_PAUSE from src.metrics.prometheus.duration_meter import duration_meter -from src.modules.csm.checkpoint import CheckpointsFactory -from src.modules.csm.state import InvalidState, State +from src.modules.csm.checkpoint import CheckpointsIterator, CheckpointProcessor +from src.modules.csm.state import State, InvalidState from src.modules.csm.tree import Tree from src.modules.csm.types import ReportData from src.modules.submodules.consensus import ConsensusModule from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay from src.modules.submodules.types import ZERO_HASH -from src.providers.execution.contracts import CSFeeOracle +from src.providers.execution.contracts.cs_fee_oracle import CSFeeOracle from src.types import BlockStamp, EpochNumber, ReferenceBlockStamp, SlotNumber, ValidatorIndex from src.utils.cache import global_lru_cache as lru_cache from src.utils.slot import get_first_non_missed_slot @@ -222,24 +222,33 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: self.state.validate_for_collect(l_epoch, r_epoch) self.state.status() - factory = CheckpointsFactory(self.w3.cc, converter, self.state) - checkpoints = factory.prepare_checkpoints(l_epoch, r_epoch, finalized_epoch) + if done := self.state.is_fulfilled: + logger.info({"msg": "All epochs are already processed. Nothing to collect"}) + return done + checkpoints = CheckpointsIterator( + converter, min(self.state.unprocessed_epochs) or l_epoch, r_epoch, finalized_epoch + ) + processor = CheckpointProcessor(self.w3.cc, self.state, converter, blockstamp) + + new_processed_epochs = 0 start = time.time() for checkpoint in checkpoints: + if self.current_frame_range(self._receive_last_finalized_slot()) != (l_ref_slot, r_ref_slot): - logger.info({"msg": "Checkpoints were prepared for an outdated frame, stop proccessing"}) + logger.info({"msg": "Checkpoints were prepared for an outdated frame, stop processing"}) raise ValueError("Outdated checkpoint") - if converter.get_epoch_by_slot(checkpoint.slot) > finalized_epoch: + if checkpoint.slot > blockstamp.slot_number: logger.info({"msg": f"Checkpoint for slot {checkpoint.slot} is not finalized yet"}) break - logger.info({"msg": f"Processing checkpoint for slot {checkpoint.slot}"}) - logger.info({"msg": f"Processing {len(checkpoint.duty_epochs)} epochs"}) - checkpoint.process(blockstamp) - if checkpoints: - logger.info({"msg": f"All epochs were processed in {time.time() - start:.2f} seconds"}) + new_processed_epochs += processor.exec(checkpoint) + + if new_processed_epochs: + logger.info( + {"msg": f"Collecting data for {new_processed_epochs} epochs was done in {time.time() - start:.2f} sec"} + ) return self.state.is_fulfilled diff --git a/src/providers/execution/contracts/__init__.py b/src/providers/execution/contracts/__init__.py deleted file mode 100644 index 89030f53d..000000000 --- a/src/providers/execution/contracts/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .CSFeeDistributor import * -from .CSFeeOracle import * -from .CSModule import * diff --git a/src/providers/execution/contracts/CSFeeDistributor.py b/src/providers/execution/contracts/cs_fee_distributor.py similarity index 100% rename from src/providers/execution/contracts/CSFeeDistributor.py rename to src/providers/execution/contracts/cs_fee_distributor.py diff --git a/src/providers/execution/contracts/CSFeeOracle.py b/src/providers/execution/contracts/cs_fee_oracle.py similarity index 100% rename from src/providers/execution/contracts/CSFeeOracle.py rename to src/providers/execution/contracts/cs_fee_oracle.py diff --git a/src/providers/execution/contracts/CSModule.py b/src/providers/execution/contracts/cs_module.py similarity index 100% rename from src/providers/execution/contracts/CSModule.py rename to src/providers/execution/contracts/cs_module.py diff --git a/src/web3py/extensions/csm.py b/src/web3py/extensions/csm.py index b720fcb58..6ae91dbd0 100644 --- a/src/web3py/extensions/csm.py +++ b/src/web3py/extensions/csm.py @@ -13,7 +13,9 @@ from src import variables from src.metrics.prometheus.business import FRAME_PREV_REPORT_REF_SLOT -from src.providers.execution.contracts import CSFeeDistributor, CSFeeOracle, CSModule +from src.providers.execution.contracts.cs_fee_distributor import CSFeeDistributor +from src.providers.execution.contracts.cs_fee_oracle import CSFeeOracle +from src.providers.execution.contracts.cs_module import CSModule from src.providers.ipfs import CIDv0, CIDv1, is_cid_v0 from src.types import BlockStamp, SlotNumber from src.web3py.extensions.lido_validators import NodeOperatorId diff --git a/tests/factory/configs.py b/tests/factory/configs.py index f7f4f6a25..ea89cf235 100644 --- a/tests/factory/configs.py +++ b/tests/factory/configs.py @@ -1,7 +1,13 @@ from typing import overload from src.modules.accounting.types import OracleReportLimits from src.modules.submodules.types import ChainConfig, FrameConfig -from src.providers.consensus.types import BeaconSpecResponse +from src.providers.consensus.types import ( + BeaconSpecResponse, + SlotAttestationCommittee, + BlockAttestation, + AttestationData, + Checkpoint, +) from src.services.bunker_cases.types import BunkerConfig from tests.factory.web3_factory import Web3Factory @@ -45,3 +51,24 @@ class BeaconSpecResponseFactory(Web3Factory): SECONDS_PER_SLOT = 12 SLOTS_PER_EPOCH = 32 SLOTS_PER_HISTORICAL_ROOT = 8192 + + +class SlotAttestationCommitteeFactory(Web3Factory): + __model__ = SlotAttestationCommittee + + slot = 0 + index = 0 + validators = [] + + +class BlockAttestationFactory(Web3Factory): + __model__ = BlockAttestation + + aggregation_bits = "0x" + data = AttestationData( + slot="0", + index="0", + beacon_block_root="0x", + source=Checkpoint("0", "0x"), + target=Checkpoint("0", "0x"), + ) diff --git a/tests/modules/csm/test_checkpoint.py b/tests/modules/csm/test_checkpoint.py new file mode 100644 index 000000000..ceeb5bc6c --- /dev/null +++ b/tests/modules/csm/test_checkpoint.py @@ -0,0 +1,290 @@ +from copy import deepcopy +from typing import cast, Iterator +from unittest.mock import Mock + +import pytest + +from src.modules.csm.checkpoint import CheckpointsIterator, Checkpoint, CheckpointProcessor, process_attestations +from src.modules.csm.state import State +from src.modules.submodules.types import ChainConfig, FrameConfig +from src.providers.consensus.client import ConsensusClient +from src.providers.consensus.types import BeaconSpecResponse, SlotAttestationCommittee, BlockAttestation +from src.utils.web3converter import Web3Converter +from tests.factory.configs import ( + FrameConfigFactory, + ChainConfigFactory, + BeaconSpecResponseFactory, + SlotAttestationCommitteeFactory, + BlockAttestationFactory, +) + + +@pytest.fixture +def frame_config() -> FrameConfig: + return FrameConfigFactory.build( + epochs_per_frame=225, + ) + + +@pytest.fixture +def chain_config() -> ChainConfig: + return ChainConfigFactory.build( + slots_per_epoch=32, + seconds_per_slot=12, + genesis_time=0, + ) + + +@pytest.fixture +def converter(frame_config: FrameConfig, chain_config: ChainConfig) -> Web3Converter: + return Web3Converter(chain_config, frame_config) + + +def test_checkpoints_iterator_min_epoch_is_not_reached(converter): + with pytest.raises(ValueError, match='Minimum checkpoint step is not reached yet'): + CheckpointsIterator(converter, 100, 600, 109) + + +def test_checkpoints_iterator_r_epoch_is_changed_by_finalized(converter): + l_epoch = 100 + r_epoch = 600 + finalized_epoch = 110 + expected = finalized_epoch - 1 + iterator = CheckpointsIterator(converter, l_epoch, r_epoch, finalized_epoch) + assert r_epoch != iterator.r_epoch, "Right border should be changed" + assert iterator.r_epoch == expected, "Right border should be equal to the finalized epoch minus one" + + +@pytest.mark.parametrize( + "l_epoch,r_epoch,finalized_epoch,expected_checkpoints", + [ + (0, 255, 255, [Checkpoint(8191, list(range(0, 255)))]), + (15, 255, 255, [Checkpoint(8191, list(range(15, 255)))]), + (15, 255, 25, [Checkpoint(831, list(range(15, 25)))]), + (0, 255 * 2, 255 * 2, [Checkpoint(8191, list(range(0, 255))), Checkpoint(16351, list(range(255, 510)))]), + (15, 255 * 2, 350, [Checkpoint(8671, list(range(15, 270))), Checkpoint(11231, list(range(270, 350)))]), + ], +) +def test_checkpoints_iterator_given_checkpoints(converter, l_epoch, r_epoch, finalized_epoch, expected_checkpoints): + iterator = CheckpointsIterator(converter, l_epoch, r_epoch, finalized_epoch) + assert iter(iterator).checkpoints == expected_checkpoints + + +@pytest.fixture +def consensus_client(): + return ConsensusClient('http://localhost', 5 * 60, 5, 5) + + +@pytest.fixture +def mock_get_state_block_roots(consensus_client): + def _get_state_block_roots(state_id): + match state_id: + # with no duplicated roots + case 0: + return [f'0x{r}' for r in range(0, 8192)] + # with duplicated roots + case 1: + br = [f'0x{r}' for r in range(0, 8192)] + return [br[i - 1] if i % 2 == 0 else br[i] for i in range(len(br))] + + consensus_client.get_state_block_roots = Mock(side_effect=_get_state_block_roots) + + +def test_checkpoints_processor_get_block_roots(consensus_client, mock_get_state_block_roots, converter: Web3Converter): + state = ... + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + converter, + state, + finalized_blockstamp, + ) + roots = processor._get_block_roots(0) + assert len([r for r in roots if r is not None]) == 8192 + + +def test_checkpoints_processor_get_block_roots_with_duplicates( + consensus_client, mock_get_state_block_roots, converter: Web3Converter +): + state = ... + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + converter, + state, + finalized_blockstamp, + ) + roots = processor._get_block_roots(1) + assert len([r for r in roots if r is not None]) == 4096 + + +@pytest.fixture +def mock_get_config_spec(consensus_client): + bc_spec = cast(BeaconSpecResponse, BeaconSpecResponseFactory.build()) + bc_spec.SLOTS_PER_HISTORICAL_ROOT = 8192 + consensus_client.get_config_spec = Mock(return_value=bc_spec) + + +def test_checkpoints_processor_select_block_roots( + consensus_client, mock_get_state_block_roots, mock_get_config_spec, converter: Web3Converter +): + state = ... + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + state, + converter, + finalized_blockstamp, + ) + roots = processor._get_block_roots(0) + selected = processor._select_block_roots(10, roots, 8192) + assert len(selected) == 64 + assert selected == [f'0x{r}' for r in range(320, 384)] + + +def test_checkpoints_processor_select_block_roots_out_of_range( + consensus_client, mock_get_state_block_roots, mock_get_config_spec, converter: Web3Converter +): + state = ... + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + state, + converter, + finalized_blockstamp, + ) + roots = processor._get_block_roots(0) + with pytest.raises(ValueError, match="Slot is out of the state block roots range"): + processor._select_block_roots(255, roots, 8192) + + +@pytest.fixture() +def mock_get_attestation_committees(consensus_client): + def _get_attestation_committees(finalized_slot, epoch): + committees = [] + validators = [v for v in range(0, 2048 * 32)] + for i in range(epoch * 32, epoch * 32 + 32): # 1 epoch = 32 slots. + for j in range(0, 64): # 64 committees per slot + committee = deepcopy(cast(SlotAttestationCommittee, SlotAttestationCommitteeFactory.build())) + committee.slot = i + committee.index = j + # 32 validators per committee + committee.validators = [validators.pop() for _ in range(32)] + committees.append(committee) + return committees + + consensus_client.get_attestation_committees = Mock(side_effect=_get_attestation_committees) + + +def test_checkpoints_processor_prepare_committees(mock_get_attestation_committees, consensus_client, converter): + state = ... + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + state, + converter, + finalized_blockstamp, + ) + raw = consensus_client.get_attestation_committees(0, 0) + committees = processor._prepare_committees(0) + assert len(committees) == 2048 + for index, (committee_id, validators) in enumerate(committees.items()): + slot, committee_index = committee_id.split('_') + committee_from_raw = raw[index] + assert int(slot) == committee_from_raw.slot + assert int(committee_index) == committee_from_raw.index + assert len(validators) == 32 + for validator in validators: + assert validator['included'] is False + + +def test_checkpoints_processor_process_attestations(mock_get_attestation_committees, consensus_client, converter): + state = ... + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + state, + converter, + finalized_blockstamp, + ) + committees = processor._prepare_committees(0) + # normal attestation + attestation = cast(BlockAttestation, BlockAttestationFactory.build()) + attestation.data.slot = 0 + attestation.data.index = 0 + attestation.aggregation_bits = '0x' + 'f' * 32 + # the same but with no included attestations in bits + attestation2 = cast(BlockAttestation, BlockAttestationFactory.build()) + attestation2.data.slot = 0 + attestation2.data.index = 0 + attestation2.aggregation_bits = '0x' + '0' * 32 + process_attestations([attestation, attestation2], committees) + for validator in committees["0_0"]: + # only the first attestation is accounted + assert validator['included'] is True + + +def test_checkpoints_processor_process_attestations_undefined_committee( + mock_get_attestation_committees, consensus_client, converter +): + state = ... + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + state, + converter, + finalized_blockstamp, + ) + committees = processor._prepare_committees(0) + # undefined committee + attestation = cast(BlockAttestation, BlockAttestationFactory.build()) + attestation.data.slot = 100500 + attestation.data.index = 100500 + attestation.aggregation_bits = '0x' + 'f' * 32 + process_attestations([attestation], committees) + for validators in committees.values(): + for v in validators: + assert v['included'] is False + + +@pytest.fixture() +def mock_get_block_attestations(consensus_client): + def _get_block_attestations(root): + attestations = [] + for i in range(0, 64): + attestation = deepcopy(cast(BlockAttestation, BlockAttestationFactory.build())) + attestation.data.slot = root[2:] + attestation.data.index = str(i) + attestation.aggregation_bits = '0x' + 'f' * 32 + attestations.append(attestation) + return attestations + + consensus_client.get_block_attestations = Mock(side_effect=_get_block_attestations) + + +def test_checkpoints_processor_check_duty( + mock_get_state_block_roots, mock_get_attestation_committees, mock_get_block_attestations, consensus_client +): + state = State() + state.validate_for_collect(0, 255) + finalized_blockstamp = ... + processor = CheckpointProcessor( + consensus_client, + state, + converter, + finalized_blockstamp, + ) + roots = processor._get_block_roots(0) + processor._check_duty(0, roots[:64]) + assert len(state._processed_epochs) == 1 + assert len(state._epochs_to_process) == 256 + assert len(state.unprocessed_epochs) == 255 + assert len(state.data) == 2048 * 32 + + +def test_checkpoints_processor_process(): + ... + + +def test_checkpoints_processor_exec(): + ...