diff --git a/.circleci/config.yml b/.circleci/config.yml index 5f8039654d..328fd9c6db 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -134,6 +134,18 @@ jobs: - image: circleci/python:3.7 environment: TOXENV: py37-lint + py36-lint-eth2: + <<: *common + docker: + - image: circleci/python:3.6 + environment: + TOXENV: py36-lint-eth2 + py37-lint-eth2: + <<: *common + docker: + - image: circleci/python:3.7 + environment: + TOXENV: py37-lint-eth2 py36-docs: <<: *common @@ -418,6 +430,8 @@ workflows: - py36-long_run_integration - py36-lint + - py36-lint-eth2 - py37-lint + - py37-lint-eth2 - docker-image-build-test diff --git a/.flake8-eth2 b/.flake8-eth2 new file mode 100644 index 0000000000..a56eebad32 --- /dev/null +++ b/.flake8-eth2 @@ -0,0 +1,4 @@ +[flake8] +max-line-length= 100 +exclude= +ignore=W503,E203 diff --git a/Makefile b/Makefile index 725fafc908..27614ca8a5 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,9 @@ clean-pyc: lint: tox -epy3{6,5}-lint +lint-eth2: + tox -epy37-lint-eth2 + test: py.test --tb native tests diff --git a/eth2/_utils/bitfield.py b/eth2/_utils/bitfield.py index 07474814ff..731acb06c5 100644 --- a/eth2/_utils/bitfield.py +++ b/eth2/_utils/bitfield.py @@ -1,7 +1,7 @@ -from cytoolz import ( - curry, -) +from cytoolz import curry + from eth2.beacon.typing import Bitfield + from .tuple import update_tuple_item @@ -12,13 +12,7 @@ def has_voted(bitfield: Bitfield, index: int) -> bool: @curry def set_voted(bitfield: Bitfield, index: int) -> Bitfield: - return Bitfield( - update_tuple_item( - bitfield, - index, - True, - ) - ) + return Bitfield(update_tuple_item(bitfield, index, True)) def get_bitfield_length(bit_count: int) -> int: @@ -32,9 +26,5 @@ def get_empty_bitfield(bit_count: int) -> Bitfield: def get_vote_count(bitfield: Bitfield) -> int: return len( - tuple( - index - for index in range(len(bitfield)) - if has_voted(bitfield, index) - ) + tuple(index for index in range(len(bitfield)) if has_voted(bitfield, index)) ) diff --git a/eth2/_utils/bls/__init__.py b/eth2/_utils/bls/__init__.py index 351c8fb730..3a6190a6ba 100644 --- a/eth2/_utils/bls/__init__.py +++ b/eth2/_utils/bls/__init__.py @@ -1,32 +1,17 @@ -from typing import ( - Sequence, - Type, -) - -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) +from typing import Sequence, Type +from eth_typing import BLSPubkey, BLSSignature, Hash32 from py_ecc.bls.typing import Domain -from eth2.beacon.exceptions import ( - SignatureError, -) +from eth2.beacon.exceptions import SignatureError -from .backends import ( - DEFAULT_BACKEND, - NoOpBackend, -) -from .backends.base import ( - BaseBLSBackend, -) +from .backends import DEFAULT_BACKEND, NoOpBackend +from .backends.base import BaseBLSBackend from .validation import ( + validate_many_public_keys, validate_private_key, - validate_signature, validate_public_key, - validate_many_public_keys, + validate_signature, ) @@ -49,51 +34,51 @@ def use_noop_backend(cls) -> None: cls.use(NoOpBackend) @classmethod - def privtopub(cls, - privkey: int) -> BLSPubkey: + def privtopub(cls, privkey: int) -> BLSPubkey: validate_private_key(privkey) return cls.backend.privtopub(privkey) @classmethod - def sign(cls, - message_hash: Hash32, - privkey: int, - domain: Domain) -> BLSSignature: + def sign(cls, message_hash: Hash32, privkey: int, domain: Domain) -> BLSSignature: validate_private_key(privkey) return cls.backend.sign(message_hash, privkey, domain) @classmethod - def aggregate_signatures(cls, - signatures: Sequence[BLSSignature]) -> BLSSignature: + def aggregate_signatures(cls, signatures: Sequence[BLSSignature]) -> BLSSignature: return cls.backend.aggregate_signatures(signatures) @classmethod - def aggregate_pubkeys(cls, - pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: + def aggregate_pubkeys(cls, pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: return cls.backend.aggregate_pubkeys(pubkeys) @classmethod - def verify(cls, - message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> bool: + def verify( + cls, + message_hash: Hash32, + pubkey: BLSPubkey, + signature: BLSSignature, + domain: Domain, + ) -> bool: return cls.backend.verify(message_hash, pubkey, signature, domain) @classmethod - def verify_multiple(cls, - pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> bool: + def verify_multiple( + cls, + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, + ) -> bool: return cls.backend.verify_multiple(pubkeys, message_hashes, signature, domain) @classmethod - def validate(cls, - message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> None: + def validate( + cls, + message_hash: Hash32, + pubkey: BLSPubkey, + signature: BLSSignature, + domain: Domain, + ) -> None: if cls.backend != NoOpBackend: validate_signature(signature) validate_public_key(pubkey) @@ -108,11 +93,13 @@ def validate(cls, ) @classmethod - def validate_multiple(cls, - pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> None: + def validate_multiple( + cls, + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, + ) -> None: if cls.backend != NoOpBackend: validate_signature(signature) validate_many_public_keys(pubkeys) diff --git a/eth2/_utils/bls/backends/__init__.py b/eth2/_utils/bls/backends/__init__.py index 40198805c2..50e6732c53 100644 --- a/eth2/_utils/bls/backends/__init__.py +++ b/eth2/_utils/bls/backends/__init__.py @@ -1,12 +1,8 @@ +from typing import Tuple, Type # noqa: F401 +from .base import BaseBLSBackend # noqa: F401 from .noop import NoOpBackend from .py_ecc import PyECCBackend -from .base import BaseBLSBackend # noqa: F401 -from typing import ( # noqa: F401 - Type, - Tuple, -) - AVAILABLE_BACKENDS = ( NoOpBackend, @@ -18,6 +14,7 @@ try: from .milagro import MilagroBackend + DEFAULT_BACKEND = MilagroBackend AVAILABLE_BACKENDS += (MilagroBackend,) except ImportError: @@ -25,6 +22,7 @@ try: from .chia import ChiaBackend + AVAILABLE_BACKENDS += (ChiaBackend,) except ImportError: pass diff --git a/eth2/_utils/bls/backends/base.py b/eth2/_utils/bls/backends/base.py index 330b998bb1..a1aefb978e 100644 --- a/eth2/_utils/bls/backends/base.py +++ b/eth2/_utils/bls/backends/base.py @@ -1,17 +1,7 @@ -from abc import ( - ABC, - abstractmethod, -) -from typing import ( - Sequence, -) - -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) +from abc import ABC, abstractmethod +from typing import Sequence +from eth_typing import BLSPubkey, BLSSignature, Hash32 from py_ecc.bls.typing import Domain @@ -23,17 +13,14 @@ def privtopub(k: int) -> BLSPubkey: @staticmethod @abstractmethod - def sign(message_hash: Hash32, - privkey: int, - domain: Domain) -> BLSSignature: + def sign(message_hash: Hash32, privkey: int, domain: Domain) -> BLSSignature: ... @staticmethod @abstractmethod - def verify(message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> bool: + def verify( + message_hash: Hash32, pubkey: BLSPubkey, signature: BLSSignature, domain: Domain + ) -> bool: ... @staticmethod @@ -48,8 +35,10 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: @staticmethod @abstractmethod - def verify_multiple(pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> bool: + def verify_multiple( + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, + ) -> bool: ... diff --git a/eth2/_utils/bls/backends/chia/__init__.py b/eth2/_utils/bls/backends/chia/__init__.py index d6c1ad495a..82fac04202 100644 --- a/eth2/_utils/bls/backends/chia/__init__.py +++ b/eth2/_utils/bls/backends/chia/__init__.py @@ -1,18 +1,9 @@ -from typing import ( - Sequence, -) - -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) +from typing import Sequence +from eth_typing import BLSPubkey, BLSSignature, Hash32 from py_ecc.bls.typing import Domain -from eth2._utils.bls.backends.base import ( - BaseBLSBackend, -) +from eth2._utils.bls.backends.base import BaseBLSBackend from .api import ( aggregate_pubkeys, @@ -30,16 +21,13 @@ def privtopub(k: int) -> BLSPubkey: return privtopub(k) @staticmethod - def sign(message_hash: Hash32, - privkey: int, - domain: Domain) -> BLSSignature: + def sign(message_hash: Hash32, privkey: int, domain: Domain) -> BLSSignature: return sign(message_hash, privkey, domain) @staticmethod - def verify(message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> bool: + def verify( + message_hash: Hash32, pubkey: BLSPubkey, signature: BLSSignature, domain: Domain + ) -> bool: return verify(message_hash, pubkey, signature, domain) @staticmethod @@ -51,8 +39,10 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: return aggregate_pubkeys(pubkeys) @staticmethod - def verify_multiple(pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> bool: + def verify_multiple( + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, + ) -> bool: return verify_multiple(pubkeys, message_hashes, signature, domain) diff --git a/eth2/_utils/bls/backends/chia/api.py b/eth2/_utils/bls/backends/chia/api.py index 1284b3435b..abc77d85eb 100644 --- a/eth2/_utils/bls/backends/chia/api.py +++ b/eth2/_utils/bls/backends/chia/api.py @@ -1,31 +1,11 @@ -from typing import ( - Sequence, - cast, -) - -from blspy import ( - AggregationInfo, - InsecureSignature, - PrivateKey, - PublicKey, - Signature, -) -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) -from eth_utils import ( - ValidationError, -) - +from typing import Sequence, cast +from blspy import AggregationInfo, InsecureSignature, PrivateKey, PublicKey, Signature +from eth_typing import BLSPubkey, BLSSignature, Hash32 +from eth_utils import ValidationError from py_ecc.bls.typing import Domain -from eth2.beacon.constants import ( - EMPTY_PUBKEY, - EMPTY_SIGNATURE, -) +from eth2.beacon.constants import EMPTY_PUBKEY, EMPTY_SIGNATURE def _privkey_from_int(privkey: int) -> PrivateKey: @@ -54,13 +34,9 @@ def combine_domain(message_hash: Hash32, domain: Domain) -> bytes: return message_hash + domain -def sign(message_hash: Hash32, - privkey: int, - domain: Domain) -> BLSSignature: +def sign(message_hash: Hash32, privkey: int, domain: Domain) -> BLSSignature: privkey_chia = _privkey_from_int(privkey) - sig_chia = privkey_chia.sign_insecure( - combine_domain(message_hash, domain) - ) + sig_chia = privkey_chia.sign_insecure(combine_domain(message_hash, domain)) sig_chia_bytes = sig_chia.serialize() return cast(BLSSignature, sig_chia_bytes) @@ -70,17 +46,13 @@ def privtopub(k: int) -> BLSPubkey: return cast(BLSPubkey, privkey_chia.get_public_key().serialize()) -def verify(message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> bool: +def verify( + message_hash: Hash32, pubkey: BLSPubkey, signature: BLSSignature, domain: Domain +) -> bool: pubkey_chia = _pubkey_from_bytes(pubkey) signature_chia = _signature_from_bytes(signature) signature_chia.set_aggregation_info( - AggregationInfo.from_msg( - pubkey_chia, - combine_domain(message_hash, domain), - ) + AggregationInfo.from_msg(pubkey_chia, combine_domain(message_hash, domain)) ) return cast(bool, signature_chia.verify()) @@ -90,8 +62,7 @@ def aggregate_signatures(signatures: Sequence[BLSSignature]) -> BLSSignature: return EMPTY_SIGNATURE signatures_chia = [ - InsecureSignature.from_bytes(signature) - for signature in signatures + InsecureSignature.from_bytes(signature) for signature in signatures ] aggregated_signature = InsecureSignature.aggregate(signatures_chia) aggregated_signature_bytes = aggregated_signature.serialize() @@ -101,31 +72,28 @@ def aggregate_signatures(signatures: Sequence[BLSSignature]) -> BLSSignature: def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: if len(pubkeys) == 0: return EMPTY_PUBKEY - pubkeys_chia = [ - _pubkey_from_bytes(pubkey) - for pubkey in pubkeys - ] + pubkeys_chia = [_pubkey_from_bytes(pubkey) for pubkey in pubkeys] aggregated_pubkey_chia = PublicKey.aggregate_insecure(pubkeys_chia) return cast(BLSPubkey, aggregated_pubkey_chia.serialize()) -def verify_multiple(pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> bool: +def verify_multiple( + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, +) -> bool: len_msgs = len(message_hashes) len_pubkeys = len(pubkeys) if len_pubkeys != len_msgs: raise ValueError( - "len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" % ( - len_pubkeys, len_msgs - ) + "len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" + % (len_pubkeys, len_msgs) ) message_hashes_with_domain = [ - combine_domain(message_hash, domain) - for message_hash in message_hashes + combine_domain(message_hash, domain) for message_hash in message_hashes ] pubkeys_chia = map(_pubkey_from_bytes, pubkeys) aggregate_infos = [ diff --git a/eth2/_utils/bls/backends/milagro.py b/eth2/_utils/bls/backends/milagro.py index a4c294d862..6b98ce81ee 100644 --- a/eth2/_utils/bls/backends/milagro.py +++ b/eth2/_utils/bls/backends/milagro.py @@ -1,17 +1,7 @@ -from typing import ( - Iterator, - Sequence, - Tuple, -) +from typing import Iterator, Sequence, Tuple -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) -from eth_utils import ( - to_tuple, -) +from eth_typing import BLSPubkey, BLSSignature, Hash32 +from eth_utils import to_tuple from milagro_bls_binding import ( aggregate_pubkeys, aggregate_signatures, @@ -20,17 +10,10 @@ verify, verify_multiple, ) -from py_ecc.bls.typing import ( - Domain, -) +from py_ecc.bls.typing import Domain -from eth2._utils.bls.backends.base import ( - BaseBLSBackend, -) -from eth2.beacon.constants import ( - EMPTY_PUBKEY, - EMPTY_SIGNATURE, -) +from eth2._utils.bls.backends.base import BaseBLSBackend +from eth2.beacon.constants import EMPTY_PUBKEY, EMPTY_SIGNATURE def to_int(domain: Domain) -> int: @@ -38,12 +21,13 @@ def to_int(domain: Domain) -> int: Convert Domain to big endian int since sigp/milagro_bls use big endian int on hash to g2. """ - return int.from_bytes(domain, 'big') + return int.from_bytes(domain, "big") @to_tuple -def filter_non_empty_pair(pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32]) -> Iterator[Tuple[BLSPubkey, Hash32]]: +def filter_non_empty_pair( + pubkeys: Sequence[BLSPubkey], message_hashes: Sequence[Hash32] +) -> Iterator[Tuple[BLSPubkey, Hash32]]: for i, pubkey in enumerate(pubkeys): if pubkey != EMPTY_PUBKEY: yield pubkey, message_hashes[i] @@ -52,26 +36,27 @@ def filter_non_empty_pair(pubkeys: Sequence[BLSPubkey], class MilagroBackend(BaseBLSBackend): @staticmethod def privtopub(k: int) -> BLSPubkey: - return privtopub(k.to_bytes(48, 'big')) + return privtopub(k.to_bytes(48, "big")) @staticmethod - def sign(message_hash: Hash32, - privkey: int, - domain: Domain) -> BLSSignature: - return sign(message_hash, privkey.to_bytes(48, 'big'), to_int(domain)) + def sign(message_hash: Hash32, privkey: int, domain: Domain) -> BLSSignature: + return sign(message_hash, privkey.to_bytes(48, "big"), to_int(domain)) @staticmethod - def verify(message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> bool: + def verify( + message_hash: Hash32, pubkey: BLSPubkey, signature: BLSSignature, domain: Domain + ) -> bool: if pubkey == EMPTY_PUBKEY: - raise ValueError(f"Empty public key breaks Milagro binding pubkey={pubkey}") + raise ValueError( + f"Empty public key breaks Milagro binding pubkey={pubkey}" + ) return verify(message_hash, pubkey, signature, to_int(domain)) @staticmethod def aggregate_signatures(signatures: Sequence[BLSSignature]) -> BLSSignature: - non_empty_signatures = tuple(sig for sig in signatures if sig != EMPTY_SIGNATURE) + non_empty_signatures = tuple( + sig for sig in signatures if sig != EMPTY_SIGNATURE + ) if len(non_empty_signatures) == 0: return EMPTY_SIGNATURE return aggregate_signatures(list(non_empty_signatures)) @@ -84,12 +69,16 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: return aggregate_pubkeys(list(non_empty_pubkeys)) @staticmethod - def verify_multiple(pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> bool: + def verify_multiple( + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, + ) -> bool: if signature == EMPTY_SIGNATURE: - raise ValueError(f"Empty signature breaks Milagro binding signature={signature}") + raise ValueError( + f"Empty signature breaks Milagro binding signature={signature}" + ) non_empty_pubkeys, filtered_message_hashes = zip( *filter_non_empty_pair(pubkeys, message_hashes) diff --git a/eth2/_utils/bls/backends/noop.py b/eth2/_utils/bls/backends/noop.py index 9fca851854..33324dbac7 100644 --- a/eth2/_utils/bls/backends/noop.py +++ b/eth2/_utils/bls/backends/noop.py @@ -1,41 +1,26 @@ -from typing import ( - Sequence, -) - -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) +from typing import Sequence +from eth_typing import BLSPubkey, BLSSignature, Hash32 from py_ecc.bls.typing import Domain -from eth2.beacon.constants import ( - EMPTY_PUBKEY, - EMPTY_SIGNATURE, -) +from eth2.beacon.constants import EMPTY_PUBKEY, EMPTY_SIGNATURE -from .base import ( - BaseBLSBackend, -) +from .base import BaseBLSBackend class NoOpBackend(BaseBLSBackend): @staticmethod def privtopub(k: int) -> BLSPubkey: - return BLSPubkey(k.to_bytes(48, 'little')) + return BLSPubkey(k.to_bytes(48, "little")) @staticmethod - def sign(message_hash: Hash32, - privkey: int, - domain: Domain) -> BLSSignature: + def sign(message_hash: Hash32, privkey: int, domain: Domain) -> BLSSignature: return EMPTY_SIGNATURE @staticmethod - def verify(message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> bool: + def verify( + message_hash: Hash32, pubkey: BLSPubkey, signature: BLSSignature, domain: Domain + ) -> bool: return True @staticmethod @@ -47,8 +32,10 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: return EMPTY_PUBKEY @staticmethod - def verify_multiple(pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> bool: + def verify_multiple( + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, + ) -> bool: return True diff --git a/eth2/_utils/bls/backends/py_ecc.py b/eth2/_utils/bls/backends/py_ecc.py index 14cf97ead6..1d4897f6f2 100644 --- a/eth2/_utils/bls/backends/py_ecc.py +++ b/eth2/_utils/bls/backends/py_ecc.py @@ -1,17 +1,6 @@ -from typing import ( - Sequence, -) - -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) +from typing import Sequence - -from eth2._utils.bls.backends.base import ( - BaseBLSBackend, -) +from eth_typing import BLSPubkey, BLSSignature, Hash32 from py_ecc.bls import ( aggregate_pubkeys, aggregate_signatures, @@ -22,10 +11,8 @@ ) from py_ecc.bls.typing import Domain -from eth2.beacon.constants import ( - EMPTY_PUBKEY, - EMPTY_SIGNATURE, -) +from eth2._utils.bls.backends.base import BaseBLSBackend +from eth2.beacon.constants import EMPTY_PUBKEY, EMPTY_SIGNATURE class PyECCBackend(BaseBLSBackend): @@ -34,16 +21,13 @@ def privtopub(k: int) -> BLSPubkey: return privtopub(k) @staticmethod - def sign(message_hash: Hash32, - privkey: int, - domain: Domain) -> BLSSignature: + def sign(message_hash: Hash32, privkey: int, domain: Domain) -> BLSSignature: return sign(message_hash, privkey, domain) @staticmethod - def verify(message_hash: Hash32, - pubkey: BLSPubkey, - signature: BLSSignature, - domain: Domain) -> bool: + def verify( + message_hash: Hash32, pubkey: BLSPubkey, signature: BLSSignature, domain: Domain + ) -> bool: return verify(message_hash, pubkey, signature, domain) @staticmethod @@ -61,8 +45,10 @@ def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: return aggregate_pubkeys(pubkeys) @staticmethod - def verify_multiple(pubkeys: Sequence[BLSPubkey], - message_hashes: Sequence[Hash32], - signature: BLSSignature, - domain: Domain) -> bool: + def verify_multiple( + pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: Domain, + ) -> bool: return verify_multiple(pubkeys, message_hashes, signature, domain) diff --git a/eth2/_utils/bls/validation.py b/eth2/_utils/bls/validation.py index 13608c9a04..9cea33a938 100644 --- a/eth2/_utils/bls/validation.py +++ b/eth2/_utils/bls/validation.py @@ -1,20 +1,11 @@ -from py_ecc.optimized_bls12_381 import ( - curve_order, -) -from eth_typing import ( - BLSSignature, - BLSPubkey, -) -from eth2.beacon.constants import ( - EMPTY_PUBKEY, - EMPTY_SIGNATURE, -) -from eth2.beacon.exceptions import ( - SignatureError, - PublicKeyError, -) from typing import Sequence +from eth_typing import BLSPubkey, BLSSignature +from py_ecc.optimized_bls12_381 import curve_order + +from eth2.beacon.constants import EMPTY_PUBKEY, EMPTY_SIGNATURE +from eth2.beacon.exceptions import PublicKeyError, SignatureError + def validate_private_key(privkey: int) -> None: if privkey <= 0 or privkey >= curve_order: @@ -23,7 +14,7 @@ def validate_private_key(privkey: int) -> None: ) -def validate_public_key(pubkey: BLSPubkey, allow_empty: bool =False) -> None: +def validate_public_key(pubkey: BLSPubkey, allow_empty: bool = False) -> None: if len(pubkey) != 48: raise PublicKeyError( f"Invalid public key length, expect 48 got {len(pubkey)}. pubkey: {pubkey}" @@ -43,6 +34,4 @@ def validate_signature(signature: BLSSignature) -> None: f"Invalid signaute length, expect 96 got {len(signature)}. Signature: {signature}" ) if signature == EMPTY_SIGNATURE: - raise SignatureError( - f"Signature should not be empty. Signature: {signature}" - ) + raise SignatureError(f"Signature should not be empty. Signature: {signature}") diff --git a/eth2/_utils/funcs.py b/eth2/_utils/funcs.py index 4607f42f9a..48db66e989 100644 --- a/eth2/_utils/funcs.py +++ b/eth2/_utils/funcs.py @@ -9,8 +9,10 @@ def constantly(x: Any) -> Any: """ Return a function that returns ``x`` given any arguments. """ + def f(*args: Any, **kwargs: Any) -> Any: return x + return f diff --git a/eth2/_utils/hash.py b/eth2/_utils/hash.py index 8d52b54713..d46f232434 100644 --- a/eth2/_utils/hash.py +++ b/eth2/_utils/hash.py @@ -1,7 +1,5 @@ from hashlib import sha256 -from typing import ( - Union, -) +from typing import Union from eth_typing import Hash32 diff --git a/eth2/_utils/merkle/common.py b/eth2/_utils/merkle/common.py index a4acbc127e..02864893b1 100644 --- a/eth2/_utils/merkle/common.py +++ b/eth2/_utils/merkle/common.py @@ -1,27 +1,10 @@ -from typing import ( - Iterable, - NewType, - Sequence, -) - -from cytoolz import ( - iterate, - partition, - take, -) - -from eth_utils import to_tuple - -from eth2._utils.hash import ( - hash_eth2, -) -from eth_typing import ( - Hash32, -) -from eth_utils import ( - ValidationError, -) +from typing import Iterable, NewType, Sequence +from cytoolz import iterate, partition, take +from eth_typing import Hash32 +from eth_utils import ValidationError, to_tuple + +from eth2._utils.hash import hash_eth2 MerkleTree = NewType("MerkleTree", Sequence[Sequence[Hash32]]) MerkleProof = NewType("MerkleProof", Sequence[Hash32]) @@ -71,22 +54,22 @@ def get_merkle_proof(tree: MerkleTree, item_index: int) -> Iterable[Hash32]: return () branch_indices = get_branch_indices(item_index, len(tree)) - proof_indices = [i ^ 1 for i in branch_indices][:-1] # get sibling by flipping rightmost bit + proof_indices = [i ^ 1 for i in branch_indices][ + :-1 + ] # get sibling by flipping rightmost bit for layer, proof_index in zip(reversed(tree), proof_indices): yield layer[proof_index] -def verify_merkle_branch(leaf: Hash32, - proof: Sequence[Hash32], - depth: int, - index: int, - root: Hash32) -> bool: +def verify_merkle_branch( + leaf: Hash32, proof: Sequence[Hash32], depth: int, index: int, root: Hash32 +) -> bool: """ Verify that the given ``leaf`` is on the merkle branch ``proof``. """ value = leaf for i in range(depth): - if index // (2**i) % 2: + if index // (2 ** i) % 2: value = hash_eth2(proof[i] + value) else: value = hash_eth2(value + proof[i]) diff --git a/eth2/_utils/merkle/normal.py b/eth2/_utils/merkle/normal.py index 93a4a3d6ec..6e545326d4 100644 --- a/eth2/_utils/merkle/normal.py +++ b/eth2/_utils/merkle/normal.py @@ -6,40 +6,27 @@ """ import math -from typing import ( - Sequence, - Union, -) +from typing import Sequence, Union -from cytoolz import ( - identity, - iterate, - reduce, - take, -) +from cytoolz import identity, iterate, reduce, take +from eth_typing import Hash32 -from eth_typing import ( - Hash32, -) +from eth2._utils.hash import hash_eth2 -from eth2._utils.hash import ( - hash_eth2, -) from .common import ( # noqa: F401 + MerkleProof, + MerkleTree, _calc_parent_hash, _hash_layer, get_branch_indices, get_merkle_proof, get_root, - MerkleTree, - MerkleProof, ) -def verify_merkle_proof(root: Hash32, - item: Union[bytes, bytearray], - item_index: int, - proof: MerkleProof) -> bool: +def verify_merkle_proof( + root: Hash32, item: Union[bytes, bytearray], item_index: int, proof: MerkleProof +) -> bool: """ Verify a Merkle proof against a root hash. """ @@ -50,7 +37,9 @@ def verify_merkle_proof(root: Hash32, for branch_index in branch_indices ] proof_root = reduce( - lambda n1, n2_and_order: _calc_parent_hash(*n2_and_order[1]([n1, n2_and_order[0]])), + lambda n1, n2_and_order: _calc_parent_hash( + *n2_and_order[1]([n1, n2_and_order[0]]) + ), zip(proof, node_orderers), leaf, ) diff --git a/eth2/_utils/merkle/sparse.py b/eth2/_utils/merkle/sparse.py index 80a3a9c37c..aa96dc00c4 100644 --- a/eth2/_utils/merkle/sparse.py +++ b/eth2/_utils/merkle/sparse.py @@ -5,33 +5,22 @@ not considered to be part of the tree. """ -from typing import ( - Sequence, - Union, - TYPE_CHECKING, -) +from typing import TYPE_CHECKING, Sequence, Union -from eth_utils.toolz import ( - cons, - iterate, - take, -) -from eth2._utils.hash import ( - hash_eth2, -) +from eth_typing import Hash32 +from eth_utils.toolz import cons, iterate, take + +from eth2._utils.hash import hash_eth2 from eth2._utils.tuple import update_tuple_item -from eth_typing import ( - Hash32, -) from .common import ( # noqa: F401 + MerkleProof, + MerkleTree, _calc_parent_hash, _hash_layer, get_branch_indices, get_merkle_proof, get_root, - MerkleTree, - MerkleProof, ) if TYPE_CHECKING: @@ -39,14 +28,16 @@ TreeDepth = 32 EmptyNodeHashes = tuple( - take(TreeDepth, iterate(lambda node_hash: hash_eth2(node_hash + node_hash), b'\x00' * 32)) + take( + TreeDepth, + iterate(lambda node_hash: hash_eth2(node_hash + node_hash), b"\x00" * 32), + ) ) -def verify_merkle_proof(root: Hash32, - leaf: Hash32, - index: int, - proof: MerkleProof) -> bool: +def verify_merkle_proof( + root: Hash32, leaf: Hash32, index: int, proof: MerkleProof +) -> bool: """ Verify that the given ``item`` is on the merkle branch ``proof`` starting with the given ``root``. @@ -54,7 +45,7 @@ def verify_merkle_proof(root: Hash32, assert len(proof) == TreeDepth value = leaf for i in range(TreeDepth): - if index // (2**i) % 2: + if index // (2 ** i) % 2: value = hash_eth2(proof[i] + value) else: value = hash_eth2(value + proof[i]) @@ -82,11 +73,7 @@ def calc_merkle_tree_from_leaves(leaves: Sequence[Hash32]) -> MerkleTree: tree: Tuple[Sequence[Hash32], ...] = (leaves,) for i in range(TreeDepth): if len(tree[0]) % 2 == 1: - tree = update_tuple_item( - tree, - 0, - tuple(tree[0]) + (EmptyNodeHashes[i],), - ) + tree = update_tuple_item(tree, 0, tuple(tree[0]) + (EmptyNodeHashes[i],)) tree = tuple(cons(_hash_layer(tree[0]), tree)) return MerkleTree(tree) diff --git a/eth2/_utils/numeric.py b/eth2/_utils/numeric.py index 0a2dfa24fc..5578c89d0e 100644 --- a/eth2/_utils/numeric.py +++ b/eth2/_utils/numeric.py @@ -1,8 +1,6 @@ import decimal -from eth_typing import ( - Hash32, -) +from eth_typing import Hash32 def bitwise_xor(a: Hash32, b: Hash32) -> Hash32: @@ -22,17 +20,9 @@ def integer_squareroot(value: int) -> int: root of a 256-bit integer is a 128-bit integer. """ if not isinstance(value, int) or isinstance(value, bool): - raise ValueError( - "Value must be an integer: Got: {0}".format( - type(value), - ) - ) + raise ValueError("Value must be an integer: Got: {0}".format(type(value))) if value < 0: - raise ValueError( - "Value cannot be negative: Got: {0}".format( - value, - ) - ) + raise ValueError("Value cannot be negative: Got: {0}".format(value)) with decimal.localcontext() as ctx: ctx.prec = 128 diff --git a/eth2/_utils/ssz.py b/eth2/_utils/ssz.py index 4a95c428cc..69b05fcb8d 100644 --- a/eth2/_utils/ssz.py +++ b/eth2/_utils/ssz.py @@ -1,32 +1,25 @@ -from typing import ( - Iterable, - Optional, - Tuple, -) +from typing import Iterable, Optional, Tuple +from eth_utils import ValidationError, to_tuple +from eth_utils.toolz import curry import ssz -from eth_utils import ( - to_tuple, - ValidationError, -) -from eth_utils.toolz import ( - curry, -) - from eth2.beacon.types.blocks import BaseBeaconBlock @to_tuple -def diff_ssz_object(left: BaseBeaconBlock, - right: BaseBeaconBlock) -> Optional[Iterable[Tuple[str, str, str]]]: +def diff_ssz_object( + left: BaseBeaconBlock, right: BaseBeaconBlock +) -> Optional[Iterable[Tuple[str, str, str]]]: if left != right: ssz_type = type(left) for field_name, field_type in ssz_type._meta.fields: left_value = getattr(left, field_name) right_value = getattr(right, field_name) - if isinstance(field_type, type) and issubclass(field_type, ssz.Serializable): + if isinstance(field_type, type) and issubclass( + field_type, ssz.Serializable + ): sub_diff = diff_ssz_object(left_value, right_value) for sub_field_name, sub_left_value, sub_right_value in sub_diff: yield ( @@ -36,33 +29,27 @@ def diff_ssz_object(left: BaseBeaconBlock, ) elif isinstance(field_type, ssz.sedes.List): if tuple(left_value) != tuple(right_value): - yield ( - field_name, - left_value, - right_value, - ) + yield (field_name, left_value, right_value) elif left_value != right_value: - yield ( - field_name, - left_value, - right_value, - ) + yield (field_name, left_value, right_value) else: continue @curry -def validate_ssz_equal(obj_a: BaseBeaconBlock, - obj_b: BaseBeaconBlock, - obj_a_name: str=None, - obj_b_name: str=None) -> None: +def validate_ssz_equal( + obj_a: BaseBeaconBlock, + obj_b: BaseBeaconBlock, + obj_a_name: str = None, + obj_b_name: str = None, +) -> None: if obj_a == obj_b: return if obj_a_name is None: - obj_a_name = obj_a.__class__.__name__ + '_a' + obj_a_name = obj_a.__class__.__name__ + "_a" if obj_b_name is None: - obj_b_name = obj_b.__class__.__name__ + '_b' + obj_b_name = obj_b.__class__.__name__ + "_b" diff = diff_ssz_object(obj_a, obj_b) if len(diff) == 0: @@ -73,8 +60,7 @@ def validate_ssz_equal(obj_a: BaseBeaconBlock, diff_error_message = "\n - ".join( f"{field_name.ljust(longest_field_name, ' ')}:\n " f"(actual) : {actual}\n (expected): {expected}" - for field_name, actual, expected - in diff + for field_name, actual, expected in diff ) error_message = ( f"Mismatch between {obj_a_name} and {obj_b_name} " @@ -84,6 +70,5 @@ def validate_ssz_equal(obj_a: BaseBeaconBlock, validate_imported_block_unchanged = validate_ssz_equal( - obj_a_name="block", - obj_b_name="imported block", + obj_a_name="block", obj_b_name="imported block" ) diff --git a/eth2/_utils/tuple.py b/eth2/_utils/tuple.py index 2f392079b7..9afb0821f4 100644 --- a/eth2/_utils/tuple.py +++ b/eth2/_utils/tuple.py @@ -1,22 +1,16 @@ -from typing import ( - Any, - Callable, - Tuple, - TypeVar, -) +from typing import Any, Callable, Tuple, TypeVar -from eth_utils import ( - ValidationError, -) +from eth_utils import ValidationError +VType = TypeVar("VType") -VType = TypeVar('VType') - -def update_tuple_item_with_fn(tuple_data: Tuple[VType, ...], - index: int, - fn: Callable[[VType, Any], VType], - *args: Any) -> Tuple[VType, ...]: +def update_tuple_item_with_fn( + tuple_data: Tuple[VType, ...], + index: int, + fn: Callable[[VType, Any], VType], + *args: Any +) -> Tuple[VType, ...]: """ Update the ``index``th item of ``tuple_data`` to the result of calling ``fn`` on the existing value. @@ -29,22 +23,17 @@ def update_tuple_item_with_fn(tuple_data: Tuple[VType, ...], except IndexError: raise ValidationError( "the length of the given tuple_data is {}, the given index {} is out of index".format( - len(tuple_data), - index, + len(tuple_data), index ) ) else: return tuple(list_data) -def update_tuple_item(tuple_data: Tuple[VType, ...], - index: int, - new_value: VType) -> Tuple[VType, ...]: +def update_tuple_item( + tuple_data: Tuple[VType, ...], index: int, new_value: VType +) -> Tuple[VType, ...]: """ Update the ``index``th item of ``tuple_data`` to ``new_value`` """ - return update_tuple_item_with_fn( - tuple_data, - index, - lambda *_: new_value - ) + return update_tuple_item_with_fn(tuple_data, index, lambda *_: new_value) diff --git a/eth2/beacon/attestation_helpers.py b/eth2/beacon/attestation_helpers.py index 59c0cf41b0..4ca7a74c54 100644 --- a/eth2/beacon/attestation_helpers.py +++ b/eth2/beacon/attestation_helpers.py @@ -1,43 +1,29 @@ -from eth2._utils.bls import ( - bls, -) - -from eth_utils import ( - ValidationError, -) +from eth_utils import ValidationError +from eth2._utils.bls import bls +from eth2.beacon.committee_helpers import get_committee_count, get_start_shard +from eth2.beacon.exceptions import SignatureError from eth2.beacon.helpers import ( + compute_start_slot_of_epoch, get_active_validator_indices, get_domain, - compute_start_slot_of_epoch, -) -from eth2.beacon.committee_helpers import ( - get_committee_count, - get_start_shard, ) from eth2.beacon.signature_domain import SignatureDomain -from eth2.beacon.types.attestations import IndexedAttestation from eth2.beacon.types.attestation_data import AttestationData -from eth2.beacon.types.attestation_data_and_custody_bits import AttestationDataAndCustodyBit -from eth2.beacon.types.states import BeaconState -from eth2.beacon.typing import ( - Slot, -) -from eth2.configs import ( - CommitteeConfig, - Eth2Config, -) -from eth2.beacon.exceptions import ( - SignatureError, +from eth2.beacon.types.attestation_data_and_custody_bits import ( + AttestationDataAndCustodyBit, ) +from eth2.beacon.types.attestations import IndexedAttestation +from eth2.beacon.types.states import BeaconState +from eth2.beacon.typing import Slot +from eth2.configs import CommitteeConfig, Eth2Config -def get_attestation_data_slot(state: BeaconState, - data: AttestationData, - config: Eth2Config) -> Slot: +def get_attestation_data_slot( + state: BeaconState, data: AttestationData, config: Eth2Config +) -> Slot: active_validator_indices = get_active_validator_indices( - state.validators, - data.target.epoch, + state.validators, data.target.epoch ) committee_count = get_committee_count( len(active_validator_indices), @@ -46,42 +32,34 @@ def get_attestation_data_slot(state: BeaconState, config.TARGET_COMMITTEE_SIZE, ) offset = ( - data.crosslink.shard + config.SHARD_COUNT - get_start_shard( - state, - data.target.epoch, - CommitteeConfig(config), - ) + data.crosslink.shard + + config.SHARD_COUNT + - get_start_shard(state, data.target.epoch, CommitteeConfig(config)) ) % config.SHARD_COUNT committees_per_slot = committee_count // config.SLOTS_PER_EPOCH - return compute_start_slot_of_epoch( - data.target.epoch, - config.SLOTS_PER_EPOCH, - ) + offset // committees_per_slot + return ( + compute_start_slot_of_epoch(data.target.epoch, config.SLOTS_PER_EPOCH) + + offset // committees_per_slot + ) -def validate_indexed_attestation_aggregate_signature(state: BeaconState, - indexed_attestation: IndexedAttestation, - slots_per_epoch: int) -> None: +def validate_indexed_attestation_aggregate_signature( + state: BeaconState, indexed_attestation: IndexedAttestation, slots_per_epoch: int +) -> None: bit_0_indices = indexed_attestation.custody_bit_0_indices bit_1_indices = indexed_attestation.custody_bit_1_indices pubkeys = ( - bls.aggregate_pubkeys( - tuple(state.validators[i].pubkey for i in bit_0_indices) - ), - bls.aggregate_pubkeys( - tuple(state.validators[i].pubkey for i in bit_1_indices) - ), + bls.aggregate_pubkeys(tuple(state.validators[i].pubkey for i in bit_0_indices)), + bls.aggregate_pubkeys(tuple(state.validators[i].pubkey for i in bit_1_indices)), ) message_hashes = ( AttestationDataAndCustodyBit( - data=indexed_attestation.data, - custody_bit=False + data=indexed_attestation.data, custody_bit=False ).hash_tree_root, AttestationDataAndCustodyBit( - data=indexed_attestation.data, - custody_bit=True, + data=indexed_attestation.data, custody_bit=True ).hash_tree_root, ) @@ -99,10 +77,12 @@ def validate_indexed_attestation_aggregate_signature(state: BeaconState, ) -def validate_indexed_attestation(state: BeaconState, - indexed_attestation: IndexedAttestation, - max_validators_per_committee: int, - slots_per_epoch: int) -> None: +def validate_indexed_attestation( + state: BeaconState, + indexed_attestation: IndexedAttestation, + max_validators_per_committee: int, + slots_per_epoch: int, +) -> None: """ Derived from spec: `is_valid_indexed_attestation`. """ @@ -139,23 +119,28 @@ def validate_indexed_attestation(state: BeaconState, ) try: - validate_indexed_attestation_aggregate_signature(state, - indexed_attestation, - slots_per_epoch) + validate_indexed_attestation_aggregate_signature( + state, indexed_attestation, slots_per_epoch + ) except SignatureError as error: raise ValidationError( - f"Incorrect aggregate signature on the {indexed_attestation}", - error, + f"Incorrect aggregate signature on the {indexed_attestation}", error ) -def is_slashable_attestation_data(data_1: AttestationData, data_2: AttestationData) -> bool: +def is_slashable_attestation_data( + data_1: AttestationData, data_2: AttestationData +) -> bool: """ Check if ``data_1`` and ``data_2`` are slashable according to Casper FFG rules. """ return ( # Double vote - (data_1 != data_2 and data_1.target.epoch == data_2.target.epoch) or + (data_1 != data_2 and data_1.target.epoch == data_2.target.epoch) + or # Surround vote - (data_1.source.epoch < data_2.source.epoch and data_2.target.epoch < data_1.target.epoch) + ( + data_1.source.epoch < data_2.source.epoch + and data_2.target.epoch < data_1.target.epoch + ) ) diff --git a/eth2/beacon/chains/base.py b/eth2/beacon/chains/base.py index 33ecf790d7..1b41ef90c5 100644 --- a/eth2/beacon/chains/base.py +++ b/eth2/beacon/chains/base.py @@ -1,77 +1,39 @@ -from abc import ( - ABC, - abstractmethod, -) +from abc import ABC, abstractmethod import logging -from typing import ( - TYPE_CHECKING, - Tuple, - Type, -) - -from eth._utils.datatypes import ( - Configurable, -) -from eth.abc import ( - AtomicDatabaseAPI, -) -from eth.exceptions import ( - BlockNotFound, -) -from eth.validation import ( - validate_word, -) -from eth_typing import ( - Hash32, -) -from eth_utils import ( - ValidationError, - encode_hex, -) +from typing import TYPE_CHECKING, Tuple, Type + +from eth._utils.datatypes import Configurable +from eth.abc import AtomicDatabaseAPI +from eth.exceptions import BlockNotFound +from eth.validation import validate_word +from eth_typing import Hash32 +from eth_utils import ValidationError, encode_hex from eth2._utils.funcs import constantly -from eth2._utils.ssz import ( - validate_imported_block_unchanged, -) -from eth2.beacon.db.chain import ( - BaseBeaconChainDB, - BeaconChainDB, -) -from eth2.beacon.exceptions import ( - BlockClassError, - StateMachineNotFound, -) +from eth2._utils.ssz import validate_imported_block_unchanged +from eth2.beacon.db.chain import BaseBeaconChainDB, BeaconChainDB +from eth2.beacon.exceptions import BlockClassError, StateMachineNotFound from eth2.beacon.operations.attestation_pool import AttestationPool -from eth2.beacon.types.attestations import ( - Attestation, -) -from eth2.beacon.types.blocks import ( - BaseBeaconBlock, -) -from eth2.beacon.types.states import ( - BeaconState, -) -from eth2.beacon.typing import ( - FromBlockParams, - Slot, -) -from eth2.configs import ( - Eth2GenesisConfig, -) +from eth2.beacon.types.attestations import Attestation +from eth2.beacon.types.blocks import BaseBeaconBlock +from eth2.beacon.types.states import BeaconState +from eth2.beacon.typing import FromBlockParams, Slot +from eth2.configs import Eth2Config, Eth2GenesisConfig if TYPE_CHECKING: - from eth2.beacon.state_machines.base import ( # noqa: F401 - BaseBeaconStateMachine, - ) + from eth2.beacon.state_machines.base import BaseBeaconStateMachine # noqa: F401 class BaseBeaconChain(Configurable, ABC): """ The base class for all BeaconChain objects """ + chaindb = None # type: BaseBeaconChainDB chaindb_class = None # type: Type[BaseBeaconChainDB] - sm_configuration = None # type: Tuple[Tuple[Slot, Type[BaseBeaconStateMachine]], ...] + sm_configuration = ( + None + ) # type: Tuple[Tuple[Slot, Type[BaseBeaconStateMachine]], ...] chain_id = None # type: int # @@ -87,11 +49,13 @@ def get_chaindb_class(cls) -> Type[BaseBeaconChainDB]: # @classmethod @abstractmethod - def from_genesis(cls, - base_db: AtomicDatabaseAPI, - genesis_state: BeaconState, - genesis_block: BaseBeaconBlock, - genesis_config: Eth2GenesisConfig) -> 'BaseBeaconChain': + def from_genesis( + cls, + base_db: AtomicDatabaseAPI, + genesis_state: BeaconState, + genesis_block: BaseBeaconBlock, + genesis_config: Eth2GenesisConfig, + ) -> "BaseBeaconChain": ... # @@ -100,24 +64,24 @@ def from_genesis(cls, @classmethod @abstractmethod def get_state_machine_class( - cls, - block: BaseBeaconBlock) -> Type['BaseBeaconStateMachine']: + cls, block: BaseBeaconBlock + ) -> Type["BaseBeaconStateMachine"]: ... @abstractmethod - def get_state_machine(self, at_slot: Slot=None) -> 'BaseBeaconStateMachine': + def get_state_machine(self, at_slot: Slot = None) -> "BaseBeaconStateMachine": ... @classmethod @abstractmethod def get_state_machine_class_for_block_slot( - cls, - slot: Slot) -> Type['BaseBeaconStateMachine']: + cls, slot: Slot + ) -> Type["BaseBeaconStateMachine"]: ... @classmethod @abstractmethod - def get_genesis_state_machine_class(self) -> Type['BaseBeaconStateMachine']: + def get_genesis_state_machine_class(self) -> Type["BaseBeaconStateMachine"]: ... # @@ -135,9 +99,9 @@ def get_block_class(self, block_root: Hash32) -> Type[BaseBeaconBlock]: ... @abstractmethod - def create_block_from_parent(self, - parent_block: BaseBeaconBlock, - block_params: FromBlockParams) -> BaseBeaconBlock: + def create_block_from_parent( + self, parent_block: BaseBeaconBlock, block_params: FromBlockParams + ) -> BaseBeaconBlock: ... @abstractmethod @@ -162,17 +126,17 @@ def get_canonical_block_root(self, slot: Slot) -> Hash32: @abstractmethod def import_block( - self, - block: BaseBeaconBlock, - perform_validation: bool=True - ) -> Tuple[BaseBeaconBlock, Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: + self, block: BaseBeaconBlock, perform_validation: bool = True + ) -> Tuple[ + BaseBeaconBlock, Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...] + ]: ... # # Attestation API # @abstractmethod - def get_attestation_by_root(self, attestation_root: Hash32)-> Attestation: + def get_attestation_by_root(self, attestation_root: Hash32) -> Attestation: ... @abstractmethod @@ -187,14 +151,17 @@ class BeaconChain(BaseBeaconChain): StateMachine classes, delegating operations to the appropriate StateMachine depending on the current block slot number. """ + logger = logging.getLogger("eth2.beacon.chains.BeaconChain") chaindb_class = BeaconChainDB # type: Type[BaseBeaconChainDB] - def __init__(self, - base_db: AtomicDatabaseAPI, - attestation_pool: AttestationPool, - genesis_config: Eth2GenesisConfig) -> None: + def __init__( + self, + base_db: AtomicDatabaseAPI, + attestation_pool: AttestationPool, + genesis_config: Eth2GenesisConfig, + ) -> None: if not self.sm_configuration: raise ValueError( "The Chain class cannot be instantiated with an empty `sm_configuration`" @@ -211,20 +178,25 @@ def __init__(self, # Helpers # @classmethod - def get_chaindb_class(cls) -> Type['BaseBeaconChainDB']: + def get_chaindb_class(cls) -> Type["BaseBeaconChainDB"]: if cls.chaindb_class is None: raise AttributeError("`chaindb_class` not set") return cls.chaindb_class + def get_config_by_slot(self, slot: Slot) -> Eth2Config: + return self.get_state_machine_class_for_block_slot(slot).config + # # Chain API # @classmethod - def from_genesis(cls, - base_db: AtomicDatabaseAPI, - genesis_state: BeaconState, - genesis_block: BaseBeaconBlock, - genesis_config: Eth2GenesisConfig) -> 'BaseBeaconChain': + def from_genesis( + cls, + base_db: AtomicDatabaseAPI, + genesis_state: BeaconState, + genesis_block: BaseBeaconBlock, + genesis_config: Eth2GenesisConfig, + ) -> "BaseBeaconChain": """ Initialize the ``BeaconChain`` from a genesis state. """ @@ -232,8 +204,7 @@ def from_genesis(cls, if type(genesis_block) != sm_class.block_class: raise BlockClassError( "Given genesis block class: {}, StateMachine.block_class: {}".format( - type(genesis_block), - sm_class.block_class + type(genesis_block), sm_class.block_class ) ) @@ -241,18 +212,17 @@ def from_genesis(cls, chaindb.persist_state(genesis_state) attestation_pool = AttestationPool() return cls._from_genesis_block( - base_db, - attestation_pool, - genesis_block, - genesis_config, + base_db, attestation_pool, genesis_block, genesis_config ) @classmethod - def _from_genesis_block(cls, - base_db: AtomicDatabaseAPI, - attestation_pool: AttestationPool, - genesis_block: BaseBeaconBlock, - genesis_config: Eth2GenesisConfig) -> 'BaseBeaconChain': + def _from_genesis_block( + cls, + base_db: AtomicDatabaseAPI, + attestation_pool: AttestationPool, + genesis_block: BaseBeaconBlock, + genesis_config: Eth2GenesisConfig, + ) -> "BaseBeaconChain": """ Initialize the ``BeaconChain`` from the genesis block. """ @@ -265,7 +235,9 @@ def _from_genesis_block(cls, # StateMachine API # @classmethod - def get_state_machine_class(cls, block: BaseBeaconBlock) -> Type['BaseBeaconStateMachine']: + def get_state_machine_class( + cls, block: BaseBeaconBlock + ) -> Type["BaseBeaconStateMachine"]: """ Returns the ``StateMachine`` instance for the given block slot number. """ @@ -273,20 +245,24 @@ def get_state_machine_class(cls, block: BaseBeaconBlock) -> Type['BaseBeaconStat @classmethod def get_state_machine_class_for_block_slot( - cls, - slot: Slot) -> Type['BaseBeaconStateMachine']: + cls, slot: Slot + ) -> Type["BaseBeaconStateMachine"]: """ Return the ``StateMachine`` class for the given block slot number. """ if cls.sm_configuration is None: - raise AttributeError("Chain classes must define the StateMachines in sm_configuration") + raise AttributeError( + "Chain classes must define the StateMachines in sm_configuration" + ) for start_slot, sm_class in reversed(cls.sm_configuration): if slot >= start_slot: return sm_class - raise StateMachineNotFound("No StateMachine available for block slot: #{0}".format(slot)) + raise StateMachineNotFound( + "No StateMachine available for block slot: #{0}".format(slot) + ) - def get_state_machine(self, at_slot: Slot=None) -> 'BaseBeaconStateMachine': + def get_state_machine(self, at_slot: Slot = None) -> "BaseBeaconStateMachine": """ Return the ``StateMachine`` instance for the given slot number. """ @@ -297,13 +273,11 @@ def get_state_machine(self, at_slot: Slot=None) -> 'BaseBeaconStateMachine': sm_class = self.get_state_machine_class_for_block_slot(slot) return sm_class( - chaindb=self.chaindb, - attestation_pool=self.attestation_pool, - slot=slot, + chaindb=self.chaindb, attestation_pool=self.attestation_pool, slot=slot ) @classmethod - def get_genesis_state_machine_class(cls) -> Type['BaseBeaconStateMachine']: + def get_genesis_state_machine_class(cls) -> Type["BaseBeaconStateMachine"]: return cls.sm_configuration[0][1] # @@ -313,11 +287,12 @@ def get_state_by_slot(self, slot: Slot) -> BeaconState: """ Return the requested state as specified by slot number. - Raise ``StateSlotNotFound`` if there's no state with the given slot in the db. + Raise ``StateNotFound`` if there's no state with the given slot number in the db. """ sm_class = self.get_state_machine_class_for_block_slot(slot) state_class = sm_class.get_state_class() - return self.chaindb.get_state_by_slot(slot, state_class) + state_root = self.chaindb.get_state_root_by_slot(slot) + return self.chaindb.get_state_by_root(state_root, state_class) # # Block API @@ -328,16 +303,16 @@ def get_block_class(self, block_root: Hash32) -> Type[BaseBeaconBlock]: block_class = sm_class.block_class return block_class - def create_block_from_parent(self, - parent_block: BaseBeaconBlock, - block_params: FromBlockParams) -> BaseBeaconBlock: + def create_block_from_parent( + self, parent_block: BaseBeaconBlock, block_params: FromBlockParams + ) -> BaseBeaconBlock: """ Passthrough helper to the ``StateMachine`` class of the block descending from the given block. """ slot = parent_block.slot + 1 if block_params.slot is None else block_params.slot return self.get_state_machine_class_for_block_slot( - slot=slot, + slot=slot ).create_block_from_parent(parent_block, block_params) def get_block_by_root(self, block_root: Hash32) -> BaseBeaconBlock: @@ -388,11 +363,15 @@ def get_canonical_block_root(self, slot: Slot) -> Hash32: """ return self.chaindb.get_canonical_block_root(slot) + def get_head_state(self) -> BeaconState: + head_state_slot = self.chaindb.get_head_state_slot() + return self.get_state_by_slot(head_state_slot) + def import_block( - self, - block: BaseBeaconBlock, - perform_validation: bool=True - ) -> Tuple[BaseBeaconBlock, Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: + self, block: BaseBeaconBlock, perform_validation: bool = True + ) -> Tuple[ + BaseBeaconBlock, Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...] + ]: """ Import a complete block and returns a 3-tuple @@ -407,9 +386,7 @@ def import_block( raise ValidationError( "Attempt to import block #{}. Cannot import block {} before importing " "its parent block at {}".format( - block.slot, - block.signing_root, - block.parent_root, + block.slot, block.signing_root, block.parent_root ) ) @@ -422,8 +399,10 @@ def import_block( prev_state_slot = head_state_slot state_machine = self.get_state_machine(prev_state_slot) - - state, imported_block = state_machine.import_block(block) + state = self.get_state_by_slot(prev_state_slot) + state, imported_block = state_machine.import_block( + block, state, check_proposer_signature=perform_validation + ) # Validate the imported block. if perform_validation: @@ -433,17 +412,12 @@ def import_block( self.chaindb.persist_state(state) fork_choice_scoring = state_machine.get_fork_choice_scoring() - ( - new_canonical_blocks, - old_canonical_blocks, - ) = self.chaindb.persist_block( - imported_block, - imported_block.__class__, - fork_choice_scoring, + (new_canonical_blocks, old_canonical_blocks) = self.chaindb.persist_block( + imported_block, imported_block.__class__, fork_choice_scoring ) self.logger.debug( - 'IMPORTED_BLOCK: slot %s | signed root %s', + "IMPORTED_BLOCK: slot %s | signed root %s", imported_block.slot, encode_hex(imported_block.signing_root), ) @@ -453,7 +427,7 @@ def import_block( # # Attestation API # - def get_attestation_by_root(self, attestation_root: Hash32)-> Attestation: + def get_attestation_by_root(self, attestation_root: Hash32) -> Attestation: block_root, index = self.chaindb.get_attestation_key_by_root(attestation_root) block = self.get_block_by_root(block_root) return block.body.attestations[index] diff --git a/eth2/beacon/chains/testnet/__init__.py b/eth2/beacon/chains/testnet/__init__.py index e1f4d75299..365e58d228 100644 --- a/eth2/beacon/chains/testnet/__init__.py +++ b/eth2/beacon/chains/testnet/__init__.py @@ -1,27 +1,14 @@ -from typing import ( - TYPE_CHECKING, -) -from eth2.beacon.chains.base import ( - BeaconChain, -) -from eth2.beacon.state_machines.forks.xiao_long_bao import ( - XiaoLongBaoStateMachine, -) -from .constants import ( - TESTNET_CHAIN_ID, -) +from typing import TYPE_CHECKING + +from eth2.beacon.chains.base import BeaconChain +from eth2.beacon.state_machines.forks.xiao_long_bao import XiaoLongBaoStateMachine + +from .constants import TESTNET_CHAIN_ID if TYPE_CHECKING: - from eth2.beacon.typing import ( # noqa: F401 - Slot, - ) - from eth2.beacon.state_machines.base import ( # noqa: F401 - BaseBeaconStateMachine, - ) - from typing import ( # noqa: F401 - Tuple, - Type, - ) + from eth2.beacon.typing import Slot # noqa: F401 + from eth2.beacon.state_machines.base import BaseBeaconStateMachine # noqa: F401 + from typing import Tuple, Type # noqa: F401 state_machine_class = XiaoLongBaoStateMachine diff --git a/eth2/beacon/chains/testnet/constants.py b/eth2/beacon/chains/testnet/constants.py index ad449a21db..714ace7eda 100644 --- a/eth2/beacon/chains/testnet/constants.py +++ b/eth2/beacon/chains/testnet/constants.py @@ -1,3 +1 @@ - - TESTNET_CHAIN_ID = 5566 diff --git a/eth2/beacon/committee_helpers.py b/eth2/beacon/committee_helpers.py index 8ff00575c8..c6f3df5965 100644 --- a/eth2/beacon/committee_helpers.py +++ b/eth2/beacon/committee_helpers.py @@ -1,77 +1,50 @@ -from typing import ( - Iterable, - Sequence, - Tuple, -) - -from eth_utils import ( - to_tuple, - ValidationError, -) -from eth_typing import ( - Hash32, - BLSPubkey, -) +from typing import Iterable, Sequence, Tuple +from eth_typing import BLSPubkey, Hash32 +from eth_utils import ValidationError, to_tuple import ssz -from eth2._utils.hash import ( - hash_eth2, -) -from eth2._utils.tuple import ( - update_tuple_item, -) -from eth2.configs import ( - CommitteeConfig, -) -from eth2.beacon.constants import ( - MAX_RANDOM_BYTE, - MAX_INDEX_COUNT, -) -from eth2.beacon.helpers import ( - get_seed, - get_active_validator_indices, -) -from eth2.beacon.typing import ( - Epoch, - Gwei, - Shard, - Slot, - ValidatorIndex, -) +from eth2._utils.hash import hash_eth2 +from eth2._utils.tuple import update_tuple_item +from eth2.beacon.constants import MAX_INDEX_COUNT, MAX_RANDOM_BYTE +from eth2.beacon.helpers import get_active_validator_indices, get_seed from eth2.beacon.types.compact_committees import CompactCommittee from eth2.beacon.types.states import BeaconState from eth2.beacon.types.validators import Validator +from eth2.beacon.typing import Epoch, Gwei, Shard, Slot, ValidatorIndex +from eth2.configs import CommitteeConfig -def get_committees_per_slot(active_validator_count: int, - shard_count: int, - slots_per_epoch: int, - target_committee_size: int) -> int: +def get_committees_per_slot( + active_validator_count: int, + shard_count: int, + slots_per_epoch: int, + target_committee_size: int, +) -> int: return max( 1, min( shard_count // slots_per_epoch, active_validator_count // slots_per_epoch // target_committee_size, - ) + ), ) -def get_committee_count(active_validator_count: int, - shard_count: int, - slots_per_epoch: int, - target_committee_size: int) -> int: - return get_committees_per_slot( - active_validator_count, - shard_count, - slots_per_epoch, - target_committee_size, - ) * slots_per_epoch +def get_committee_count( + active_validator_count: int, + shard_count: int, + slots_per_epoch: int, + target_committee_size: int, +) -> int: + return ( + get_committees_per_slot( + active_validator_count, shard_count, slots_per_epoch, target_committee_size + ) + * slots_per_epoch + ) -def get_shard_delta(state: BeaconState, - epoch: Epoch, - config: CommitteeConfig) -> int: +def get_shard_delta(state: BeaconState, epoch: Epoch, config: CommitteeConfig) -> int: shard_count = config.SHARD_COUNT slots_per_epoch = config.SLOTS_PER_EPOCH @@ -88,9 +61,7 @@ def get_shard_delta(state: BeaconState, ) -def get_start_shard(state: BeaconState, - epoch: Epoch, - config: CommitteeConfig) -> Shard: +def get_start_shard(state: BeaconState, epoch: Epoch, config: CommitteeConfig) -> Shard: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) next_epoch = state.next_epoch(config.SLOTS_PER_EPOCH) if epoch > next_epoch: @@ -103,16 +74,20 @@ def get_start_shard(state: BeaconState, while check_epoch > epoch: check_epoch -= 1 shard = ( - shard + config.SHARD_COUNT - get_shard_delta(state, Epoch(check_epoch), config) + shard + + config.SHARD_COUNT + - get_shard_delta(state, Epoch(check_epoch), config) ) % config.SHARD_COUNT return shard -def _find_proposer_in_committee(validators: Sequence[Validator], - committee: Sequence[ValidatorIndex], - epoch: Epoch, - seed: Hash32, - max_effective_balance: Gwei) -> ValidatorIndex: +def _find_proposer_in_committee( + validators: Sequence[Validator], + committee: Sequence[ValidatorIndex], + epoch: Epoch, + seed: Hash32, + max_effective_balance: Gwei, +) -> ValidatorIndex: base = int(epoch) i = 0 committee_len = len(committee) @@ -125,16 +100,18 @@ def _find_proposer_in_committee(validators: Sequence[Validator], i += 1 -def _calculate_first_committee_at_slot(state: BeaconState, - slot: Slot, - config: CommitteeConfig) -> Tuple[ValidatorIndex, ...]: +def _calculate_first_committee_at_slot( + state: BeaconState, slot: Slot, config: CommitteeConfig +) -> Tuple[ValidatorIndex, ...]: slots_per_epoch = config.SLOTS_PER_EPOCH shard_count = config.SHARD_COUNT target_committee_size = config.TARGET_COMMITTEE_SIZE current_epoch = state.current_epoch(slots_per_epoch) - active_validator_indices = get_active_validator_indices(state.validators, current_epoch) + active_validator_indices = get_active_validator_indices( + state.validators, current_epoch + ) committees_per_slot = get_committees_per_slot( len(active_validator_indices), @@ -144,28 +121,20 @@ def _calculate_first_committee_at_slot(state: BeaconState, ) offset = committees_per_slot * (slot % slots_per_epoch) - shard = ( - get_start_shard(state, current_epoch, config) + offset - ) % shard_count + shard = (get_start_shard(state, current_epoch, config) + offset) % shard_count - return get_crosslink_committee( - state, - current_epoch, - shard, - config, - ) + return get_crosslink_committee(state, current_epoch, shard, config) -def get_beacon_proposer_index(state: BeaconState, - committee_config: CommitteeConfig) -> ValidatorIndex: +def get_beacon_proposer_index( + state: BeaconState, committee_config: CommitteeConfig +) -> ValidatorIndex: """ Return the current beacon proposer index. """ first_committee = _calculate_first_committee_at_slot( - state, - state.slot, - committee_config, + state, state.slot, committee_config ) current_epoch = state.current_epoch(committee_config.SLOTS_PER_EPOCH) @@ -181,10 +150,9 @@ def get_beacon_proposer_index(state: BeaconState, ) -def compute_shuffled_index(index: int, - index_count: int, - seed: Hash32, - shuffle_round_count: int) -> int: +def compute_shuffled_index( + index: int, index_count: int, seed: Hash32, shuffle_round_count: int +) -> int: """ Return `p(index)` in a pseudorandom permutation `p` of `0...index_count-1` with ``seed`` as entropy. @@ -206,17 +174,19 @@ def compute_shuffled_index(index: int, new_index = index for current_round in range(shuffle_round_count): - pivot = int.from_bytes( - hash_eth2(seed + current_round.to_bytes(1, 'little'))[0:8], - 'little', - ) % index_count + pivot = ( + int.from_bytes( + hash_eth2(seed + current_round.to_bytes(1, "little"))[0:8], "little" + ) + % index_count + ) flip = (pivot + index_count - new_index) % index_count position = max(new_index, flip) source = hash_eth2( - seed + - current_round.to_bytes(1, 'little') + - (position // 256).to_bytes(4, 'little') + seed + + current_round.to_bytes(1, "little") + + (position // 256).to_bytes(4, "little") ) byte = source[(position % 256) // 8] bit = (byte >> (position % 8)) % 2 @@ -225,31 +195,31 @@ def compute_shuffled_index(index: int, return new_index -def _compute_committee(indices: Sequence[ValidatorIndex], - seed: Hash32, - index: int, - count: int, - shuffle_round_count: int) -> Iterable[ValidatorIndex]: +def _compute_committee( + indices: Sequence[ValidatorIndex], + seed: Hash32, + index: int, + count: int, + shuffle_round_count: int, +) -> Iterable[ValidatorIndex]: start = (len(indices) * index) // count end = (len(indices) * (index + 1)) // count for i in range(start, end): - shuffled_index = compute_shuffled_index(i, len(indices), seed, shuffle_round_count) + shuffled_index = compute_shuffled_index( + i, len(indices), seed, shuffle_round_count + ) yield indices[shuffled_index] @to_tuple -def get_crosslink_committee(state: BeaconState, - epoch: Epoch, - shard: Shard, - config: CommitteeConfig) -> Iterable[ValidatorIndex]: +def get_crosslink_committee( + state: BeaconState, epoch: Epoch, shard: Shard, config: CommitteeConfig +) -> Iterable[ValidatorIndex]: target_shard = ( shard + config.SHARD_COUNT - get_start_shard(state, epoch, config) ) % config.SHARD_COUNT - active_validator_indices = get_active_validator_indices( - state.validators, - epoch, - ) + active_validator_indices = get_active_validator_indices(state.validators, epoch) return _compute_committee( indices=active_validator_indices, @@ -265,10 +235,9 @@ def get_crosslink_committee(state: BeaconState, ) -def _compute_compact_committee_for_shard_in_epoch(state: BeaconState, - epoch: Epoch, - shard: Shard, - config: CommitteeConfig) -> CompactCommittee: +def _compute_compact_committee_for_shard_in_epoch( + state: BeaconState, epoch: Epoch, shard: Shard, config: CommitteeConfig +) -> CompactCommittee: effective_balance_increment = config.EFFECTIVE_BALANCE_INCREMENT pubkeys: Tuple[BLSPubkey, ...] = tuple() @@ -281,23 +250,17 @@ def _compute_compact_committee_for_shard_in_epoch(state: BeaconState, compact_validator = (index << 16) + (validator.slashed << 15) + compact_balance compact_validators += (compact_validator,) - return CompactCommittee( - pubkeys=pubkeys, - compact_validators=compact_validators, - ) + return CompactCommittee(pubkeys=pubkeys, compact_validators=compact_validators) -def get_compact_committees_root(state: BeaconState, - epoch: Epoch, - config: CommitteeConfig) -> Hash32: +def get_compact_committees_root( + state: BeaconState, epoch: Epoch, config: CommitteeConfig +) -> Hash32: shard_count = config.SHARD_COUNT committees = (CompactCommittee(),) * shard_count start_shard = get_start_shard(state, epoch, config) - active_validator_indices = get_active_validator_indices( - state.validators, - epoch, - ) + active_validator_indices = get_active_validator_indices(state.validators, epoch) committee_count = get_committee_count( len(active_validator_indices), config.SHARD_COUNT, @@ -307,14 +270,9 @@ def get_compact_committees_root(state: BeaconState, for committee_number in range(committee_count): shard = Shard((start_shard + committee_number) % shard_count) compact_committee = _compute_compact_committee_for_shard_in_epoch( - state, - epoch, - shard, - config, - ) - committees = update_tuple_item( - committees, - shard, - compact_committee, + state, epoch, shard, config ) - return ssz.get_hash_tree_root(committees, sedes=ssz.sedes.Vector(CompactCommittee, shard_count)) + committees = update_tuple_item(committees, shard, compact_committee) + return ssz.get_hash_tree_root( + committees, sedes=ssz.sedes.Vector(CompactCommittee, shard_count) + ) diff --git a/eth2/beacon/constants.py b/eth2/beacon/constants.py index 59a2dc8b57..c82ce79d8d 100644 --- a/eth2/beacon/constants.py +++ b/eth2/beacon/constants.py @@ -1,32 +1,24 @@ -from eth.constants import ( - ZERO_HASH32, -) -from eth_typing import ( - BLSSignature, - BLSPubkey, -) -from eth2.beacon.typing import ( - Epoch, - Timestamp, -) - - -EMPTY_SIGNATURE = BLSSignature(b'\x00' * 96) -EMPTY_PUBKEY = BLSPubkey(b'\x00' * 48) -GWEI_PER_ETH = 10**9 -FAR_FUTURE_EPOCH = Epoch(2**64 - 1) +from eth.constants import ZERO_HASH32 +from eth_typing import BLSPubkey, BLSSignature + +from eth2.beacon.typing import Epoch, Timestamp + +EMPTY_SIGNATURE = BLSSignature(b"\x00" * 96) +EMPTY_PUBKEY = BLSPubkey(b"\x00" * 48) +GWEI_PER_ETH = 10 ** 9 +FAR_FUTURE_EPOCH = Epoch(2 ** 64 - 1) GENESIS_PARENT_ROOT = ZERO_HASH32 ZERO_TIMESTAMP = Timestamp(0) -MAX_INDEX_COUNT = 2**40 +MAX_INDEX_COUNT = 2 ** 40 -MAX_RANDOM_BYTE = 2**8 - 1 +MAX_RANDOM_BYTE = 2 ** 8 - 1 BASE_REWARDS_PER_EPOCH = 5 -DEPOSIT_CONTRACT_TREE_DEPTH = 2**5 +DEPOSIT_CONTRACT_TREE_DEPTH = 2 ** 5 SECONDS_PER_DAY = 86400 diff --git a/eth2/beacon/db/chain.py b/eth2/beacon/db/chain.py index fbd2bd06ec..c34fca98cb 100644 --- a/eth2/beacon/db/chain.py +++ b/eth2/beacon/db/chain.py @@ -1,57 +1,15 @@ from abc import ABC, abstractmethod import functools - -from typing import ( - Iterable, - Optional, - Tuple, - Type, -) -from cytoolz import ( - concat, - first, - sliding_window, -) - +from typing import Iterable, Optional, Tuple, Type + +from cytoolz import concat, first, sliding_window +from eth.abc import AtomicDatabaseAPI, DatabaseAPI +from eth.constants import ZERO_HASH32 +from eth.exceptions import BlockNotFound, CanonicalHeadNotFound, ParentNotFound +from eth.validation import validate_word +from eth_typing import Hash32 +from eth_utils import ValidationError, encode_hex, to_tuple import ssz -from eth_typing import ( - Hash32, -) -from eth_utils import ( - encode_hex, - to_tuple, - ValidationError, -) - -from eth.abc import ( - DatabaseAPI, - AtomicDatabaseAPI, -) -from eth.constants import ( - ZERO_HASH32, -) -from eth.exceptions import ( - BlockNotFound, - CanonicalHeadNotFound, - ParentNotFound, - StateRootNotFound, -) -from eth.validation import ( - validate_word, -) -from eth2.beacon.fork_choice.scoring import ScoringFn as ForkChoiceScoringFn -from eth2.beacon.helpers import ( - compute_epoch_of_slot, -) -from eth2.beacon.typing import ( - Epoch, - Slot, -) -from eth2.beacon.types.states import BeaconState # noqa: F401 -from eth2.beacon.types.blocks import ( # noqa: F401 - BaseBeaconBlock, - BeaconBlock, -) from eth2.beacon.db.exceptions import ( AttestationRootNotFound, @@ -59,27 +17,28 @@ HeadStateSlotNotFound, JustifiedHeadNotFound, MissingForkChoiceScoringFns, - StateSlotNotFound, + StateNotFound, ) from eth2.beacon.db.schema import SchemaV1 - -from eth2.configs import ( - Eth2GenesisConfig, -) +from eth2.beacon.fork_choice.scoring import ScoringFn as ForkChoiceScoringFn +from eth2.beacon.helpers import compute_epoch_of_slot +from eth2.beacon.types.blocks import BaseBeaconBlock, BeaconBlock # noqa: F401 +from eth2.beacon.types.states import BeaconState # noqa: F401 +from eth2.beacon.typing import Epoch, Slot +from eth2.configs import Eth2GenesisConfig class AttestationKey(ssz.Serializable): - fields = [ - ('block_root', ssz.sedes.bytes32), - ('index', ssz.sedes.uint8), - ] + fields = [("block_root", ssz.sedes.bytes32), ("index", ssz.sedes.uint8)] class BaseBeaconChainDB(ABC): db: AtomicDatabaseAPI = None @abstractmethod - def __init__(self, db: AtomicDatabaseAPI, genesis_config: Eth2GenesisConfig) -> None: + def __init__( + self, db: AtomicDatabaseAPI, genesis_config: Eth2GenesisConfig + ) -> None: pass # @@ -87,10 +46,10 @@ def __init__(self, db: AtomicDatabaseAPI, genesis_config: Eth2GenesisConfig) -> # @abstractmethod def persist_block( - self, - block: BaseBeaconBlock, - block_class: Type[BaseBeaconBlock], - fork_choice_scoring: ForkChoiceScoringFn, + self, + block: BaseBeaconBlock, + block_class: Type[BaseBeaconBlock], + fork_choice_scoring: ForkChoiceScoringFn, ) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: pass @@ -103,9 +62,9 @@ def get_genesis_block_root(self) -> Hash32: pass @abstractmethod - def get_canonical_block_by_slot(self, - slot: Slot, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def get_canonical_block_by_slot( + self, slot: Slot, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: pass @abstractmethod @@ -125,14 +84,13 @@ def get_justified_head(self, block_class: Type[BaseBeaconBlock]) -> BaseBeaconBl pass @abstractmethod - def get_block_by_root(self, - block_root: Hash32, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def get_block_by_root( + self, block_root: Hash32, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: pass @abstractmethod - def get_slot_by_root(self, - block_root: Hash32) -> Slot: + def get_slot_by_root(self, block_root: Hash32) -> Slot: pass @abstractmethod @@ -145,15 +103,15 @@ def block_exists(self, block_root: Hash32) -> bool: @abstractmethod def persist_block_chain( - self, - blocks: Iterable[BaseBeaconBlock], - block_class: Type[BaseBeaconBlock], - fork_choice_scoring: Iterable[ForkChoiceScoringFn], + self, + blocks: Iterable[BaseBeaconBlock], + block_class: Type[BaseBeaconBlock], + fork_choice_scoring: Iterable[ForkChoiceScoringFn], ) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: pass @abstractmethod - def set_score(self, block: BaseBeaconBlock, score: int)-> None: + def set_score(self, block: BaseBeaconBlock, score: int) -> None: pass # @@ -164,23 +122,26 @@ def get_head_state_slot(self) -> Slot: pass @abstractmethod - def get_state_by_slot(self, slot: Slot, state_class: Type[BeaconState]) -> BeaconState: + def get_state_root_by_slot(self, slot: Slot) -> Hash32: pass @abstractmethod - def get_state_by_root(self, state_root: Hash32, state_class: Type[BeaconState]) -> BeaconState: + def get_state_by_root( + self, state_root: Hash32, state_class: Type[BeaconState] + ) -> BeaconState: pass @abstractmethod - def persist_state(self, - state: BeaconState) -> None: + def persist_state(self, state: BeaconState) -> None: pass # # Attestation API # @abstractmethod - def get_attestation_key_by_root(self, attestation_root: Hash32)-> Tuple[Hash32, int]: + def get_attestation_key_by_root( + self, attestation_root: Hash32 + ) -> Tuple[Hash32, int]: pass @abstractmethod @@ -200,7 +161,9 @@ def get(self, key: bytes) -> bytes: class BeaconChainDB(BaseBeaconChainDB): - def __init__(self, db: AtomicDatabaseAPI, genesis_config: Eth2GenesisConfig) -> None: + def __init__( + self, db: AtomicDatabaseAPI, genesis_config: Eth2GenesisConfig + ) -> None: self.db = db self.genesis_config = genesis_config @@ -222,10 +185,10 @@ def _get_highest_justified_epoch(self, db: DatabaseAPI) -> Epoch: return self.genesis_config.GENESIS_EPOCH def persist_block( - self, - block: BaseBeaconBlock, - block_class: Type[BaseBeaconBlock], - fork_choice_scoring: ForkChoiceScoringFn, + self, + block: BaseBeaconBlock, + block_class: Type[BaseBeaconBlock], + fork_choice_scoring: ForkChoiceScoringFn, ) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: """ Persist the given block. @@ -238,19 +201,16 @@ def persist_block( @classmethod def _persist_block( - cls, - db: DatabaseAPI, - block: BaseBeaconBlock, - block_class: Type[BaseBeaconBlock], - fork_choice_scoring: ForkChoiceScoringFn, + cls, + db: DatabaseAPI, + block: BaseBeaconBlock, + block_class: Type[BaseBeaconBlock], + fork_choice_scoring: ForkChoiceScoringFn, ) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: - block_chain = (block, ) - scorings = (fork_choice_scoring, ) + block_chain = (block,) + scorings = (fork_choice_scoring,) new_canonical_blocks, old_canonical_blocks = cls._persist_block_chain( - db, - block_chain, - block_class, - scorings, + db, block_chain, block_class, scorings ) return new_canonical_blocks, old_canonical_blocks @@ -282,15 +242,13 @@ def _get_canonical_block_root(db: DatabaseAPI, slot: Slot) -> Hash32: try: encoded_key = db[slot_to_root_key] except KeyError: - raise BlockNotFound( - "No canonical block for block slot #{0}".format(slot) - ) + raise BlockNotFound("No canonical block for block slot #{0}".format(slot)) else: return ssz.decode(encoded_key, sedes=ssz.sedes.bytes32) - def get_canonical_block_by_slot(self, - slot: Slot, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def get_canonical_block_by_slot( + self, slot: Slot, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: """ Return the block with the given slot in the canonical chain. @@ -301,10 +259,8 @@ def get_canonical_block_by_slot(self, @classmethod def _get_canonical_block_by_slot( - cls, - db: DatabaseAPI, - slot: Slot, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + cls, db: DatabaseAPI, slot: Slot, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: canonical_block_root = cls._get_canonical_block_root(db, slot) return cls._get_block_by_root(db, canonical_block_root, block_class) @@ -315,9 +271,9 @@ def get_canonical_head(self, block_class: Type[BaseBeaconBlock]) -> BaseBeaconBl return self._get_canonical_head(self.db, block_class) @classmethod - def _get_canonical_head(cls, - db: DatabaseAPI, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def _get_canonical_head( + cls, db: DatabaseAPI, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: canonical_head_root = cls._get_canonical_head_root(db) return cls._get_block_by_root(db, Hash32(canonical_head_root), block_class) @@ -342,9 +298,9 @@ def get_finalized_head(self, block_class: Type[BaseBeaconBlock]) -> BaseBeaconBl return self._get_finalized_head(self.db, block_class) @classmethod - def _get_finalized_head(cls, - db: DatabaseAPI, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def _get_finalized_head( + cls, db: DatabaseAPI, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: finalized_head_root = cls._get_finalized_head_root(db) return cls._get_block_by_root(db, Hash32(finalized_head_root), block_class) @@ -363,9 +319,9 @@ def get_justified_head(self, block_class: Type[BaseBeaconBlock]) -> BaseBeaconBl return self._get_justified_head(self.db, block_class) @classmethod - def _get_justified_head(cls, - db: DatabaseAPI, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def _get_justified_head( + cls, db: DatabaseAPI, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: justified_head_root = cls._get_justified_head_root(db) return cls._get_block_by_root(db, Hash32(justified_head_root), block_class) @@ -377,15 +333,15 @@ def _get_justified_head_root(db: DatabaseAPI) -> Hash32: raise JustifiedHeadNotFound("No justified head set for this chain") return justified_head_root - def get_block_by_root(self, - block_root: Hash32, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def get_block_by_root( + self, block_root: Hash32, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: return self._get_block_by_root(self.db, block_root, block_class) @staticmethod - def _get_block_by_root(db: DatabaseAPI, - block_root: Hash32, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: + def _get_block_by_root( + db: DatabaseAPI, block_root: Hash32, block_class: Type[BaseBeaconBlock] + ) -> BaseBeaconBlock: """ Return the requested block header as specified by block root. @@ -395,12 +351,12 @@ def _get_block_by_root(db: DatabaseAPI, try: block_ssz = db[block_root] except KeyError: - raise BlockNotFound("No block with root {0} found".format( - encode_hex(block_root))) + raise BlockNotFound( + "No block with root {0} found".format(encode_hex(block_root)) + ) return _decode_block(block_ssz, block_class) - def get_slot_by_root(self, - block_root: Hash32) -> Slot: + def get_slot_by_root(self, block_root: Hash32) -> Slot: """ Return the requested block header as specified by block root. @@ -409,14 +365,14 @@ def get_slot_by_root(self, return self._get_slot_by_root(self.db, block_root) @staticmethod - def _get_slot_by_root(db: DatabaseAPI, - block_root: Hash32) -> Slot: + def _get_slot_by_root(db: DatabaseAPI, block_root: Hash32) -> Slot: validate_word(block_root, title="block root") try: encoded_slot = db[SchemaV1.make_block_root_to_slot_lookup_key(block_root)] except KeyError: - raise BlockNotFound("No block with root {0} found".format( - encode_hex(block_root))) + raise BlockNotFound( + "No block with root {0} found".format(encode_hex(block_root)) + ) return Slot(ssz.decode(encoded_slot, sedes=ssz.sedes.uint64)) def get_score(self, block_root: Hash32) -> int: @@ -427,8 +383,9 @@ def _get_score(db: DatabaseAPI, block_root: Hash32) -> int: try: encoded_score = db[SchemaV1.make_block_root_to_score_lookup_key(block_root)] except KeyError: - raise BlockNotFound("No block with hash {0} found".format( - encode_hex(block_root))) + raise BlockNotFound( + "No block with hash {0} found".format(encode_hex(block_root)) + ) return ssz.decode(encoded_score, sedes=ssz.sedes.uint64) def block_exists(self, block_root: Hash32) -> bool: @@ -440,23 +397,23 @@ def _block_exists(db: DatabaseAPI, block_root: Hash32) -> bool: return block_root in db def persist_block_chain( - self, - blocks: Iterable[BaseBeaconBlock], - block_class: Type[BaseBeaconBlock], - fork_choice_scorings: Iterable[ForkChoiceScoringFn], + self, + blocks: Iterable[BaseBeaconBlock], + block_class: Type[BaseBeaconBlock], + fork_choice_scorings: Iterable[ForkChoiceScoringFn], ) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: """ Return two iterable of blocks, the first containing the new canonical blocks, the second containing the old canonical headers """ with self.db.atomic_batch() as db: - return self._persist_block_chain(db, blocks, block_class, fork_choice_scorings) + return self._persist_block_chain( + db, blocks, block_class, fork_choice_scorings + ) @staticmethod def _set_block_score_to_db( - db: DatabaseAPI, - block: BaseBeaconBlock, - score: int, + db: DatabaseAPI, block: BaseBeaconBlock, score: int ) -> int: # NOTE if we change the score serialization, we will likely need to # patch up the fork choice logic. @@ -468,19 +425,15 @@ def _set_block_score_to_db( return score def set_score(self, block: BaseBeaconBlock, score: int) -> None: - self.__class__._set_block_score_to_db( - self.db, - block, - score, - ) + self.__class__._set_block_score_to_db(self.db, block, score) @classmethod def _persist_block_chain( - cls, - db: DatabaseAPI, - blocks: Iterable[BaseBeaconBlock], - block_class: Type[BaseBeaconBlock], - fork_choice_scorings: Iterable[ForkChoiceScoringFn], + cls, + db: DatabaseAPI, + blocks: Iterable[BaseBeaconBlock], + block_class: Type[BaseBeaconBlock], + fork_choice_scorings: Iterable[ForkChoiceScoringFn], ) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: blocks_iterator = iter(blocks) scorings_iterator = iter(fork_choice_scorings) @@ -492,7 +445,9 @@ def _persist_block_chain( return tuple(), tuple() try: - previous_canonical_head = cls._get_canonical_head(db, block_class).signing_root + previous_canonical_head = cls._get_canonical_head( + db, block_class + ).signing_root head_score = cls._get_score(db, previous_canonical_head) except CanonicalHeadNotFound: no_canonical_head = True @@ -511,10 +466,7 @@ def _persist_block_chain( score = first_scoring(first_block) curr_block_head = first_block - db.set( - curr_block_head.signing_root, - ssz.encode(curr_block_head), - ) + db.set(curr_block_head.signing_root, ssz.encode(curr_block_head)) cls._add_block_root_to_slot_lookup(db, curr_block_head) cls._set_block_score_to_db(db, curr_block_head, score) cls._add_attestations_root_to_block_lookup(db, curr_block_head) @@ -532,10 +484,7 @@ def _persist_block_chain( ) curr_block_head = child - db.set( - curr_block_head.signing_root, - ssz.encode(curr_block_head), - ) + db.set(curr_block_head.signing_root, ssz.encode(curr_block_head)) cls._add_block_root_to_slot_lookup(db, curr_block_head) cls._add_attestations_root_to_block_lookup(db, curr_block_head) @@ -549,19 +498,20 @@ def _persist_block_chain( cls._set_block_score_to_db(db, curr_block_head, score) if no_canonical_head: - return cls._set_as_canonical_chain_head(db, curr_block_head.signing_root, block_class) + return cls._set_as_canonical_chain_head( + db, curr_block_head.signing_root, block_class + ) if score > head_score: - return cls._set_as_canonical_chain_head(db, curr_block_head.signing_root, block_class) + return cls._set_as_canonical_chain_head( + db, curr_block_head.signing_root, block_class + ) else: return tuple(), tuple() @classmethod def _set_as_canonical_chain_head( - cls, - db: DatabaseAPI, - block_root: Hash32, - block_class: Type[BaseBeaconBlock] + cls, db: DatabaseAPI, block_root: Hash32, block_class: Type[BaseBeaconBlock] ) -> Tuple[Tuple[BaseBeaconBlock, ...], Tuple[BaseBeaconBlock, ...]]: """ Set the canonical chain HEAD to the block as specified by the @@ -577,7 +527,9 @@ def _set_as_canonical_chain_head( "Cannot use unknown block root as canonical head: {}".format(block_root) ) - new_canonical_blocks = tuple(reversed(cls._find_new_ancestors(db, block, block_class))) + new_canonical_blocks = tuple( + reversed(cls._find_new_ancestors(db, block, block_class)) + ) old_canonical_blocks = [] for block in new_canonical_blocks: @@ -587,7 +539,9 @@ def _set_as_canonical_chain_head( # no old_canonical block, and no more possible break else: - old_canonical_block = cls._get_block_by_root(db, old_canonical_root, block_class) + old_canonical_block = cls._get_block_by_root( + db, old_canonical_root, block_class + ) old_canonical_blocks.append(old_canonical_block) for block in new_canonical_blocks: @@ -600,10 +554,8 @@ def _set_as_canonical_chain_head( @classmethod @to_tuple def _find_new_ancestors( - cls, - db: DatabaseAPI, - block: BaseBeaconBlock, - block_class: Type[BaseBeaconBlock]) -> Iterable[BaseBeaconBlock]: + cls, db: DatabaseAPI, block: BaseBeaconBlock, block_class: Type[BaseBeaconBlock] + ) -> Iterable[BaseBeaconBlock]: """ Return the chain leading up from the given block until (but not including) the first ancestor it has in common with our canonical chain. @@ -640,9 +592,7 @@ def _add_block_slot_to_root_lookup(db: DatabaseAPI, block: BaseBeaconBlock) -> N Set a record in the database to allow looking up this block by its block slot. """ - block_slot_to_root_key = SchemaV1.make_block_slot_to_root_lookup_key( - block.slot - ) + block_slot_to_root_key = SchemaV1.make_block_slot_to_root_lookup_key(block.slot) db.set( block_slot_to_root_key, ssz.encode(block.signing_root, sedes=ssz.sedes.bytes32), @@ -657,10 +607,7 @@ def _add_block_root_to_slot_lookup(db: DatabaseAPI, block: BaseBeaconBlock) -> N block_root_to_slot_key = SchemaV1.make_block_root_to_slot_lookup_key( block.signing_root ) - db.set( - block_root_to_slot_key, - ssz.encode(block.slot, sedes=ssz.sedes.uint64), - ) + db.set(block_root_to_slot_key, ssz.encode(block.slot, sedes=ssz.sedes.uint64)) # # Beacon State API @@ -681,8 +628,7 @@ def _add_slot_to_state_root_lookup(self, slot: Slot, state_root: Hash32) -> None """ slot_to_state_root_key = SchemaV1.make_slot_to_state_root_lookup_key(slot) self.db.set( - slot_to_state_root_key, - ssz.encode(state_root, sedes=ssz.sedes.bytes32), + slot_to_state_root_key, ssz.encode(state_root, sedes=ssz.sedes.bytes32) ) def get_head_state_slot(self) -> Slot: @@ -692,59 +638,54 @@ def get_head_state_slot(self) -> Slot: def _get_head_state_slot(db: DatabaseAPI) -> Slot: try: encoded_head_state_slot = db[SchemaV1.make_head_state_slot_lookup_key()] - head_state_slot = ssz.decode(encoded_head_state_slot, sedes=ssz.sedes.uint64) + head_state_slot = ssz.decode( + encoded_head_state_slot, sedes=ssz.sedes.uint64 + ) except KeyError: raise HeadStateSlotNotFound("No head state slot found") return head_state_slot - def get_state_by_slot(self, slot: Slot, state_class: Type[BeaconState]) -> BeaconState: - return self._get_state_by_slot(self.db, slot, state_class) + def get_state_root_by_slot(self, slot: Slot) -> Hash32: + return self._get_state_root_by_slot(self.db, slot) @staticmethod - def _get_state_by_slot(db: DatabaseAPI, - slot: Slot, - state_class: Type[BeaconState]) -> BeaconState: + def _get_state_root_by_slot(db: DatabaseAPI, slot: Slot) -> Hash32: """ Return the requested beacon state as specified by slot. - Raises StateSlotNotFound if it is not present in the db. + Raises StateNotFound if it is not present in the db. """ slot_to_state_root_key = SchemaV1.make_slot_to_state_root_lookup_key(slot) try: state_root_ssz = db[slot_to_state_root_key] except KeyError: - raise StateSlotNotFound( - "No state root for slot #{0}".format(slot) - ) + raise StateNotFound("No state root for slot #{0}".format(slot)) state_root = ssz.decode(state_root_ssz, sedes=ssz.sedes.bytes32) - try: - state_ssz = db[state_root] - except KeyError: - raise StateRootNotFound(f"No state with root {encode_hex(state_root)} found") - return _decode_state(state_ssz, state_class) + return state_root - def get_state_by_root(self, state_root: Hash32, state_class: Type[BeaconState]) -> BeaconState: + def get_state_by_root( + self, state_root: Hash32, state_class: Type[BeaconState] + ) -> BeaconState: return self._get_state_by_root(self.db, state_root, state_class) @staticmethod - def _get_state_by_root(db: DatabaseAPI, - state_root: Hash32, - state_class: Type[BeaconState]) -> BeaconState: + def _get_state_by_root( + db: DatabaseAPI, state_root: Hash32, state_class: Type[BeaconState] + ) -> BeaconState: """ Return the requested beacon state as specified by state hash. - Raises StateRootNotFound if it is not present in the db. + Raises StateNotFound if it is not present in the db. """ # TODO: validate_state_root try: state_ssz = db[state_root] except KeyError: - raise StateRootNotFound(f"No state with root {encode_hex(state_root)} found") + raise StateNotFound(f"No state with root {encode_hex(state_root)} found") return _decode_state(state_ssz, state_class) - def persist_state(self, - state: BeaconState) -> None: + def persist_state(self, state: BeaconState) -> None: """ Persist the given BeaconState. @@ -753,10 +694,7 @@ def persist_state(self, return self._persist_state(state) def _persist_state(self, state: BeaconState) -> None: - self.db.set( - state.hash_tree_root, - ssz.encode(state), - ) + self.db.set(state.hash_tree_root, ssz.encode(state)) self._add_slot_to_state_root_lookup(state.slot, state.hash_tree_root) self._persist_finalized_head(state) @@ -777,10 +715,7 @@ def _update_finalized_head(self, finalized_root: Hash32) -> None: Unconditionally write the ``finalized_root`` as the root of the currently finalized block. """ - self.db.set( - SchemaV1.make_finalized_head_root_lookup_key(), - finalized_root, - ) + self.db.set(SchemaV1.make_finalized_head_root_lookup_key(), finalized_root) self._finalized_root = finalized_root def _persist_finalized_head(self, state: BeaconState) -> None: @@ -801,13 +736,12 @@ def _update_justified_head(self, justified_root: Hash32, epoch: Epoch) -> None: Unconditionally write the ``justified_root`` as the root of the highest justified block. """ - self.db.set( - SchemaV1.make_justified_head_root_lookup_key(), - justified_root, - ) + self.db.set(SchemaV1.make_justified_head_root_lookup_key(), justified_root) self._highest_justified_epoch = epoch - def _find_updated_justified_root(self, state: BeaconState) -> Optional[Tuple[Hash32, Epoch]]: + def _find_updated_justified_root( + self, state: BeaconState + ) -> Optional[Tuple[Hash32, Epoch]]: """ Find the highest epoch that has been justified so far. @@ -820,16 +754,10 @@ def _find_updated_justified_root(self, state: BeaconState) -> Optional[Tuple[Has """ if state.current_justified_checkpoint.epoch > self._highest_justified_epoch: checkpoint = state.current_justified_checkpoint - return ( - checkpoint.root, - checkpoint.epoch, - ) + return (checkpoint.root, checkpoint.epoch) elif state.previous_justified_checkpoint.epoch > self._highest_justified_epoch: checkpoint = state.previous_justified_checkpoint - return ( - checkpoint.root, - checkpoint.epoch, - ) + return (checkpoint.root, checkpoint.epoch) return None def _persist_justified_head(self, state: BeaconState) -> None: @@ -843,9 +771,9 @@ def _persist_justified_head(self, state: BeaconState) -> None: if result: self._update_justified_head(*result) - def _handle_exceptional_justification_and_finality(self, - db: DatabaseAPI, - genesis_block: BaseBeaconBlock) -> None: + def _handle_exceptional_justification_and_finality( + self, db: DatabaseAPI, genesis_block: BaseBeaconBlock + ) -> None: """ The genesis ``BeaconState`` lacks the correct justification and finality data in the early epochs. The invariants of this class require an exceptional @@ -861,24 +789,32 @@ def _handle_exceptional_justification_and_finality(self, # @staticmethod - def _add_attestations_root_to_block_lookup(db: DatabaseAPI, - block: BaseBeaconBlock) -> None: + def _add_attestations_root_to_block_lookup( + db: DatabaseAPI, block: BaseBeaconBlock + ) -> None: root = block.signing_root for index, attestation in enumerate(block.body.attestations): attestation_key = AttestationKey(root, index) db.set( - SchemaV1.make_attestation_root_to_block_lookup_key(attestation.hash_tree_root), + SchemaV1.make_attestation_root_to_block_lookup_key( + attestation.hash_tree_root + ), ssz.encode(attestation_key), ) - def get_attestation_key_by_root(self, attestation_root: Hash32)-> Tuple[Hash32, int]: + def get_attestation_key_by_root( + self, attestation_root: Hash32 + ) -> Tuple[Hash32, int]: return self._get_attestation_key_by_root(self.db, attestation_root) @staticmethod - def _get_attestation_key_by_root(db: DatabaseAPI, - attestation_root: Hash32) -> Tuple[Hash32, int]: + def _get_attestation_key_by_root( + db: DatabaseAPI, attestation_root: Hash32 + ) -> Tuple[Hash32, int]: try: - encoded_key = db[SchemaV1.make_attestation_root_to_block_lookup_key(attestation_root)] + encoded_key = db[ + SchemaV1.make_attestation_root_to_block_lookup_key(attestation_root) + ] except KeyError: raise AttestationRootNotFound( "Attestation root {0} not found".format(encode_hex(attestation_root)) @@ -887,7 +823,9 @@ def _get_attestation_key_by_root(db: DatabaseAPI, return attestation_key.block_root, attestation_key.index def attestation_exists(self, attestation_root: Hash32) -> bool: - lookup_key = SchemaV1.make_attestation_root_to_block_lookup_key(attestation_root) + lookup_key = SchemaV1.make_attestation_root_to_block_lookup_key( + attestation_root + ) return self.exists(lookup_key) # diff --git a/eth2/beacon/db/exceptions.py b/eth2/beacon/db/exceptions.py index 760a965121..26d8971bcb 100644 --- a/eth2/beacon/db/exceptions.py +++ b/eth2/beacon/db/exceptions.py @@ -2,6 +2,7 @@ class BeaconDBException(Exception): """ Base class for exceptions raised by this package. """ + pass @@ -9,13 +10,15 @@ class HeadStateSlotNotFound(BeaconDBException): """ Exception raised if head state slot does not exist. """ + pass -class StateSlotNotFound(BeaconDBException): +class StateNotFound(BeaconDBException): """ - Exception raised if state root with the given slot number does not exist. + Exception raised if state with the given state does not exist. """ + pass @@ -23,6 +26,7 @@ class FinalizedHeadNotFound(BeaconDBException): """ Exception raised if no finalized head is set in this database. """ + pass @@ -30,6 +34,7 @@ class JustifiedHeadNotFound(BeaconDBException): """ Exception raised if no justified head is set in this database. """ + pass @@ -37,6 +42,7 @@ class AttestationRootNotFound(BeaconDBException): """ Exception raised if no attestation root is set in this database. """ + pass @@ -45,4 +51,5 @@ class MissingForkChoiceScoringFns(BeaconDBException): Exception raised if a client tries to score a block without providing the ability to generate a score via a ``scoring``. """ + pass diff --git a/eth2/beacon/db/schema.py b/eth2/beacon/db/schema.py index 7547a18594..f3c2969829 100644 --- a/eth2/beacon/db/schema.py +++ b/eth2/beacon/db/schema.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod -from eth_typing import ( - Hash32, -) +from eth_typing import Hash32 class BaseSchema(ABC): @@ -67,43 +65,43 @@ class SchemaV1(BaseSchema): # @staticmethod def make_head_state_slot_lookup_key() -> bytes: - return b'v1:beacon:head-state-slot' + return b"v1:beacon:head-state-slot" @staticmethod def make_slot_to_state_root_lookup_key(slot: int) -> bytes: - return b'v1:beacon:slot-to-state-root%d' % slot + return b"v1:beacon:slot-to-state-root%d" % slot # # Block # @staticmethod def make_canonical_head_root_lookup_key() -> bytes: - return b'v1:beacon:canonical-head-root' + return b"v1:beacon:canonical-head-root" @staticmethod def make_finalized_head_root_lookup_key() -> bytes: - return b'v1:beacon:finalized-head-root' + return b"v1:beacon:finalized-head-root" @staticmethod def make_justified_head_root_lookup_key() -> bytes: - return b'v1:beacon:justified-head-root' + return b"v1:beacon:justified-head-root" @staticmethod def make_block_slot_to_root_lookup_key(slot: int) -> bytes: - slot_to_root_key = b'v1:beacon:block-slot-to-root:%d' % slot + slot_to_root_key = b"v1:beacon:block-slot-to-root:%d" % slot return slot_to_root_key @staticmethod def make_block_root_to_score_lookup_key(block_root: Hash32) -> bytes: - return b'v1:beacon:block-root-to-score:%s' % block_root + return b"v1:beacon:block-root-to-score:%s" % block_root @staticmethod def make_block_root_to_slot_lookup_key(block_root: Hash32) -> bytes: - return b'v1:beacon:block-root-to-slot:%s' % block_root + return b"v1:beacon:block-root-to-slot:%s" % block_root # # Attestaion # @staticmethod def make_attestation_root_to_block_lookup_key(attestaton_root: Hash32) -> bytes: - return b'v1:beacon:attestation-root-to-block:%s' % attestaton_root + return b"v1:beacon:attestation-root-to-block:%s" % attestaton_root diff --git a/eth2/beacon/deposit_helpers.py b/eth2/beacon/deposit_helpers.py index 9eeb76f68a..bbf1b8a412 100644 --- a/eth2/beacon/deposit_helpers.py +++ b/eth2/beacon/deposit_helpers.py @@ -1,36 +1,21 @@ -from eth_utils import ( - encode_hex, - ValidationError, -) -from eth2._utils.bls import bls +from eth_utils import ValidationError, encode_hex -from eth2._utils.merkle.common import ( - verify_merkle_branch, -) -from eth2.beacon.constants import ( - DEPOSIT_CONTRACT_TREE_DEPTH, -) -from eth2.beacon.signature_domain import ( - SignatureDomain, -) -from eth2.beacon.helpers import ( - compute_domain, -) -from eth2.beacon.epoch_processing_helpers import ( - increase_balance, -) +from eth2._utils.bls import bls +from eth2._utils.merkle.common import verify_merkle_branch +from eth2.beacon.constants import DEPOSIT_CONTRACT_TREE_DEPTH +from eth2.beacon.epoch_processing_helpers import increase_balance +from eth2.beacon.helpers import compute_domain +from eth2.beacon.signature_domain import SignatureDomain from eth2.beacon.types.deposits import Deposit from eth2.beacon.types.states import BeaconState from eth2.beacon.types.validators import Validator -from eth2.beacon.typing import ( - ValidatorIndex, -) +from eth2.beacon.typing import ValidatorIndex from eth2.configs import Eth2Config -def validate_deposit_proof(state: BeaconState, - deposit: Deposit, - deposit_contract_tree_depth: int) -> None: +def validate_deposit_proof( + state: BeaconState, deposit: Deposit, deposit_contract_tree_depth: int +) -> None: """ Validate if deposit branch proof is valid. """ @@ -51,9 +36,9 @@ def validate_deposit_proof(state: BeaconState, ) -def process_deposit(state: BeaconState, - deposit: Deposit, - config: Eth2Config) -> BeaconState: +def process_deposit( + state: BeaconState, deposit: Deposit, config: Eth2Config +) -> BeaconState: """ Process a deposit from Ethereum 1.0. """ @@ -63,9 +48,7 @@ def process_deposit(state: BeaconState, # needs to be done here because while the deposit contract will never # create an invalid Merkle branch, it may admit an invalid deposit # object, and we need to be able to skip over it - state = state.copy( - eth1_deposit_index=state.eth1_deposit_index + 1, - ) + state = state.copy(eth1_deposit_index=state.eth1_deposit_index + 1) pubkey = deposit.data.pubkey amount = deposit.data.amount @@ -79,29 +62,20 @@ def process_deposit(state: BeaconState, message_hash=deposit.data.signing_root, pubkey=pubkey, signature=deposit.data.signature, - domain=compute_domain( - SignatureDomain.DOMAIN_DEPOSIT, - ), + domain=compute_domain(SignatureDomain.DOMAIN_DEPOSIT), ) if not is_valid_proof_of_possession: return state withdrawal_credentials = deposit.data.withdrawal_credentials validator = Validator.create_pending_validator( - pubkey, - withdrawal_credentials, - amount, - config, + pubkey, withdrawal_credentials, amount, config ) return state.copy( validators=state.validators + (validator,), - balances=state.balances + (amount, ), + balances=state.balances + (amount,), ) else: index = ValidatorIndex(validator_pubkeys.index(pubkey)) - return increase_balance( - state, - index, - amount, - ) + return increase_balance(state, index, amount) diff --git a/eth2/beacon/epoch_processing_helpers.py b/eth2/beacon/epoch_processing_helpers.py index cf6d3bac95..a3a82cbb3a 100644 --- a/eth2/beacon/epoch_processing_helpers.py +++ b/eth2/beacon/epoch_processing_helpers.py @@ -1,121 +1,76 @@ -from typing import ( - Iterable, - Sequence, - Set, - Tuple, -) - -from eth_typing import ( - Hash32, -) +from typing import Iterable, Sequence, Set, Tuple -from eth_utils import ( - to_tuple, - ValidationError, -) -from eth_utils.toolz import ( - curry, - groupby, - thread_first, -) +from eth_typing import Hash32 +from eth_utils import ValidationError, to_tuple +from eth_utils.toolz import curry, groupby, thread_first -from eth2._utils.bitfield import ( - Bitfield, - has_voted -) +from eth2._utils.bitfield import Bitfield, has_voted from eth2._utils.numeric import integer_squareroot from eth2._utils.tuple import update_tuple_item_with_fn -from eth2.beacon.attestation_helpers import ( - get_attestation_data_slot, -) -from eth2.beacon.committee_helpers import ( - get_crosslink_committee, -) -from eth2.beacon.constants import ( - BASE_REWARDS_PER_EPOCH, -) -from eth2.configs import ( - Eth2Config, - CommitteeConfig, -) -from eth2.beacon.exceptions import ( - InvalidEpochError, -) +from eth2.beacon.attestation_helpers import get_attestation_data_slot +from eth2.beacon.committee_helpers import get_crosslink_committee +from eth2.beacon.constants import BASE_REWARDS_PER_EPOCH +from eth2.beacon.exceptions import InvalidEpochError from eth2.beacon.helpers import ( get_active_validator_indices, get_block_root, get_block_root_at_slot, get_total_balance, ) -from eth2.beacon.typing import ( - Epoch, - Gwei, - Shard, - ValidatorIndex, -) - -from eth2.beacon.types.crosslinks import Crosslink -from eth2.beacon.types.pending_attestations import ( - PendingAttestation, -) -from eth2.beacon.types.attestations import ( - Attestation, - IndexedAttestation, -) from eth2.beacon.types.attestation_data import AttestationData +from eth2.beacon.types.attestations import Attestation, IndexedAttestation +from eth2.beacon.types.crosslinks import Crosslink +from eth2.beacon.types.pending_attestations import PendingAttestation from eth2.beacon.types.states import BeaconState +from eth2.beacon.typing import Epoch, Gwei, Shard, ValidatorIndex +from eth2.configs import CommitteeConfig, Eth2Config -def increase_balance(state: BeaconState, index: ValidatorIndex, delta: Gwei) -> BeaconState: +def increase_balance( + state: BeaconState, index: ValidatorIndex, delta: Gwei +) -> BeaconState: return state.copy( balances=update_tuple_item_with_fn( - state.balances, - index, - lambda balance, *_: Gwei(balance + delta) - ), + state.balances, index, lambda balance, *_: Gwei(balance + delta) + ) ) -def decrease_balance(state: BeaconState, index: ValidatorIndex, delta: Gwei) -> BeaconState: +def decrease_balance( + state: BeaconState, index: ValidatorIndex, delta: Gwei +) -> BeaconState: return state.copy( balances=update_tuple_item_with_fn( state.balances, index, - lambda balance, *_: Gwei(0) if delta > balance else Gwei(balance - delta) - ), + lambda balance, *_: Gwei(0) if delta > balance else Gwei(balance - delta), + ) ) -def get_attesting_indices(state: BeaconState, - attestation_data: AttestationData, - bitfield: Bitfield, - config: CommitteeConfig) -> Set[ValidatorIndex]: +def get_attesting_indices( + state: BeaconState, + attestation_data: AttestationData, + bitfield: Bitfield, + config: CommitteeConfig, +) -> Set[ValidatorIndex]: """ Return the sorted attesting indices corresponding to ``attestation_data`` and ``bitfield``. """ committee = get_crosslink_committee( - state, - attestation_data.target.epoch, - attestation_data.crosslink.shard, - config, + state, attestation_data.target.epoch, attestation_data.crosslink.shard, config ) return set(index for i, index in enumerate(committee) if has_voted(bitfield, i)) -def get_indexed_attestation(state: BeaconState, - attestation: Attestation, - config: CommitteeConfig) -> IndexedAttestation: +def get_indexed_attestation( + state: BeaconState, attestation: Attestation, config: CommitteeConfig +) -> IndexedAttestation: attesting_indices = get_attesting_indices( - state, - attestation.data, - attestation.aggregation_bits, - config, + state, attestation.data, attestation.aggregation_bits, config ) custody_bit_1_indices = get_attesting_indices( - state, - attestation.data, - attestation.custody_bits, - config, + state, attestation.data, attestation.custody_bits, config ) if not custody_bit_1_indices.issubset(attesting_indices): raise ValidationError( @@ -133,8 +88,7 @@ def get_indexed_attestation(state: BeaconState, ) -def compute_activation_exit_epoch(epoch: Epoch, - activation_exit_delay: int) -> Epoch: +def compute_activation_exit_epoch(epoch: Epoch, activation_exit_delay: int) -> Epoch: """ An entry or exit triggered in the ``epoch`` given by the input takes effect at the epoch given by the output. @@ -149,24 +103,24 @@ def get_validator_churn_limit(state: BeaconState, config: Eth2Config) -> int: current_epoch = state.current_epoch(slots_per_epoch) active_validator_indices = get_active_validator_indices( - state.validators, - current_epoch, + state.validators, current_epoch ) return max( - min_per_epoch_churn_limit, - len(active_validator_indices) // churn_limit_quotient + min_per_epoch_churn_limit, len(active_validator_indices) // churn_limit_quotient ) def get_total_active_balance(state: BeaconState, config: Eth2Config) -> Gwei: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) - active_validator_indices = get_active_validator_indices(state.validators, current_epoch) + active_validator_indices = get_active_validator_indices( + state.validators, current_epoch + ) return get_total_balance(state, set(active_validator_indices)) -def get_matching_source_attestations(state: BeaconState, - epoch: Epoch, - config: Eth2Config) -> Tuple[PendingAttestation, ...]: +def get_matching_source_attestations( + state: BeaconState, epoch: Epoch, config: Eth2Config +) -> Tuple[PendingAttestation, ...]: if epoch == state.current_epoch(config.SLOTS_PER_EPOCH): return state.current_epoch_attestations elif epoch == state.previous_epoch(config.SLOTS_PER_EPOCH, config.GENESIS_EPOCH): @@ -176,14 +130,11 @@ def get_matching_source_attestations(state: BeaconState, @to_tuple -def get_matching_target_attestations(state: BeaconState, - epoch: Epoch, - config: Eth2Config) -> Iterable[PendingAttestation]: +def get_matching_target_attestations( + state: BeaconState, epoch: Epoch, config: Eth2Config +) -> Iterable[PendingAttestation]: target_root = get_block_root( - state, - epoch, - config.SLOTS_PER_EPOCH, - config.SLOTS_PER_HISTORICAL_ROOT, + state, epoch, config.SLOTS_PER_EPOCH, config.SLOTS_PER_HISTORICAL_ROOT ) for a in get_matching_source_attestations(state, epoch, config): @@ -192,76 +143,63 @@ def get_matching_target_attestations(state: BeaconState, @to_tuple -def get_matching_head_attestations(state: BeaconState, - epoch: Epoch, - config: Eth2Config) -> Iterable[PendingAttestation]: +def get_matching_head_attestations( + state: BeaconState, epoch: Epoch, config: Eth2Config +) -> Iterable[PendingAttestation]: for a in get_matching_source_attestations(state, epoch, config): beacon_block_root = get_block_root_at_slot( state, - get_attestation_data_slot( - state, - a.data, - config, - ), + get_attestation_data_slot(state, a.data, config), config.SLOTS_PER_HISTORICAL_ROOT, ) if a.data.beacon_block_root == beacon_block_root: yield a -def get_unslashed_attesting_indices(state: BeaconState, - attestations: Sequence[PendingAttestation], - config: CommitteeConfig) -> Set[ValidatorIndex]: +def get_unslashed_attesting_indices( + state: BeaconState, + attestations: Sequence[PendingAttestation], + config: CommitteeConfig, +) -> Set[ValidatorIndex]: output: Set[ValidatorIndex] = set() for a in attestations: - output = output.union(get_attesting_indices(state, a.data, a.aggregation_bits, config)) - return set( - filter( - lambda index: not state.validators[index].slashed, - output, + output = output.union( + get_attesting_indices(state, a.data, a.aggregation_bits, config) ) - ) + return set(filter(lambda index: not state.validators[index].slashed, output)) -def get_attesting_balance(state: BeaconState, - attestations: Sequence[PendingAttestation], - config: Eth2Config) -> Gwei: +def get_attesting_balance( + state: BeaconState, attestations: Sequence[PendingAttestation], config: Eth2Config +) -> Gwei: return get_total_balance( state, - get_unslashed_attesting_indices(state, attestations, CommitteeConfig(config)) + get_unslashed_attesting_indices(state, attestations, CommitteeConfig(config)), ) -def _score_crosslink(state: BeaconState, - crosslink: Crosslink, - attestations: Sequence[PendingAttestation], - config: Eth2Config) -> Tuple[Gwei, Hash32]: - return ( - get_attesting_balance( - state, - attestations, - config, - ), - crosslink.data_root - ) +def _score_crosslink( + state: BeaconState, + crosslink: Crosslink, + attestations: Sequence[PendingAttestation], + config: Eth2Config, +) -> Tuple[Gwei, Hash32]: + return (get_attesting_balance(state, attestations, config), crosslink.data_root) def _find_winning_crosslink_and_attesting_indices_from_candidates( - state: BeaconState, - candidate_attestations: Sequence[PendingAttestation], - config: Eth2Config) -> Tuple[Crosslink, Set[ValidatorIndex]]: + state: BeaconState, + candidate_attestations: Sequence[PendingAttestation], + config: Eth2Config, +) -> Tuple[Crosslink, Set[ValidatorIndex]]: attestations_by_crosslink = groupby( - lambda a: a.data.crosslink, - candidate_attestations, + lambda a: a.data.crosslink, candidate_attestations ) winning_crosslink, winning_attestations = max( attestations_by_crosslink.items(), key=lambda pair: _score_crosslink( - state, - pair[0], # crosslink - pair[1], # attestations - config, + state, pair[0], pair[1], config # crosslink # attestations ), default=(Crosslink(), tuple()), ) @@ -269,45 +207,48 @@ def _find_winning_crosslink_and_attesting_indices_from_candidates( return ( winning_crosslink, get_unslashed_attesting_indices( - state, - winning_attestations, - CommitteeConfig(config), + state, winning_attestations, CommitteeConfig(config) ), ) @to_tuple def _get_attestations_for_shard( - attestations: Sequence[PendingAttestation], - shard: Shard) -> Iterable[PendingAttestation]: + attestations: Sequence[PendingAttestation], shard: Shard +) -> Iterable[PendingAttestation]: for a in attestations: if a.data.crosslink.shard == shard: yield a @curry -def _crosslink_or_parent_is_valid(valid_crosslink: Crosslink, candidate: Crosslink) -> bool: - return valid_crosslink.hash_tree_root in (candidate.parent_root, candidate.hash_tree_root) +def _crosslink_or_parent_is_valid( + valid_crosslink: Crosslink, candidate: Crosslink +) -> bool: + return valid_crosslink.hash_tree_root in ( + candidate.parent_root, + candidate.hash_tree_root, + ) @to_tuple -def _get_attestations_for_valid_crosslink(attestations: Sequence[PendingAttestation], - state: BeaconState, - shard: Shard, - config: Eth2Config) -> Iterable[PendingAttestation]: +def _get_attestations_for_valid_crosslink( + attestations: Sequence[PendingAttestation], + state: BeaconState, + shard: Shard, + config: Eth2Config, +) -> Iterable[PendingAttestation]: return filter( lambda a: _crosslink_or_parent_is_valid( - state.current_crosslinks[shard], - a.data.crosslink, + state.current_crosslinks[shard], a.data.crosslink ), attestations, ) -def _find_candidate_attestations_for_shard(state: BeaconState, - epoch: Epoch, - shard: Shard, - config: Eth2Config) -> Tuple[PendingAttestation, ...]: +def _find_candidate_attestations_for_shard( + state: BeaconState, epoch: Epoch, shard: Shard, config: Eth2Config +) -> Tuple[PendingAttestation, ...]: return thread_first( state, (get_matching_source_attestations, epoch, config), @@ -317,26 +258,25 @@ def _find_candidate_attestations_for_shard(state: BeaconState, def get_winning_crosslink_and_attesting_indices( - *, - state: BeaconState, - epoch: Epoch, - shard: Shard, - config: Eth2Config) -> Tuple[Crosslink, Set[ValidatorIndex]]: - candidate_attestations = _find_candidate_attestations_for_shard(state, epoch, shard, config) + *, state: BeaconState, epoch: Epoch, shard: Shard, config: Eth2Config +) -> Tuple[Crosslink, Set[ValidatorIndex]]: + candidate_attestations = _find_candidate_attestations_for_shard( + state, epoch, shard, config + ) return _find_winning_crosslink_and_attesting_indices_from_candidates( - state, - candidate_attestations, - config, + state, candidate_attestations, config ) -def get_base_reward(state: BeaconState, - index: ValidatorIndex, - config: Eth2Config) -> Gwei: +def get_base_reward( + state: BeaconState, index: ValidatorIndex, config: Eth2Config +) -> Gwei: total_balance = get_total_active_balance(state, config) effective_balance = state.validators[index].effective_balance return Gwei( - effective_balance * config.BASE_REWARD_FACTOR // - integer_squareroot(total_balance) // BASE_REWARDS_PER_EPOCH + effective_balance + * config.BASE_REWARD_FACTOR + // integer_squareroot(total_balance) + // BASE_REWARDS_PER_EPOCH ) diff --git a/eth2/beacon/exceptions.py b/eth2/beacon/exceptions.py index 74dabd6b28..7dfda6a740 100644 --- a/eth2/beacon/exceptions.py +++ b/eth2/beacon/exceptions.py @@ -1,16 +1,12 @@ -from eth.exceptions import ( - PyEVMError, -) - -from eth_utils import ( - ValidationError, -) +from eth.exceptions import PyEVMError +from eth_utils import ValidationError class StateMachineNotFound(PyEVMError): """ Raised when no ``StateMachine`` is available for the provided block slot number. """ + pass @@ -18,6 +14,7 @@ class BlockClassError(PyEVMError): """ Raised when the given ``block`` doesn't match the block class version """ + pass @@ -26,6 +23,7 @@ class ProposerIndexError(PyEVMError): Raised when the given ``validator_index`` doesn't match the ``validator_index`` of proposer of the given ``slot`` """ + pass @@ -33,6 +31,7 @@ class NoCommitteeAssignment(PyEVMError): """ Raised when no potential crosslink committee assignment. """ + pass @@ -42,6 +41,7 @@ class InvalidEpochError(ValidationError): Example: asking the ``BeaconState`` about an epoch that is not derivable given the current data. """ + pass @@ -49,6 +49,7 @@ class BLSValidationError(ValidationError): """ Raised when a verification of public keys, messages, and signature fails. """ + pass @@ -56,6 +57,7 @@ class SignatureError(BLSValidationError): """ Signature is ill-formed """ + pass @@ -63,4 +65,5 @@ class PublicKeyError(BLSValidationError): """ Public Key is ill-formed """ + pass diff --git a/eth2/beacon/fork_choice/lmd_ghost.py b/eth2/beacon/fork_choice/lmd_ghost.py index 21774d9a79..485c8bd3b7 100644 --- a/eth2/beacon/fork_choice/lmd_ghost.py +++ b/eth2/beacon/fork_choice/lmd_ghost.py @@ -1,45 +1,21 @@ from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, Union -from eth_typing import ( - Hash32, -) -from eth_utils import ( - to_tuple, -) -from eth_utils.toolz import ( - curry, - first, - mapcat, - merge, - merge_with, - second, - valmap, -) - -from eth2.beacon.attestation_helpers import ( - get_attestation_data_slot, -) -from eth2.beacon.epoch_processing_helpers import ( - get_attesting_indices, -) -from eth2.beacon.helpers import ( - get_active_validator_indices, - compute_epoch_of_slot, -) +from eth_typing import Hash32 +from eth_utils import to_tuple +from eth_utils.toolz import curry, first, mapcat, merge, merge_with, second, valmap + +from eth2.beacon.attestation_helpers import get_attestation_data_slot from eth2.beacon.db.chain import BeaconChainDB +from eth2.beacon.epoch_processing_helpers import get_attesting_indices +from eth2.beacon.helpers import compute_epoch_of_slot, get_active_validator_indices from eth2.beacon.operations.attestation_pool import AttestationPool -from eth2.beacon.types.attestations import Attestation from eth2.beacon.types.attestation_data import AttestationData +from eth2.beacon.types.attestations import Attestation from eth2.beacon.types.blocks import BaseBeaconBlock from eth2.beacon.types.pending_attestations import PendingAttestation from eth2.beacon.types.states import BeaconState -from eth2.beacon.typing import ( - Gwei, - Slot, - ValidatorIndex, -) -from eth2.configs import Eth2Config, CommitteeConfig - +from eth2.beacon.typing import Gwei, Slot, ValidatorIndex +from eth2.configs import CommitteeConfig, Eth2Config # TODO(ralexstokes) integrate `AttestationPool` once it has been merged AttestationIndex = Dict[ValidatorIndex, AttestationData] @@ -48,7 +24,8 @@ def _take_latest_attestation_by_slot( - candidates: Sequence[Tuple[Slot, AttestationData]]) -> Tuple[Slot, AttestationData]: + candidates: Sequence[Tuple[Slot, AttestationData]] +) -> Tuple[Slot, AttestationData]: return max(candidates, key=first) @@ -56,21 +33,24 @@ class Store: """ A private class meant to encapsulate data access for the functionality in this module. """ - def __init__(self, - chain_db: BeaconChainDB, - state: BeaconState, - attestation_pool: AttestationPool, - block_class: Type[BaseBeaconBlock], - config: Eth2Config): + + def __init__( + self, + chain_db: BeaconChainDB, + state: BeaconState, + attestation_pool: AttestationPool, + block_class: Type[BaseBeaconBlock], + config: Eth2Config, + ): self._db = chain_db self._block_class = block_class self._config = config self._attestation_index = self._build_attestation_index(state, attestation_pool) @curry - def _mk_pre_index_from_attestation(self, - state: BeaconState, - attestation: AttestationLike) -> Iterable[PreIndex]: + def _mk_pre_index_from_attestation( + self, state: BeaconState, attestation: AttestationLike + ) -> Iterable[PreIndex]: attestation_data = attestation.data slot = get_attestation_data_slot(state, attestation_data, self._config) @@ -84,22 +64,17 @@ def _mk_pre_index_from_attestation(self, ) ) - def _mk_pre_index_from_attestations(self, - state: BeaconState, - attestations: Sequence[AttestationLike]) -> PreIndex: + def _mk_pre_index_from_attestations( + self, state: BeaconState, attestations: Sequence[AttestationLike] + ) -> PreIndex: """ A 'pre-index' is a Dict[ValidatorIndex, Tuple[Slot, AttestationData]]. """ - return merge( - *mapcat( - self._mk_pre_index_from_attestation(state), - attestations, - ) - ) + return merge(*mapcat(self._mk_pre_index_from_attestation(state), attestations)) - def _build_attestation_index(self, - state: BeaconState, - attestation_pool: AttestationPool) -> AttestationIndex: + def _build_attestation_index( + self, state: BeaconState, attestation_pool: AttestationPool + ) -> AttestationIndex: """ Assembles a dictionary of latest attestations keyed by validator index. Any attestation made by a validator in the ``attestation_pool`` that occur after the @@ -111,18 +86,15 @@ def _build_attestation_index(self, duplicates in the pre-indices keyed by validator index. """ previous_epoch_index = self._mk_pre_index_from_attestations( - state, - state.previous_epoch_attestations + state, state.previous_epoch_attestations ) current_epoch_index = self._mk_pre_index_from_attestations( - state, - state.current_epoch_attestations + state, state.current_epoch_attestations ) pool_index = self._mk_pre_index_from_attestations( - state, - tuple(attestation for _, attestation in attestation_pool) + state, tuple(attestation for _, attestation in attestation_pool) ) index_by_latest_slot = merge_with( @@ -132,12 +104,11 @@ def _build_attestation_index(self, pool_index, ) # convert the index to a mapping of ValidatorIndex -> (latest) Attestation - return valmap( - second, - index_by_latest_slot, - ) + return valmap(second, index_by_latest_slot) - def _get_latest_attestation(self, index: ValidatorIndex) -> Optional[AttestationData]: + def _get_latest_attestation( + self, index: ValidatorIndex + ) -> Optional[AttestationData]: """ Return the latest attesation we know from the validator with the given ``index``. @@ -147,7 +118,9 @@ def _get_latest_attestation(self, index: ValidatorIndex) -> Optional[Attestation def _get_block_by_root(self, root: Hash32) -> BaseBeaconBlock: return self._db.get_block_by_root(root, self._block_class) - def get_latest_attestation_target(self, index: ValidatorIndex) -> Optional[BaseBeaconBlock]: + def get_latest_attestation_target( + self, index: ValidatorIndex + ) -> Optional[BaseBeaconBlock]: attestation = self._get_latest_attestation(index) if not attestation: return None @@ -179,29 +152,19 @@ def get_ancestor(self, block: BaseBeaconBlock, slot: Slot) -> BaseBeaconBlock: @curry def _find_latest_attestation_target( - store: Store, - index: ValidatorIndex) -> AttestationTarget: - return ( - index, - store.get_latest_attestation_target(index), - ) + store: Store, index: ValidatorIndex +) -> AttestationTarget: + return (index, store.get_latest_attestation_target(index)) @to_tuple -def _find_latest_attestation_targets(state: BeaconState, - store: Store, - config: Eth2Config) -> Iterable[AttestationTarget]: +def _find_latest_attestation_targets( + state: BeaconState, store: Store, config: Eth2Config +) -> Iterable[AttestationTarget]: epoch = compute_epoch_of_slot(state.slot, config.SLOTS_PER_EPOCH) - active_validators = get_active_validator_indices( - state.validators, - epoch, - ) + active_validators = get_active_validator_indices(state.validators, epoch) return filter( - second, - map( - _find_latest_attestation_target(store), - active_validators, - ) + second, map(_find_latest_attestation_target(store), active_validators) ) @@ -213,10 +176,12 @@ def _balance_for_validator(state: BeaconState, validator_index: ValidatorIndex) return state.validators[validator_index].effective_balance -def score_block_by_attestations(state: BeaconState, - store: Store, - attestation_targets: Sequence[AttestationTarget], - block: BaseBeaconBlock) -> int: +def score_block_by_attestations( + state: BeaconState, + store: Store, + attestation_targets: Sequence[AttestationTarget], + block: BaseBeaconBlock, +) -> int: """ Return the total balance attesting to ``block`` based on the ``attestation_targets``. """ @@ -228,16 +193,18 @@ def score_block_by_attestations(state: BeaconState, def score_block_by_root(block: BaseBeaconBlock) -> int: - return int.from_bytes(block.hash_tree_root[:8], byteorder='big') + return int.from_bytes(block.hash_tree_root[:8], byteorder="big") @curry -def lmd_ghost_scoring(chain_db: BeaconChainDB, - attestation_pool: AttestationPool, - state: BeaconState, - config: Eth2Config, - block_class: Type[BaseBeaconBlock], - block: BaseBeaconBlock) -> int: +def lmd_ghost_scoring( + chain_db: BeaconChainDB, + attestation_pool: AttestationPool, + state: BeaconState, + config: Eth2Config, + block_class: Type[BaseBeaconBlock], + block: BaseBeaconBlock, +) -> int: """ Return the score of the ``target_block`` according to the LMD GHOST algorithm, using the lexicographic ordering of the block root to break ties. @@ -247,10 +214,7 @@ def lmd_ghost_scoring(chain_db: BeaconChainDB, attestation_targets = _find_latest_attestation_targets(state, store, config) attestation_score = score_block_by_attestations( - state, - store, - attestation_targets, - block, + state, store, attestation_targets, block ) block_root_score = score_block_by_root(block) diff --git a/eth2/beacon/genesis.py b/eth2/beacon/genesis.py index 241af46778..137b54214e 100644 --- a/eth2/beacon/genesis.py +++ b/eth2/beacon/genesis.py @@ -1,54 +1,27 @@ -from typing import ( - Sequence, - Type, -) - -from eth_typing import ( - Hash32, -) +from typing import Sequence, Type +from eth_typing import Hash32 import ssz -from eth2.beacon.constants import ( - SECONDS_PER_DAY, - DEPOSIT_CONTRACT_TREE_DEPTH, -) -from eth2.beacon.committee_helpers import ( - get_compact_committees_root, -) -from eth2.beacon.deposit_helpers import ( - process_deposit, -) -from eth2.beacon.helpers import ( - get_active_validator_indices, -) - -from eth2.beacon.types.blocks import ( - BaseBeaconBlock, - BeaconBlockBody, -) -from eth2.beacon.types.block_headers import ( - BeaconBlockHeader, -) -from eth2.beacon.types.deposits import Deposit +from eth2.beacon.committee_helpers import get_compact_committees_root +from eth2.beacon.constants import DEPOSIT_CONTRACT_TREE_DEPTH, SECONDS_PER_DAY +from eth2.beacon.deposit_helpers import process_deposit +from eth2.beacon.helpers import get_active_validator_indices +from eth2.beacon.types.block_headers import BeaconBlockHeader +from eth2.beacon.types.blocks import BaseBeaconBlock, BeaconBlockBody from eth2.beacon.types.deposit_data import DepositData +from eth2.beacon.types.deposits import Deposit from eth2.beacon.types.eth1_data import Eth1Data from eth2.beacon.types.states import BeaconState from eth2.beacon.types.validators import calculate_effective_balance -from eth2.beacon.typing import ( - Timestamp, - ValidatorIndex, -) -from eth2.beacon.validator_status_helpers import ( - activate_validator, -) -from eth2.configs import ( - Eth2Config, - CommitteeConfig, -) - - -def is_genesis_trigger(deposits: Sequence[Deposit], timestamp: int, config: Eth2Config) -> bool: +from eth2.beacon.typing import Timestamp, ValidatorIndex +from eth2.beacon.validator_status_helpers import activate_validator +from eth2.configs import CommitteeConfig, Eth2Config + + +def is_genesis_trigger( + deposits: Sequence[Deposit], timestamp: int, config: Eth2Config +) -> bool: state = BeaconState(config=config) for deposit in deposits: @@ -64,24 +37,16 @@ def is_genesis_trigger(deposits: Sequence[Deposit], timestamp: int, config: Eth2 def state_with_validator_digests(state: BeaconState, config: Eth2Config) -> BeaconState: active_validator_indices = get_active_validator_indices( - state.validators, - config.GENESIS_EPOCH, + state.validators, config.GENESIS_EPOCH ) active_index_root = ssz.get_hash_tree_root( - active_validator_indices, - ssz.List(ssz.uint64, config.VALIDATOR_REGISTRY_LIMIT), - ) - active_index_roots = ( - (active_index_root,) * config.EPOCHS_PER_HISTORICAL_VECTOR + active_validator_indices, ssz.List(ssz.uint64, config.VALIDATOR_REGISTRY_LIMIT) ) + active_index_roots = (active_index_root,) * config.EPOCHS_PER_HISTORICAL_VECTOR committee_root = get_compact_committees_root( - state, - config.GENESIS_EPOCH, - CommitteeConfig(config), - ) - compact_committees_roots = ( - (committee_root,) * config.EPOCHS_PER_HISTORICAL_VECTOR + state, config.GENESIS_EPOCH, CommitteeConfig(config) ) + compact_committees_roots = (committee_root,) * config.EPOCHS_PER_HISTORICAL_VECTOR return state.copy( active_index_roots=active_index_roots, compact_committees_roots=compact_committees_roots, @@ -90,89 +55,71 @@ def state_with_validator_digests(state: BeaconState, config: Eth2Config) -> Beac def _genesis_time_from_eth1_timestamp(eth1_timestamp: Timestamp) -> Timestamp: return Timestamp( - eth1_timestamp - eth1_timestamp % SECONDS_PER_DAY + 2 * SECONDS_PER_DAY, + eth1_timestamp - eth1_timestamp % SECONDS_PER_DAY + 2 * SECONDS_PER_DAY ) -def initialize_beacon_state_from_eth1(*, - eth1_block_hash: Hash32, - eth1_timestamp: Timestamp, - deposits: Sequence[Deposit], - config: Eth2Config) -> BeaconState: +def initialize_beacon_state_from_eth1( + *, + eth1_block_hash: Hash32, + eth1_timestamp: Timestamp, + deposits: Sequence[Deposit], + config: Eth2Config +) -> BeaconState: state = BeaconState( genesis_time=_genesis_time_from_eth1_timestamp(eth1_timestamp), - eth1_data=Eth1Data( - block_hash=eth1_block_hash, - deposit_count=len(deposits), - ), + eth1_data=Eth1Data(block_hash=eth1_block_hash, deposit_count=len(deposits)), latest_block_header=BeaconBlockHeader( - body_root=BeaconBlockBody().hash_tree_root, + body_root=BeaconBlockBody().hash_tree_root ), config=config, ) # Process genesis deposits for index, deposit in enumerate(deposits): - deposit_data_list = tuple( - deposit.data - for deposit in deposits[:index + 1] - ) + deposit_data_list = tuple(deposit.data for deposit in deposits[: index + 1]) state = state.copy( eth1_data=state.eth1_data.copy( deposit_root=ssz.get_hash_tree_root( deposit_data_list, - ssz.List(DepositData, 2**DEPOSIT_CONTRACT_TREE_DEPTH), - ), - ), - ) - state = process_deposit( - state=state, - deposit=deposit, - config=config, + ssz.List(DepositData, 2 ** DEPOSIT_CONTRACT_TREE_DEPTH), + ) + ) ) + state = process_deposit(state=state, deposit=deposit, config=config) # Process genesis activations for validator_index in range(len(state.validators)): validator_index = ValidatorIndex(validator_index) balance = state.balances[validator_index] - effective_balance = calculate_effective_balance( - balance, - config, - ) + effective_balance = calculate_effective_balance(balance, config) state = state.update_validator_with_fn( - validator_index, - lambda v, *_: v.copy( - effective_balance=effective_balance, - ), + validator_index, lambda v, *_: v.copy(effective_balance=effective_balance) ) if effective_balance == config.MAX_EFFECTIVE_BALANCE: state = state.update_validator_with_fn( - validator_index, - activate_validator, - config.GENESIS_EPOCH, + validator_index, activate_validator, config.GENESIS_EPOCH ) - return state_with_validator_digests( - state, - config, - ) + return state_with_validator_digests(state, config) def is_valid_genesis_state(state: BeaconState, config: Eth2Config) -> bool: if state.genesis_time < config.MIN_GENESIS_TIME: return False - validator_count = len(get_active_validator_indices(state.validators, config.GENESIS_EPOCH)) + validator_count = len( + get_active_validator_indices(state.validators, config.GENESIS_EPOCH) + ) if validator_count < config.MIN_GENESIS_ACTIVE_VALIDATOR_COUNT: return False return True -def get_genesis_block(genesis_state_root: Hash32, - block_class: Type[BaseBeaconBlock]) -> BaseBeaconBlock: - return block_class( - state_root=genesis_state_root, - ) +def get_genesis_block( + genesis_state_root: Hash32, block_class: Type[BaseBeaconBlock] +) -> BaseBeaconBlock: + return block_class(state_root=genesis_state_root) diff --git a/eth2/beacon/helpers.py b/eth2/beacon/helpers.py index 0d966ce526..833d3326ad 100644 --- a/eth2/beacon/helpers.py +++ b/eth2/beacon/helpers.py @@ -1,41 +1,23 @@ -from typing import ( - Callable, - Sequence, - Set, - Tuple, - TYPE_CHECKING, -) - -from eth_utils import ( - ValidationError, -) -from eth_typing import ( - Hash32, -) +from typing import TYPE_CHECKING, Callable, Sequence, Set, Tuple +from eth_typing import Hash32 +from eth_utils import ValidationError from py_ecc.bls.typing import Domain -from eth2._utils.hash import ( - hash_eth2, -) -from eth2.beacon.signature_domain import ( - SignatureDomain, -) +from eth2._utils.hash import hash_eth2 +from eth2.beacon.signature_domain import SignatureDomain +from eth2.beacon.types.forks import Fork +from eth2.beacon.types.validators import Validator from eth2.beacon.typing import ( + DomainType, Epoch, Gwei, Slot, ValidatorIndex, Version, default_version, - DomainType, -) -from eth2.configs import ( - CommitteeConfig, ) - -from eth2.beacon.types.forks import Fork -from eth2.beacon.types.validators import Validator +from eth2.configs import CommitteeConfig if TYPE_CHECKING: from eth2.beacon.types.states import BeaconState # noqa: F401 @@ -49,8 +31,9 @@ def compute_start_slot_of_epoch(epoch: Epoch, slots_per_epoch: int) -> Slot: return Slot(epoch * slots_per_epoch) -def get_active_validator_indices(validators: Sequence[Validator], - epoch: Epoch) -> Tuple[ValidatorIndex, ...]: +def get_active_validator_indices( + validators: Sequence[Validator], epoch: Epoch +) -> Tuple[ValidatorIndex, ...]: """ Get indices of active validators from ``validators``. """ @@ -62,10 +45,11 @@ def get_active_validator_indices(validators: Sequence[Validator], def _get_historical_root( - historical_roots: Sequence[Hash32], - state_slot: Slot, - slot: Slot, - slots_per_historical_root: int) -> Hash32: + historical_roots: Sequence[Hash32], + state_slot: Slot, + slot: Slot, + slots_per_historical_root: int, +) -> Hash32: """ Return the historical root at a recent ``slot``. """ @@ -84,24 +68,23 @@ def _get_historical_root( return historical_roots[slot % slots_per_historical_root] -def get_block_root_at_slot(state: 'BeaconState', - slot: Slot, - slots_per_historical_root: int) -> Hash32: +def get_block_root_at_slot( + state: "BeaconState", slot: Slot, slots_per_historical_root: int +) -> Hash32: """ Return the block root at a recent ``slot``. """ return _get_historical_root( - state.block_roots, - state.slot, - slot, - slots_per_historical_root, + state.block_roots, state.slot, slot, slots_per_historical_root ) -def get_block_root(state: 'BeaconState', - epoch: Epoch, - slots_per_epoch: int, - slots_per_historical_root: int) -> Hash32: +def get_block_root( + state: "BeaconState", + epoch: Epoch, + slots_per_epoch: int, + slots_per_historical_root: int, +) -> Hash32: return get_block_root_at_slot( state, compute_start_slot_of_epoch(epoch, slots_per_epoch), @@ -109,18 +92,18 @@ def get_block_root(state: 'BeaconState', ) -def get_randao_mix(state: 'BeaconState', - epoch: Epoch, - epochs_per_historical_vector: int) -> Hash32: +def get_randao_mix( + state: "BeaconState", epoch: Epoch, epochs_per_historical_vector: int +) -> Hash32: """ Return the randao mix at a recent ``epoch``. """ return state.randao_mixes[epoch % epochs_per_historical_vector] -def get_active_index_root(state: 'BeaconState', - epoch: Epoch, - epochs_per_historical_vector: int) -> Hash32: +def get_active_index_root( + state: "BeaconState", epoch: Epoch, epochs_per_historical_vector: int +) -> Hash32: """ Return the index root at a recent ``epoch``. """ @@ -131,39 +114,39 @@ def _epoch_for_seed(epoch: Epoch) -> Hash32: return Hash32(epoch.to_bytes(32, byteorder="little")) -RandaoProvider = Callable[['BeaconState', Epoch, int], Hash32] -ActiveIndexRootProvider = Callable[['BeaconState', Epoch, int], Hash32] +RandaoProvider = Callable[["BeaconState", Epoch, int], Hash32] +ActiveIndexRootProvider = Callable[["BeaconState", Epoch, int], Hash32] -def _get_seed(state: 'BeaconState', - epoch: Epoch, - randao_provider: RandaoProvider, - active_index_root_provider: ActiveIndexRootProvider, - epoch_provider: Callable[[Epoch], Hash32], - committee_config: CommitteeConfig) -> Hash32: +def _get_seed( + state: "BeaconState", + epoch: Epoch, + randao_provider: RandaoProvider, + active_index_root_provider: ActiveIndexRootProvider, + epoch_provider: Callable[[Epoch], Hash32], + committee_config: CommitteeConfig, +) -> Hash32: randao_mix = randao_provider( state, Epoch( - epoch + - committee_config.EPOCHS_PER_HISTORICAL_VECTOR - - committee_config.MIN_SEED_LOOKAHEAD - - 1 + epoch + + committee_config.EPOCHS_PER_HISTORICAL_VECTOR + - committee_config.MIN_SEED_LOOKAHEAD + - 1 ), committee_config.EPOCHS_PER_HISTORICAL_VECTOR, ) active_index_root = active_index_root_provider( - state, - epoch, - committee_config.EPOCHS_PER_HISTORICAL_VECTOR, + state, epoch, committee_config.EPOCHS_PER_HISTORICAL_VECTOR ) epoch_as_bytes = epoch_provider(epoch) return hash_eth2(randao_mix + active_index_root + epoch_as_bytes) -def get_seed(state: 'BeaconState', - epoch: Epoch, - committee_config: CommitteeConfig) -> Hash32: +def get_seed( + state: "BeaconState", epoch: Epoch, committee_config: CommitteeConfig +) -> Hash32: """ Generate a seed for the given ``epoch``. """ @@ -177,18 +160,18 @@ def get_seed(state: 'BeaconState', ) -def get_total_balance(state: 'BeaconState', - validator_indices: Set[ValidatorIndex]) -> Gwei: +def get_total_balance( + state: "BeaconState", validator_indices: Set[ValidatorIndex] +) -> Gwei: """ Return the combined effective balance of an array of validators. """ return Gwei( max( sum( - state.validators[index].effective_balance - for index in validator_indices + state.validators[index].effective_balance for index in validator_indices ), - 1 + 1, ) ) @@ -207,8 +190,9 @@ def _signature_domain_to_domain_type(s: SignatureDomain) -> DomainType: return DomainType(s.to_bytes(4, byteorder="little")) -def compute_domain(signature_domain: SignatureDomain, - fork_version: Version=default_version) -> Domain: +def compute_domain( + signature_domain: SignatureDomain, fork_version: Version = default_version +) -> Domain: """ NOTE: we deviate from the spec here by taking the enum ``SignatureDomain`` and converting before creating the domain. @@ -217,13 +201,17 @@ def compute_domain(signature_domain: SignatureDomain, return Domain(domain_type + fork_version) -def get_domain(state: 'BeaconState', - signature_domain: SignatureDomain, - slots_per_epoch: int, - message_epoch: Epoch=None) -> Domain: +def get_domain( + state: "BeaconState", + signature_domain: SignatureDomain, + slots_per_epoch: int, + message_epoch: Epoch = None, +) -> Domain: """ Return the domain number of the current fork and ``domain_type``. """ - epoch = state.current_epoch(slots_per_epoch) if message_epoch is None else message_epoch + epoch = ( + state.current_epoch(slots_per_epoch) if message_epoch is None else message_epoch + ) fork_version = _get_fork_version(state.fork, epoch) return compute_domain(signature_domain, fork_version) diff --git a/eth2/beacon/operations/attestation_pool.py b/eth2/beacon/operations/attestation_pool.py index 06ec7771f5..103dabfb48 100644 --- a/eth2/beacon/operations/attestation_pool.py +++ b/eth2/beacon/operations/attestation_pool.py @@ -1,7 +1,7 @@ -from .pool import OperationPool - from eth2.beacon.types.attestations import Attestation +from .pool import OperationPool + class AttestationPool(OperationPool[Attestation]): pass diff --git a/eth2/beacon/operations/pool.py b/eth2/beacon/operations/pool.py index fd965731f5..0817ea5ac7 100644 --- a/eth2/beacon/operations/pool.py +++ b/eth2/beacon/operations/pool.py @@ -1,8 +1,7 @@ from typing import Dict, Generic, Iterator, Tuple, TypeVar -from typing_extensions import Protocol from eth_typing import Hash32 - +from typing_extensions import Protocol HashTreeRoot = Hash32 @@ -11,7 +10,7 @@ class Operation(Protocol): hash_tree_root: HashTreeRoot -T = TypeVar('T', bound='Operation') +T = TypeVar("T", bound="Operation") class OperationPool(Generic[T]): diff --git a/eth2/beacon/scripts/run_beacon_nodes.py b/eth2/beacon/scripts/run_beacon_nodes.py index 5faa04990e..b94d33041c 100755 --- a/eth2/beacon/scripts/run_beacon_nodes.py +++ b/eth2/beacon/scripts/run_beacon_nodes.py @@ -1,51 +1,25 @@ #!/usr/bin/env python import asyncio -from collections import ( - defaultdict, -) +from collections import defaultdict import logging +from pathlib import Path import signal import sys import time -from typing import ( - ClassVar, - Dict, - List, - MutableSet, - NamedTuple, - Optional, - Tuple, -) - -from pathlib import Path - -from libp2p.peer.id import ( - ID, -) - -from eth_keys.datatypes import ( - PrivateKey, -) - -from eth_utils import ( - remove_0x_prefix, -) +from typing import ClassVar, Dict, List, MutableSet, NamedTuple, Optional, Tuple -from multiaddr import ( - Multiaddr, -) +from eth_keys.datatypes import PrivateKey +from eth_utils import remove_0x_prefix +from libp2p.peer.id import ID +from multiaddr import Multiaddr -from trinity.protocol.bcc_libp2p.utils import ( - peer_id_from_pubkey, -) +from trinity.protocol.bcc_libp2p.utils import peer_id_from_pubkey async def run(cmd): proc = await asyncio.create_subprocess_shell( - cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) return proc @@ -85,11 +59,12 @@ class Node: ) def __init__( - self, - name: str, - node_privkey: str, - port: int, - preferred_nodes: Optional[Tuple["Node", ...]] = None) -> None: + self, + name: str, + node_privkey: str, + port: int, + preferred_nodes: Optional[Tuple["Node", ...]] = None, + ) -> None: self.name = name self.node_privkey = PrivateKey(bytes.fromhex(node_privkey)) self.port = port @@ -137,7 +112,9 @@ def cmd(self) -> str: "-l debug2", ] if len(self.preferred_nodes) != 0: - preferred_nodes_str = ",".join([str(node.maddr) for node in self.preferred_nodes]) + preferred_nodes_str = ",".join( + [str(node.maddr) for node in self.preferred_nodes] + ) _cmds.append(f"--preferred_nodes={preferred_nodes_str}") _cmd = " ".join(_cmds) return _cmd @@ -162,8 +139,12 @@ async def run(self) -> None: print(f"Spinning up {self.name}") self.proc = await run(self.cmd) self.running_nodes.append(self) - self.tasks.append(asyncio.ensure_future(self._print_logs('stdout', self.proc.stdout))) - self.tasks.append(asyncio.ensure_future(self._print_logs('stderr', self.proc.stderr))) + self.tasks.append( + asyncio.ensure_future(self._print_logs("stdout", self.proc.stdout)) + ) + self.tasks.append( + asyncio.ensure_future(self._print_logs("stderr", self.proc.stderr)) + ) try: await self._log_monitor() except EventTimeOutError as e: @@ -185,9 +166,11 @@ async def _log_monitor(self) -> None: ) await asyncio.sleep(0.1) - async def _print_logs(self, from_stream: str, stream_reader: asyncio.StreamReader) -> None: + async def _print_logs( + self, from_stream: str, stream_reader: asyncio.StreamReader + ) -> None: async for line_bytes in stream_reader: - line = line_bytes.decode('utf-8').replace('\n', '') + line = line_bytes.decode("utf-8").replace("\n", "") # TODO: Preprocessing self._record_happenning_logs(from_stream, line) print(f"{self.logging_name}.{from_stream}\t: {line}") @@ -203,24 +186,22 @@ async def main(): num_validators = 9 genesis_delay = 20 - proc = await run( - f"rm -rf {Node.dir_root}" - ) + proc = await run(f"rm -rf {Node.dir_root}") await proc.wait() - proc = await run( - f"mkdir -p {Node.dir_root}" - ) + proc = await run(f"mkdir -p {Node.dir_root}") await proc.wait() print("Generating genesis file") proc = await run( - " ".join(( - "trinity-beacon", - "testnet", - f"--num={num_validators}", - f"--network-dir={Node.dir_root}", - f"--genesis-delay={genesis_delay}", - )) + " ".join( + ( + "trinity-beacon", + "testnet", + f"--num={num_validators}", + f"--network-dir={Node.dir_root}", + f"--genesis-delay={genesis_delay}", + ) + ) ) await proc.wait() diff --git a/eth2/beacon/state_machines/base.py b/eth2/beacon/state_machines/base.py index 07df99e9c0..38c0326554 100644 --- a/eth2/beacon/state_machines/base.py +++ b/eth2/beacon/state_machines/base.py @@ -1,32 +1,17 @@ -from abc import ( - ABC, - abstractmethod, -) -from typing import ( - Tuple, - Type, -) - -from eth._utils.datatypes import ( - Configurable, -) - -from eth2.configs import ( # noqa: F401 - Eth2Config, -) +from abc import ABC, abstractmethod +from typing import Tuple, Type + +from eth._utils.datatypes import Configurable + from eth2.beacon.db.chain import BaseBeaconChainDB from eth2.beacon.fork_choice.scoring import ScoringFn as ForkChoiceScoringFn from eth2.beacon.operations.attestation_pool import AttestationPool from eth2.beacon.types.blocks import BaseBeaconBlock from eth2.beacon.types.states import BeaconState -from eth2.beacon.typing import ( - FromBlockParams, - Slot, -) +from eth2.beacon.typing import FromBlockParams, Slot +from eth2.configs import Eth2Config # noqa: F401 -from .state_transitions import ( - BaseStateTransition, -) +from .state_transitions import BaseStateTransition class BaseBeaconStateMachine(Configurable, ABC): @@ -42,11 +27,13 @@ class BaseBeaconStateMachine(Configurable, ABC): state_transition_class = None # type: Type[BaseStateTransition] @abstractmethod - def __init__(self, - chaindb: BaseBeaconChainDB, - attestation_pool: AttestationPool, - slot: Slot, - state: BeaconState=None) -> None: + def __init__( + self, + chaindb: BaseBeaconChainDB, + attestation_pool: AttestationPool, + slot: Slot, + state: BeaconState = None, + ) -> None: ... @classmethod @@ -77,24 +64,30 @@ def get_fork_choice_scoring(self) -> ForkChoiceScoringFn: # Import block API # @abstractmethod - def import_block(self, - block: BaseBeaconBlock, - check_proposer_signature: bool=True) -> Tuple[BeaconState, BaseBeaconBlock]: + def import_block( + self, + block: BaseBeaconBlock, + state: BeaconState, + check_proposer_signature: bool = True, + ) -> Tuple[BeaconState, BaseBeaconBlock]: ... @staticmethod @abstractmethod - def create_block_from_parent(parent_block: BaseBeaconBlock, - block_params: FromBlockParams) -> BaseBeaconBlock: + def create_block_from_parent( + parent_block: BaseBeaconBlock, block_params: FromBlockParams + ) -> BaseBeaconBlock: ... class BeaconStateMachine(BaseBeaconStateMachine): - def __init__(self, - chaindb: BaseBeaconChainDB, - attestation_pool: AttestationPool, - slot: Slot, - state: BeaconState=None) -> None: + def __init__( + self, + chaindb: BaseBeaconChainDB, + attestation_pool: AttestationPool, + slot: Slot, + state: BeaconState = None, + ) -> None: self.chaindb = chaindb self.attestation_pool = attestation_pool if state is not None: @@ -102,15 +95,6 @@ def __init__(self, else: self.slot = slot - @property - def state(self) -> BeaconState: - if self._state is None: - self._state = self.chaindb.get_state_by_slot( - self.slot, - self.get_state_class() - ) - return self._state - @classmethod def get_block_class(cls) -> Type[BaseBeaconBlock]: """ @@ -140,7 +124,9 @@ def get_state_transiton_class(cls) -> Type[BaseStateTransition]: class that this StateTransition uses for StateTransition. """ if cls.state_transition_class is None: - raise AttributeError("No `state_transition_class` has been set for this StateMachine") + raise AttributeError( + "No `state_transition_class` has been set for this StateMachine" + ) else: return cls.state_transition_class @@ -151,17 +137,16 @@ def state_transition(self) -> BaseStateTransition: # # Import block API # - def import_block(self, - block: BaseBeaconBlock, - check_proposer_signature: bool=True) -> Tuple[BeaconState, BaseBeaconBlock]: + def import_block( + self, + block: BaseBeaconBlock, + state: BeaconState, + check_proposer_signature: bool = True, + ) -> Tuple[BeaconState, BaseBeaconBlock]: state = self.state_transition.apply_state_transition( - self.state, - block=block, - check_proposer_signature=check_proposer_signature, + state, block=block, check_proposer_signature=check_proposer_signature ) - block = block.copy( - state_root=state.hash_tree_root, - ) + block = block.copy(state_root=state.hash_tree_root) return state, block diff --git a/eth2/beacon/state_machines/forks/serenity/__init__.py b/eth2/beacon/state_machines/forks/serenity/__init__.py index 93c597af7f..c4d21eb080 100644 --- a/eth2/beacon/state_machines/forks/serenity/__init__.py +++ b/eth2/beacon/state_machines/forks/serenity/__init__.py @@ -1,31 +1,24 @@ from typing import Type # noqa: F401 +from eth2.beacon.fork_choice.lmd_ghost import lmd_ghost_scoring from eth2.beacon.fork_choice.scoring import ScoringFn as ForkChoiceScoringFn -from eth2.beacon.fork_choice.lmd_ghost import ( - lmd_ghost_scoring, -) -from eth2.beacon.typing import ( - FromBlockParams, +from eth2.beacon.state_machines.base import BeaconStateMachine +from eth2.beacon.state_machines.state_transitions import ( # noqa: F401 + BaseStateTransition, ) - from eth2.beacon.types.blocks import BaseBeaconBlock # noqa: F401 from eth2.beacon.types.states import BeaconState # noqa: F401 +from eth2.beacon.typing import FromBlockParams -from eth2.beacon.state_machines.base import BeaconStateMachine -from eth2.beacon.state_machines.state_transitions import BaseStateTransition # noqa: F401 - +from .blocks import SerenityBeaconBlock, create_serenity_block_from_parent from .configs import SERENITY_CONFIG -from .blocks import ( - create_serenity_block_from_parent, - SerenityBeaconBlock, -) -from .states import SerenityBeaconState from .state_transitions import SerenityStateTransition +from .states import SerenityBeaconState class SerenityStateMachine(BeaconStateMachine): # fork name - fork = 'serenity' # type: str + fork = "serenity" # type: str config = SERENITY_CONFIG # classes @@ -35,20 +28,19 @@ class SerenityStateMachine(BeaconStateMachine): # methods @staticmethod - def create_block_from_parent(parent_block: BaseBeaconBlock, - block_params: FromBlockParams) -> BaseBeaconBlock: + def create_block_from_parent( + parent_block: BaseBeaconBlock, block_params: FromBlockParams + ) -> BaseBeaconBlock: return create_serenity_block_from_parent(parent_block, block_params) def _get_justified_head_state(self) -> BeaconState: justified_head = self.chaindb.get_justified_head(self.block_class) - return self.chaindb.get_state_by_root(justified_head.state_root, self.state_class) + return self.chaindb.get_state_by_root( + justified_head.state_root, self.state_class + ) def get_fork_choice_scoring(self) -> ForkChoiceScoringFn: state = self._get_justified_head_state() return lmd_ghost_scoring( - self.chaindb, - self.attestation_pool, - state, - self.config, - self.block_class + self.chaindb, self.attestation_pool, state, self.config, self.block_class ) diff --git a/eth2/beacon/state_machines/forks/serenity/block_processing.py b/eth2/beacon/state_machines/forks/serenity/block_processing.py index 51d1fbd13f..ba8f23faa4 100644 --- a/eth2/beacon/state_machines/forks/serenity/block_processing.py +++ b/eth2/beacon/state_machines/forks/serenity/block_processing.py @@ -1,57 +1,38 @@ from eth2._utils.hash import hash_eth2 +from eth2._utils.numeric import bitwise_xor from eth2._utils.tuple import update_tuple_item -from eth2._utils.numeric import ( - bitwise_xor, -) - -from eth2.configs import ( - Eth2Config, - CommitteeConfig, -) -from eth2.beacon.types.states import BeaconState -from eth2.beacon.types.blocks import BaseBeaconBlock -from eth2.beacon.types.block_headers import BeaconBlockHeader - +from eth2.beacon.committee_helpers import get_beacon_proposer_index +from eth2.beacon.helpers import get_randao_mix from eth2.beacon.state_machines.forks.serenity.block_validation import ( validate_randao_reveal, ) - -from eth2.beacon.helpers import ( - get_randao_mix, -) -from eth2.beacon.committee_helpers import ( - get_beacon_proposer_index, -) +from eth2.beacon.types.block_headers import BeaconBlockHeader +from eth2.beacon.types.blocks import BaseBeaconBlock +from eth2.beacon.types.states import BeaconState +from eth2.configs import CommitteeConfig, Eth2Config from .block_validation import ( - validate_block_slot, validate_block_parent_root, - validate_proposer_signature, + validate_block_slot, validate_proposer_is_not_slashed, + validate_proposer_signature, ) - -from .operation_processing import ( - process_operations, -) +from .operation_processing import process_operations -def process_block_header(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config, - check_proposer_signature: bool) -> BeaconState: +def process_block_header( + state: BeaconState, + block: BaseBeaconBlock, + config: Eth2Config, + check_proposer_signature: bool, +) -> BeaconState: validate_block_slot(state, block) validate_block_parent_root(state, block) - validate_proposer_is_not_slashed( - state, - block.signing_root, - CommitteeConfig(config), - ) + validate_proposer_is_not_slashed(state, block.signing_root, CommitteeConfig(config)) if check_proposer_signature: validate_proposer_signature( - state, - block, - committee_config=CommitteeConfig(config), + state, block, committee_config=CommitteeConfig(config) ) return state.copy( @@ -59,16 +40,15 @@ def process_block_header(state: BeaconState, slot=block.slot, parent_root=block.parent_root, body_root=block.body.hash_tree_root, - ), + ) ) -def process_randao(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_randao( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: proposer_index = get_beacon_proposer_index( - state=state, - committee_config=CommitteeConfig(config), + state=state, committee_config=CommitteeConfig(config) ) epoch = state.current_epoch(config.SLOTS_PER_EPOCH) @@ -93,34 +73,34 @@ def process_randao(state: BeaconState, return state.copy( randao_mixes=update_tuple_item( - state.randao_mixes, - randao_mix_index, - new_randao_mix, - ), + state.randao_mixes, randao_mix_index, new_randao_mix + ) ) -def process_eth1_data(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_eth1_data( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: body = block.body new_eth1_data_votes = state.eth1_data_votes + (body.eth1_data,) new_eth1_data = state.eth1_data - if new_eth1_data_votes.count(body.eth1_data) * 2 > config.SLOTS_PER_ETH1_VOTING_PERIOD: + if ( + new_eth1_data_votes.count(body.eth1_data) * 2 + > config.SLOTS_PER_ETH1_VOTING_PERIOD + ): new_eth1_data = body.eth1_data - return state.copy( - eth1_data=new_eth1_data, - eth1_data_votes=new_eth1_data_votes, - ) + return state.copy(eth1_data=new_eth1_data, eth1_data_votes=new_eth1_data_votes) -def process_block(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config, - check_proposer_signature: bool=True) -> BeaconState: +def process_block( + state: BeaconState, + block: BaseBeaconBlock, + config: Eth2Config, + check_proposer_signature: bool = True, +) -> BeaconState: state = process_block_header(state, block, config, check_proposer_signature) state = process_randao(state, block, config) state = process_eth1_data(state, block, config) diff --git a/eth2/beacon/state_machines/forks/serenity/block_validation.py b/eth2/beacon/state_machines/forks/serenity/block_validation.py index 93af2c2d0d..b98d97b266 100644 --- a/eth2/beacon/state_machines/forks/serenity/block_validation.py +++ b/eth2/beacon/state_machines/forks/serenity/block_validation.py @@ -1,55 +1,25 @@ -from typing import ( # noqa: F401 - cast, - Iterable, - Sequence, - Tuple, -) +from typing import Iterable, Sequence, Tuple, cast # noqa: F401 -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) -from eth_utils import ( - encode_hex, - ValidationError, -) +from eth.constants import ZERO_HASH32 +from eth_typing import BLSPubkey, BLSSignature, Hash32 +from eth_utils import ValidationError, encode_hex import ssz -from eth.constants import ( - ZERO_HASH32, -) -from eth2._utils.hash import ( - hash_eth2, -) from eth2._utils.bls import bls - -from eth2.configs import ( - CommitteeConfig, -) +from eth2._utils.hash import hash_eth2 from eth2.beacon.attestation_helpers import ( get_attestation_data_slot, - validate_indexed_attestation, is_slashable_attestation_data, + validate_indexed_attestation, ) -from eth2.beacon.committee_helpers import ( - get_beacon_proposer_index, -) -from eth2.beacon.epoch_processing_helpers import ( - get_indexed_attestation, -) -from eth2.beacon.constants import ( - FAR_FUTURE_EPOCH, -) -from eth2.beacon.signature_domain import ( - SignatureDomain, -) -from eth2.beacon.helpers import ( - get_domain, - compute_epoch_of_slot, -) -from eth2.beacon.types.attestations import Attestation, IndexedAttestation +from eth2.beacon.committee_helpers import get_beacon_proposer_index +from eth2.beacon.constants import FAR_FUTURE_EPOCH +from eth2.beacon.epoch_processing_helpers import get_indexed_attestation +from eth2.beacon.exceptions import SignatureError +from eth2.beacon.helpers import compute_epoch_of_slot, get_domain +from eth2.beacon.signature_domain import SignatureDomain from eth2.beacon.types.attestation_data import AttestationData +from eth2.beacon.types.attestations import Attestation, IndexedAttestation from eth2.beacon.types.attester_slashings import AttesterSlashing from eth2.beacon.types.blocks import BaseBeaconBlock, BeaconBlockHeader from eth2.beacon.types.checkpoints import Checkpoint @@ -57,29 +27,19 @@ from eth2.beacon.types.proposer_slashings import ProposerSlashing from eth2.beacon.types.states import BeaconState from eth2.beacon.types.transfers import Transfer -from eth2.beacon.types.voluntary_exits import VoluntaryExit from eth2.beacon.types.validators import Validator -from eth2.beacon.typing import ( - Epoch, - Shard, - Slot, -) -from eth2.configs import ( - Eth2Config, -) -from eth2.beacon.exceptions import ( - SignatureError, -) +from eth2.beacon.types.voluntary_exits import VoluntaryExit +from eth2.beacon.typing import Epoch, Shard, Slot +from eth2.configs import CommitteeConfig, Eth2Config -def validate_correct_number_of_deposits(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> None: +def validate_correct_number_of_deposits( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> None: body = block.body deposit_count_in_block = len(body.deposits) expected_deposit_count = min( - config.MAX_DEPOSITS, - state.eth1_data.deposit_count - state.eth1_deposit_index, + config.MAX_DEPOSITS, state.eth1_data.deposit_count - state.eth1_deposit_index ) if deposit_count_in_block != expected_deposit_count: @@ -91,9 +51,9 @@ def validate_correct_number_of_deposits(state: BeaconState, ) -def validate_unique_transfers(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> None: +def validate_unique_transfers( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> None: body = block.body transfer_count_in_block = len(body.transfers) unique_transfer_count = len(set(body.transfers)) @@ -107,16 +67,14 @@ def validate_unique_transfers(state: BeaconState, # # Block validatation # -def validate_block_slot(state: BeaconState, - block: BaseBeaconBlock) -> None: +def validate_block_slot(state: BeaconState, block: BaseBeaconBlock) -> None: if block.slot != state.slot: raise ValidationError( f"block.slot ({block.slot}) is not equal to state.slot ({state.slot})" ) -def validate_block_parent_root(state: BeaconState, - block: BaseBeaconBlock) -> None: +def validate_block_parent_root(state: BeaconState, block: BaseBeaconBlock) -> None: expected_root = state.latest_block_header.signing_root parent_root = block.parent_root if parent_root != expected_root: @@ -126,32 +84,25 @@ def validate_block_parent_root(state: BeaconState, ) -def validate_proposer_is_not_slashed(state: BeaconState, - block_root: Hash32, - config: CommitteeConfig) -> None: +def validate_proposer_is_not_slashed( + state: BeaconState, block_root: Hash32, config: CommitteeConfig +) -> None: proposer_index = get_beacon_proposer_index(state, config) proposer = state.validators[proposer_index] if proposer.slashed: - raise ValidationError( - f"Proposer for block {encode_hex(block_root)} is slashed" - ) + raise ValidationError(f"Proposer for block {encode_hex(block_root)} is slashed") -def validate_proposer_signature(state: BeaconState, - block: BaseBeaconBlock, - committee_config: CommitteeConfig) -> None: +def validate_proposer_signature( + state: BeaconState, block: BaseBeaconBlock, committee_config: CommitteeConfig +) -> None: message_hash = block.signing_root # Get the public key of proposer - beacon_proposer_index = get_beacon_proposer_index( - state, - committee_config, - ) + beacon_proposer_index = get_beacon_proposer_index(state, committee_config) proposer_pubkey = state.validators[beacon_proposer_index].pubkey domain = get_domain( - state, - SignatureDomain.DOMAIN_BEACON_PROPOSER, - committee_config.SLOTS_PER_EPOCH, + state, SignatureDomain.DOMAIN_BEACON_PROPOSER, committee_config.SLOTS_PER_EPOCH ) try: @@ -171,11 +122,13 @@ def validate_proposer_signature(state: BeaconState, # # RANDAO validatation # -def validate_randao_reveal(state: BeaconState, - proposer_index: int, - epoch: Epoch, - randao_reveal: Hash32, - slots_per_epoch: int) -> None: +def validate_randao_reveal( + state: BeaconState, + proposer_index: int, + epoch: Epoch, + randao_reveal: Hash32, + slots_per_epoch: int, +) -> None: proposer = state.validators[proposer_index] proposer_pubkey = proposer.pubkey message_hash = ssz.get_hash_tree_root(epoch, sedes=ssz.sedes.uint64) @@ -195,9 +148,9 @@ def validate_randao_reveal(state: BeaconState, # # Proposer slashing validation # -def validate_proposer_slashing(state: BeaconState, - proposer_slashing: ProposerSlashing, - slots_per_epoch: int) -> None: +def validate_proposer_slashing( + state: BeaconState, proposer_slashing: ProposerSlashing, slots_per_epoch: int +) -> None: """ Validate the given ``proposer_slashing``. Raise ``ValidationError`` if it's invalid. @@ -225,8 +178,9 @@ def validate_proposer_slashing(state: BeaconState, ) -def validate_proposer_slashing_epoch(proposer_slashing: ProposerSlashing, - slots_per_epoch: int) -> None: +def validate_proposer_slashing_epoch( + proposer_slashing: ProposerSlashing, slots_per_epoch: int +) -> None: epoch_1 = compute_epoch_of_slot(proposer_slashing.header_1.slot, slots_per_epoch) epoch_2 = compute_epoch_of_slot(proposer_slashing.header_2.slot, slots_per_epoch) @@ -246,9 +200,9 @@ def validate_proposer_slashing_headers(proposer_slashing: ProposerSlashing) -> N ) -def validate_proposer_slashing_is_slashable(state: BeaconState, - proposer: Validator, - slots_per_epoch: int) -> None: +def validate_proposer_slashing_is_slashable( + state: BeaconState, proposer: Validator, slots_per_epoch: int +) -> None: current_epoch = state.current_epoch(slots_per_epoch) is_slashable = proposer.is_slashable(current_epoch) if not is_slashable: @@ -257,10 +211,12 @@ def validate_proposer_slashing_is_slashable(state: BeaconState, ) -def validate_block_header_signature(state: BeaconState, - header: BeaconBlockHeader, - pubkey: BLSPubkey, - slots_per_epoch: int) -> None: +def validate_block_header_signature( + state: BeaconState, + header: BeaconBlockHeader, + pubkey: BLSPubkey, + slots_per_epoch: int, +) -> None: try: bls.validate( pubkey=pubkey, @@ -271,7 +227,7 @@ def validate_block_header_signature(state: BeaconState, SignatureDomain.DOMAIN_BEACON_PROPOSER, slots_per_epoch, compute_epoch_of_slot(header.slot, slots_per_epoch), - ) + ), ) except SignatureError as error: raise ValidationError("Header signature is invalid:", error) @@ -280,9 +236,12 @@ def validate_block_header_signature(state: BeaconState, # # Attester slashing validation # -def validate_is_slashable_attestation_data(attestation_1: IndexedAttestation, - attestation_2: IndexedAttestation) -> None: - is_slashable_data = is_slashable_attestation_data(attestation_1.data, attestation_2.data) +def validate_is_slashable_attestation_data( + attestation_1: IndexedAttestation, attestation_2: IndexedAttestation +) -> None: + is_slashable_data = is_slashable_attestation_data( + attestation_1.data, attestation_2.data + ) if not is_slashable_data: raise ValidationError( @@ -290,34 +249,29 @@ def validate_is_slashable_attestation_data(attestation_1: IndexedAttestation, ) -def validate_attester_slashing(state: BeaconState, - attester_slashing: AttesterSlashing, - max_validators_per_committee: int, - slots_per_epoch: int) -> None: +def validate_attester_slashing( + state: BeaconState, + attester_slashing: AttesterSlashing, + max_validators_per_committee: int, + slots_per_epoch: int, +) -> None: attestation_1 = attester_slashing.attestation_1 attestation_2 = attester_slashing.attestation_2 - validate_is_slashable_attestation_data( - attestation_1, - attestation_2, - ) + validate_is_slashable_attestation_data(attestation_1, attestation_2) validate_indexed_attestation( - state, - attestation_1, - max_validators_per_committee, - slots_per_epoch, + state, attestation_1, max_validators_per_committee, slots_per_epoch ) validate_indexed_attestation( - state, - attestation_2, - max_validators_per_committee, - slots_per_epoch, + state, attestation_2, max_validators_per_committee, slots_per_epoch ) -def validate_some_slashing(slashed_any: bool, attester_slashing: AttesterSlashing) -> None: +def validate_some_slashing( + slashed_any: bool, attester_slashing: AttesterSlashing +) -> None: if not slashed_any: raise ValidationError( f"Attesting slashing {attester_slashing} did not yield any slashable validators." @@ -334,9 +288,9 @@ def _validate_eligible_shard_number(shard: Shard, shard_count: int) -> None: ) -def _validate_eligible_target_epoch(target_epoch: Epoch, - current_epoch: Epoch, - previous_epoch: Epoch) -> None: +def _validate_eligible_target_epoch( + target_epoch: Epoch, current_epoch: Epoch, previous_epoch: Epoch +) -> None: if target_epoch not in (previous_epoch, current_epoch): raise ValidationError( f"Attestation with target epoch {target_epoch} must be in either the" @@ -344,10 +298,12 @@ def _validate_eligible_target_epoch(target_epoch: Epoch, ) -def validate_attestation_slot(attestation_slot: Slot, - state_slot: Slot, - slots_per_epoch: int, - min_attestation_inclusion_delay: int) -> None: +def validate_attestation_slot( + attestation_slot: Slot, + state_slot: Slot, + slots_per_epoch: int, + min_attestation_inclusion_delay: int, +) -> None: if attestation_slot + min_attestation_inclusion_delay > state_slot: raise ValidationError( f"Attestation at slot {attestation_slot} can only be included after the" @@ -362,7 +318,9 @@ def validate_attestation_slot(attestation_slot: Slot, ) -def _validate_checkpoint(checkpoint: Checkpoint, expected_checkpoint: Checkpoint) -> None: +def _validate_checkpoint( + checkpoint: Checkpoint, expected_checkpoint: Checkpoint +) -> None: if checkpoint != expected_checkpoint: raise ValidationError( f"Attestation with source checkpoint {checkpoint} did not match the expected" @@ -370,10 +328,12 @@ def _validate_checkpoint(checkpoint: Checkpoint, expected_checkpoint: Checkpoint ) -def _validate_crosslink(crosslink: Crosslink, - target_epoch: Epoch, - parent_crosslink: Crosslink, - max_epochs_per_crosslink: int) -> None: +def _validate_crosslink( + crosslink: Crosslink, + target_epoch: Epoch, + parent_crosslink: Crosslink, + max_epochs_per_crosslink: int, +) -> None: if crosslink.start_epoch != parent_crosslink.end_epoch: raise ValidationError( f"Crosslink with start_epoch {crosslink.start_epoch} did not match the parent" @@ -381,8 +341,7 @@ def _validate_crosslink(crosslink: Crosslink, ) expected_end_epoch = min( - target_epoch, - parent_crosslink.end_epoch + max_epochs_per_crosslink, + target_epoch, parent_crosslink.end_epoch + max_epochs_per_crosslink ) if crosslink.end_epoch != expected_end_epoch: raise ValidationError( @@ -405,9 +364,9 @@ def _validate_crosslink(crosslink: Crosslink, ) -def _validate_attestation_data(state: BeaconState, - data: AttestationData, - config: Eth2Config) -> None: +def _validate_attestation_data( + state: BeaconState, data: AttestationData, config: Eth2Config +) -> None: slots_per_epoch = config.SLOTS_PER_EPOCH current_epoch = state.current_epoch(slots_per_epoch) previous_epoch = state.previous_epoch(slots_per_epoch, config.GENESIS_EPOCH) @@ -427,20 +386,20 @@ def _validate_attestation_data(state: BeaconState, attestation_slot, state.slot, slots_per_epoch, - config.MIN_ATTESTATION_INCLUSION_DELAY + config.MIN_ATTESTATION_INCLUSION_DELAY, ) _validate_checkpoint(data.source, expected_checkpoint) _validate_crosslink( data.crosslink, data.target.epoch, parent_crosslink, - config.MAX_EPOCHS_PER_CROSSLINK + config.MAX_EPOCHS_PER_CROSSLINK, ) -def validate_attestation(state: BeaconState, - attestation: Attestation, - config: Eth2Config) -> None: +def validate_attestation( + state: BeaconState, attestation: Attestation, config: Eth2Config +) -> None: """ Validate the given ``attestation``. Raise ``ValidationError`` if it's invalid. @@ -480,9 +439,9 @@ def _validate_eligible_exit_epoch(exit_epoch: Epoch, current_epoch: Epoch) -> No ) -def _validate_validator_minimum_lifespan(validator: Validator, - current_epoch: Epoch, - persistent_committee_period: int) -> None: +def _validate_validator_minimum_lifespan( + validator: Validator, current_epoch: Epoch, persistent_committee_period: int +) -> None: if current_epoch < validator.activation_epoch + persistent_committee_period: raise ValidationError( f"Validator in voluntary exit has not completed the minimum number of epochs" @@ -491,10 +450,12 @@ def _validate_validator_minimum_lifespan(validator: Validator, ) -def _validate_voluntary_exit_signature(state: BeaconState, - voluntary_exit: VoluntaryExit, - validator: Validator, - slots_per_epoch: int) -> None: +def _validate_voluntary_exit_signature( + state: BeaconState, + voluntary_exit: VoluntaryExit, + validator: Validator, + slots_per_epoch: int, +) -> None: domain = get_domain( state, SignatureDomain.DOMAIN_VOLUNTARY_EXIT, @@ -515,10 +476,12 @@ def _validate_voluntary_exit_signature(state: BeaconState, ) -def validate_voluntary_exit(state: BeaconState, - voluntary_exit: VoluntaryExit, - slots_per_epoch: int, - persistent_committee_period: int) -> None: +def validate_voluntary_exit( + state: BeaconState, + voluntary_exit: VoluntaryExit, + slots_per_epoch: int, + persistent_committee_period: int, +) -> None: validator = state.validators[voluntary_exit.validator_index] current_epoch = state.current_epoch(slots_per_epoch) @@ -526,11 +489,11 @@ def validate_voluntary_exit(state: BeaconState, _validate_validator_has_not_exited(validator) _validate_eligible_exit_epoch(voluntary_exit.epoch, current_epoch) _validate_validator_minimum_lifespan( - validator, - current_epoch, - persistent_committee_period, + validator, current_epoch, persistent_committee_period + ) + _validate_voluntary_exit_signature( + state, voluntary_exit, validator, slots_per_epoch ) - _validate_voluntary_exit_signature(state, voluntary_exit, validator, slots_per_epoch) def _validate_amount_and_fee_magnitude(state: BeaconState, transfer: Transfer) -> None: @@ -551,9 +514,9 @@ def _validate_transfer_slot(state_slot: Slot, transfer_slot: Slot) -> None: ) -def _validate_sender_eligibility(state: BeaconState, - transfer: Transfer, - config: Eth2Config) -> None: +def _validate_sender_eligibility( + state: BeaconState, transfer: Transfer, config: Eth2Config +) -> None: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) sender = state.validators[transfer.sender] sender_balance = state.balances[transfer.sender] @@ -573,9 +536,7 @@ def _validate_sender_eligibility(state: BeaconState, ) if not is_withdrawable: - raise ValidationError( - f"Sender in transfer {transfer} is not withdrawable." - ) + raise ValidationError(f"Sender in transfer {transfer} is not withdrawable.") if not is_transfer_total_allowed: raise ValidationError( @@ -583,12 +544,14 @@ def _validate_sender_eligibility(state: BeaconState, ) -def _validate_sender_pubkey(state: BeaconState, transfer: Transfer, config: Eth2Config) -> None: +def _validate_sender_pubkey( + state: BeaconState, transfer: Transfer, config: Eth2Config +) -> None: sender = state.validators[transfer.sender] - expected_withdrawal_credentials = config.BLS_WITHDRAWAL_PREFIX.to_bytes( - 1, - byteorder='little', - ) + hash_eth2(transfer.pubkey)[1:] + expected_withdrawal_credentials = ( + config.BLS_WITHDRAWAL_PREFIX.to_bytes(1, byteorder="little") + + hash_eth2(transfer.pubkey)[1:] + ) are_withdrawal_credentials_valid = ( sender.withdrawal_credentials == expected_withdrawal_credentials ) @@ -600,14 +563,10 @@ def _validate_sender_pubkey(state: BeaconState, transfer: Transfer, config: Eth2 ) -def _validate_transfer_signature(state: BeaconState, - transfer: Transfer, - config: Eth2Config) -> None: - domain = get_domain( - state, - SignatureDomain.DOMAIN_TRANSFER, - config.SLOTS_PER_EPOCH, - ) +def _validate_transfer_signature( + state: BeaconState, transfer: Transfer, config: Eth2Config +) -> None: + domain = get_domain(state, SignatureDomain.DOMAIN_TRANSFER, config.SLOTS_PER_EPOCH) try: bls.validate( pubkey=transfer.pubkey, @@ -616,20 +575,18 @@ def _validate_transfer_signature(state: BeaconState, domain=domain, ) except SignatureError as error: - raise ValidationError( - f"Invalid signature for transfer {transfer}", - error, - ) + raise ValidationError(f"Invalid signature for transfer {transfer}", error) -def _validate_transfer_does_not_result_in_dust(state: BeaconState, - transfer: Transfer, - config: Eth2Config) -> None: +def _validate_transfer_does_not_result_in_dust( + state: BeaconState, transfer: Transfer, config: Eth2Config +) -> None: resulting_sender_balance = max( - 0, - state.balances[transfer.sender] - (transfer.amount + transfer.fee), + 0, state.balances[transfer.sender] - (transfer.amount + transfer.fee) + ) + resulting_sender_balance_is_dust = ( + 0 < resulting_sender_balance < config.MIN_DEPOSIT_AMOUNT ) - resulting_sender_balance_is_dust = 0 < resulting_sender_balance < config.MIN_DEPOSIT_AMOUNT if resulting_sender_balance_is_dust: raise ValidationError( f"Effect of transfer {transfer} results in dust balance for sender." @@ -645,9 +602,9 @@ def _validate_transfer_does_not_result_in_dust(state: BeaconState, ) -def validate_transfer(state: BeaconState, - transfer: Transfer, - config: Eth2Config) -> None: +def validate_transfer( + state: BeaconState, transfer: Transfer, config: Eth2Config +) -> None: _validate_amount_and_fee_magnitude(state, transfer) _validate_transfer_slot(state.slot, transfer.slot) _validate_sender_eligibility(state, transfer, config) diff --git a/eth2/beacon/state_machines/forks/serenity/blocks.py b/eth2/beacon/state_machines/forks/serenity/blocks.py index cdc20ad9d1..d1c3cd2698 100644 --- a/eth2/beacon/state_machines/forks/serenity/blocks.py +++ b/eth2/beacon/state_machines/forks/serenity/blocks.py @@ -1,19 +1,14 @@ -from eth2.beacon.typing import ( - FromBlockParams, -) - -from eth2.beacon.types.blocks import ( - BaseBeaconBlock, - BeaconBlock, -) +from eth2.beacon.types.blocks import BaseBeaconBlock, BeaconBlock +from eth2.beacon.typing import FromBlockParams class SerenityBeaconBlock(BeaconBlock): pass -def create_serenity_block_from_parent(parent_block: BaseBeaconBlock, - block_params: FromBlockParams) -> BaseBeaconBlock: +def create_serenity_block_from_parent( + parent_block: BaseBeaconBlock, block_params: FromBlockParams +) -> BaseBeaconBlock: block = SerenityBeaconBlock.from_parent(parent_block, block_params) return block diff --git a/eth2/beacon/state_machines/forks/serenity/configs.py b/eth2/beacon/state_machines/forks/serenity/configs.py index 6de2e33835..1a7dc5dc82 100644 --- a/eth2/beacon/state_machines/forks/serenity/configs.py +++ b/eth2/beacon/state_machines/forks/serenity/configs.py @@ -1,68 +1,61 @@ -from eth_utils import ( - decode_hex, -) -from eth2.configs import Eth2Config -from eth2.beacon.constants import ( - GWEI_PER_ETH, -) -from eth2.beacon.typing import ( - Epoch, - Gwei, - Second, - Slot, -) +from eth_utils import decode_hex +from eth2.beacon.constants import GWEI_PER_ETH +from eth2.beacon.typing import Epoch, Gwei, Second, Slot +from eth2.configs import Eth2Config SERENITY_CONFIG = Eth2Config( # Misc - SHARD_COUNT=2**10, # (= 1,024) shards - TARGET_COMMITTEE_SIZE=2**7, # (= 128) validators - MAX_VALIDATORS_PER_COMMITTEE=2**12, # (= 4,096) validators - MIN_PER_EPOCH_CHURN_LIMIT=2**2, - CHURN_LIMIT_QUOTIENT=2**16, + SHARD_COUNT=2 ** 10, # (= 1,024) shards + TARGET_COMMITTEE_SIZE=2 ** 7, # (= 128) validators + MAX_VALIDATORS_PER_COMMITTEE=2 ** 12, # (= 4,096) validators + MIN_PER_EPOCH_CHURN_LIMIT=2 ** 2, + CHURN_LIMIT_QUOTIENT=2 ** 16, SHUFFLE_ROUND_COUNT=90, # Genesis - MIN_GENESIS_ACTIVE_VALIDATOR_COUNT=2**16, + MIN_GENESIS_ACTIVE_VALIDATOR_COUNT=2 ** 16, MIN_GENESIS_TIME=1578009600, # (= Jan 3, 2020) # Gwei values - MIN_DEPOSIT_AMOUNT=Gwei(2**0 * GWEI_PER_ETH), # (= 1,000,000,000) Gwei - MAX_EFFECTIVE_BALANCE=Gwei(2**5 * GWEI_PER_ETH), # (= 32,000,000,00) Gwei - EJECTION_BALANCE=Gwei(2**4 * GWEI_PER_ETH), # (= 16,000,000,000) Gwei - EFFECTIVE_BALANCE_INCREMENT=Gwei(2**0 * GWEI_PER_ETH), # (= 1,000,000,000) Gwei + MIN_DEPOSIT_AMOUNT=Gwei(2 ** 0 * GWEI_PER_ETH), # (= 1,000,000,000) Gwei + MAX_EFFECTIVE_BALANCE=Gwei(2 ** 5 * GWEI_PER_ETH), # (= 32,000,000,00) Gwei + EJECTION_BALANCE=Gwei(2 ** 4 * GWEI_PER_ETH), # (= 16,000,000,000) Gwei + EFFECTIVE_BALANCE_INCREMENT=Gwei(2 ** 0 * GWEI_PER_ETH), # (= 1,000,000,000) Gwei # Initial values GENESIS_SLOT=Slot(0), GENESIS_EPOCH=Epoch(0), BLS_WITHDRAWAL_PREFIX=0, # Time parameters SECONDS_PER_SLOT=Second(6), # seconds - MIN_ATTESTATION_INCLUSION_DELAY=2**0, # (= 1) slots - SLOTS_PER_EPOCH=2**6, # (= 64) slots - MIN_SEED_LOOKAHEAD=2**0, # (= 1) epochs - ACTIVATION_EXIT_DELAY=2**2, # (= 4) epochs - SLOTS_PER_ETH1_VOTING_PERIOD=2**10, # (= 16) epochs - SLOTS_PER_HISTORICAL_ROOT=2**13, # (= 8,192) slots - MIN_VALIDATOR_WITHDRAWABILITY_DELAY=2**8, # (= 256) epochs - PERSISTENT_COMMITTEE_PERIOD=2**11, # (= 2,048) epochs - MAX_EPOCHS_PER_CROSSLINK=2**6, - MIN_EPOCHS_TO_INACTIVITY_PENALTY=2**2, + MIN_ATTESTATION_INCLUSION_DELAY=2 ** 0, # (= 1) slots + SLOTS_PER_EPOCH=2 ** 6, # (= 64) slots + MIN_SEED_LOOKAHEAD=2 ** 0, # (= 1) epochs + ACTIVATION_EXIT_DELAY=2 ** 2, # (= 4) epochs + SLOTS_PER_ETH1_VOTING_PERIOD=2 ** 10, # (= 16) epochs + SLOTS_PER_HISTORICAL_ROOT=2 ** 13, # (= 8,192) slots + MIN_VALIDATOR_WITHDRAWABILITY_DELAY=2 ** 8, # (= 256) epochs + PERSISTENT_COMMITTEE_PERIOD=2 ** 11, # (= 2,048) epochs + MAX_EPOCHS_PER_CROSSLINK=2 ** 6, + MIN_EPOCHS_TO_INACTIVITY_PENALTY=2 ** 2, # State list lengths - EPOCHS_PER_HISTORICAL_VECTOR=2**16, - EPOCHS_PER_SLASHINGS_VECTOR=2**13, - HISTORICAL_ROOTS_LIMIT=2**24, - VALIDATOR_REGISTRY_LIMIT=2**40, + EPOCHS_PER_HISTORICAL_VECTOR=2 ** 16, + EPOCHS_PER_SLASHINGS_VECTOR=2 ** 13, + HISTORICAL_ROOTS_LIMIT=2 ** 24, + VALIDATOR_REGISTRY_LIMIT=2 ** 40, # Reward and penalty quotients - BASE_REWARD_FACTOR=2**6, # (= 64) - WHISTLEBLOWER_REWARD_QUOTIENT=2**9, # (= 512) - PROPOSER_REWARD_QUOTIENT=2**3, - INACTIVITY_PENALTY_QUOTIENT=2**25, # (= 33,554,432) - MIN_SLASHING_PENALTY_QUOTIENT=2**5, + BASE_REWARD_FACTOR=2 ** 6, # (= 64) + WHISTLEBLOWER_REWARD_QUOTIENT=2 ** 9, # (= 512) + PROPOSER_REWARD_QUOTIENT=2 ** 3, + INACTIVITY_PENALTY_QUOTIENT=2 ** 25, # (= 33,554,432) + MIN_SLASHING_PENALTY_QUOTIENT=2 ** 5, # Max operations per block - MAX_PROPOSER_SLASHINGS=2**4, # (= 16) - MAX_ATTESTER_SLASHINGS=2**0, # (= 1) - MAX_ATTESTATIONS=2**7, # (= 128) - MAX_DEPOSITS=2**4, # (= 16) - MAX_VOLUNTARY_EXITS=2**4, # (= 16) + MAX_PROPOSER_SLASHINGS=2 ** 4, # (= 16) + MAX_ATTESTER_SLASHINGS=2 ** 0, # (= 1) + MAX_ATTESTATIONS=2 ** 7, # (= 128) + MAX_DEPOSITS=2 ** 4, # (= 16) + MAX_VOLUNTARY_EXITS=2 ** 4, # (= 16) MAX_TRANSFERS=0, # Deposit contract - DEPOSIT_CONTRACT_ADDRESS=decode_hex('0x1234567890123456789012345678901234567890'), # TBD + DEPOSIT_CONTRACT_ADDRESS=decode_hex( + "0x1234567890123456789012345678901234567890" + ), # TBD ) diff --git a/eth2/beacon/state_machines/forks/serenity/epoch_processing.py b/eth2/beacon/state_machines/forks/serenity/epoch_processing.py index b3dfb4741f..b304a84341 100644 --- a/eth2/beacon/state_machines/forks/serenity/epoch_processing.py +++ b/eth2/beacon/state_machines/forks/serenity/epoch_processing.py @@ -1,50 +1,31 @@ -from typing import ( - Sequence, - Set, - Tuple, -) +from typing import Sequence, Set, Tuple -from eth_utils.toolz import ( - curry, -) +from eth_typing import Hash32 +from eth_utils.toolz import curry import ssz -from eth_typing import ( - Hash32, -) - -from eth2._utils.tuple import ( - update_tuple_item, - update_tuple_item_with_fn, -) -from eth2.configs import ( - Eth2Config, - CommitteeConfig, -) -from eth2.beacon.constants import ( - BASE_REWARDS_PER_EPOCH, - FAR_FUTURE_EPOCH, -) +from eth2._utils.tuple import update_tuple_item, update_tuple_item_with_fn from eth2.beacon.committee_helpers import ( - get_crosslink_committee, - get_compact_committees_root, get_committee_count, - get_start_shard, + get_compact_committees_root, + get_crosslink_committee, get_shard_delta, + get_start_shard, ) +from eth2.beacon.constants import BASE_REWARDS_PER_EPOCH, FAR_FUTURE_EPOCH from eth2.beacon.epoch_processing_helpers import ( + compute_activation_exit_epoch, decrease_balance, get_attesting_balance, get_attesting_indices, get_base_reward, - get_validator_churn_limit, - compute_activation_exit_epoch, get_matching_head_attestations, get_matching_source_attestations, get_matching_target_attestations, get_total_active_balance, get_total_balance, get_unslashed_attesting_indices, + get_validator_churn_limit, get_winning_crosslink_and_attesting_indices, increase_balance, ) @@ -53,96 +34,65 @@ get_block_root, get_randao_mix, ) -from eth2.beacon.validator_status_helpers import ( - initiate_exit_for_validator, -) from eth2.beacon.types.checkpoints import Checkpoint from eth2.beacon.types.eth1_data import Eth1Data from eth2.beacon.types.historical_batch import HistoricalBatch from eth2.beacon.types.pending_attestations import PendingAttestation from eth2.beacon.types.states import BeaconState from eth2.beacon.types.validators import Validator -from eth2.beacon.typing import ( - Bitfield, - Epoch, - Gwei, - Shard, - ValidatorIndex, -) +from eth2.beacon.typing import Bitfield, Epoch, Gwei, Shard, ValidatorIndex +from eth2.beacon.validator_status_helpers import initiate_exit_for_validator +from eth2.configs import CommitteeConfig, Eth2Config def _bft_threshold_met(participation: Gwei, total: Gwei) -> bool: return 3 * participation >= 2 * total -def _is_threshold_met_against_active_set(state: BeaconState, - attestations: Sequence[PendingAttestation], - config: Eth2Config) -> bool: +def _is_threshold_met_against_active_set( + state: BeaconState, attestations: Sequence[PendingAttestation], config: Eth2Config +) -> bool: """ Predicate indicating if the balance at risk of validators making an attestation in ``attestations`` is greater than the fault tolerance threshold of the total balance. """ - attesting_balance = get_attesting_balance( - state, - attestations, - config - ) + attesting_balance = get_attesting_balance(state, attestations, config) total_balance = get_total_active_balance(state, config) - return _bft_threshold_met( - attesting_balance, - total_balance, - ) + return _bft_threshold_met(attesting_balance, total_balance) def _is_epoch_justifiable(state: BeaconState, epoch: Epoch, config: Eth2Config) -> bool: - attestations = get_matching_target_attestations( - state, - epoch, - config, - ) - return _is_threshold_met_against_active_set( - state, - attestations, - config, - ) + attestations = get_matching_target_attestations(state, epoch, config) + return _is_threshold_met_against_active_set(state, attestations, config) -def _determine_updated_justification_data(justified_epoch: Epoch, - bitfield: Bitfield, - is_epoch_justifiable: bool, - candidate_epoch: Epoch, - bit_offset: int) -> Tuple[Epoch, Bitfield]: +def _determine_updated_justification_data( + justified_epoch: Epoch, + bitfield: Bitfield, + is_epoch_justifiable: bool, + candidate_epoch: Epoch, + bit_offset: int, +) -> Tuple[Epoch, Bitfield]: if is_epoch_justifiable: return ( candidate_epoch, - Bitfield( - update_tuple_item( - bitfield, - bit_offset, - True, - ) - ) + Bitfield(update_tuple_item(bitfield, bit_offset, True)), ) else: - return ( - justified_epoch, - bitfield, - ) + return (justified_epoch, bitfield) def _determine_updated_justifications( - previous_epoch_justifiable: bool, - previous_epoch: Epoch, - current_epoch_justifiable: bool, - current_epoch: Epoch, - justified_epoch: Epoch, - justification_bits: Bitfield) -> Tuple[Epoch, Bitfield]: - ( - justified_epoch, - justification_bits, - ) = _determine_updated_justification_data( + previous_epoch_justifiable: bool, + previous_epoch: Epoch, + current_epoch_justifiable: bool, + current_epoch: Epoch, + justified_epoch: Epoch, + justification_bits: Bitfield, +) -> Tuple[Epoch, Bitfield]: + (justified_epoch, justification_bits) = _determine_updated_justification_data( justified_epoch, justification_bits, previous_epoch_justifiable, @@ -150,39 +100,22 @@ def _determine_updated_justifications( 1, ) - ( - justified_epoch, - justification_bits, - ) = _determine_updated_justification_data( - justified_epoch, - justification_bits, - current_epoch_justifiable, - current_epoch, - 0, + (justified_epoch, justification_bits) = _determine_updated_justification_data( + justified_epoch, justification_bits, current_epoch_justifiable, current_epoch, 0 ) - return ( - justified_epoch, - justification_bits, - ) + return (justified_epoch, justification_bits) -def _determine_new_justified_epoch_and_bitfield(state: BeaconState, - config: Eth2Config) -> Tuple[Epoch, Bitfield]: +def _determine_new_justified_epoch_and_bitfield( + state: BeaconState, config: Eth2Config +) -> Tuple[Epoch, Bitfield]: genesis_epoch = config.GENESIS_EPOCH previous_epoch = state.previous_epoch(config.SLOTS_PER_EPOCH, genesis_epoch) current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) - previous_epoch_justifiable = _is_epoch_justifiable( - state, - previous_epoch, - config, - ) - current_epoch_justifiable = _is_epoch_justifiable( - state, - current_epoch, - config, - ) + previous_epoch_justifiable = _is_epoch_justifiable(state, previous_epoch, config) + current_epoch_justifiable = _is_epoch_justifiable(state, current_epoch, config) ( new_current_justified_epoch, @@ -196,22 +129,16 @@ def _determine_new_justified_epoch_and_bitfield(state: BeaconState, (False,) + state.justification_bits[:-1], ) - return ( - new_current_justified_epoch, - justification_bits, - ) + return (new_current_justified_epoch, justification_bits) def _determine_new_justified_checkpoint_and_bitfield( - state: BeaconState, - config: Eth2Config) -> Tuple[Checkpoint, Bitfield]: + state: BeaconState, config: Eth2Config +) -> Tuple[Checkpoint, Bitfield]: ( new_current_justified_epoch, justification_bits, - ) = _determine_new_justified_epoch_and_bitfield( - state, - config, - ) + ) = _determine_new_justified_epoch_and_bitfield(state, config) if new_current_justified_epoch != state.current_justified_checkpoint.epoch: new_current_justified_root = get_block_root( @@ -224,56 +151,54 @@ def _determine_new_justified_checkpoint_and_bitfield( new_current_justified_root = state.current_justified_checkpoint.root return ( - Checkpoint( - epoch=new_current_justified_epoch, - root=new_current_justified_root, - ), + Checkpoint(epoch=new_current_justified_epoch, root=new_current_justified_root), justification_bits, ) -def _bitfield_matches(bitfield: Bitfield, - offset: slice) -> bool: +def _bitfield_matches(bitfield: Bitfield, offset: slice) -> bool: return all(bitfield[offset]) -def _determine_new_finalized_epoch(last_finalized_epoch: Epoch, - previous_justified_epoch: Epoch, - current_justified_epoch: Epoch, - current_epoch: Epoch, - justification_bits: Bitfield) -> Epoch: +def _determine_new_finalized_epoch( + last_finalized_epoch: Epoch, + previous_justified_epoch: Epoch, + current_justified_epoch: Epoch, + current_epoch: Epoch, + justification_bits: Bitfield, +) -> Epoch: new_finalized_epoch = last_finalized_epoch if ( - _bitfield_matches(justification_bits, slice(1, 4)) and - previous_justified_epoch + 3 == current_epoch + _bitfield_matches(justification_bits, slice(1, 4)) + and previous_justified_epoch + 3 == current_epoch ): new_finalized_epoch = previous_justified_epoch if ( - _bitfield_matches(justification_bits, slice(1, 3)) and - previous_justified_epoch + 2 == current_epoch + _bitfield_matches(justification_bits, slice(1, 3)) + and previous_justified_epoch + 2 == current_epoch ): new_finalized_epoch = previous_justified_epoch if ( - _bitfield_matches(justification_bits, slice(0, 3)) and - current_justified_epoch + 2 == current_epoch + _bitfield_matches(justification_bits, slice(0, 3)) + and current_justified_epoch + 2 == current_epoch ): new_finalized_epoch = current_justified_epoch if ( - _bitfield_matches(justification_bits, slice(0, 2)) and - current_justified_epoch + 1 == current_epoch + _bitfield_matches(justification_bits, slice(0, 2)) + and current_justified_epoch + 1 == current_epoch ): new_finalized_epoch = current_justified_epoch return new_finalized_epoch -def _determine_new_finalized_checkpoint(state: BeaconState, - justification_bits: Bitfield, - config: Eth2Config) -> Checkpoint: +def _determine_new_finalized_checkpoint( + state: BeaconState, justification_bits: Bitfield, config: Eth2Config +) -> Checkpoint: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) new_finalized_epoch = _determine_new_finalized_epoch( @@ -298,13 +223,12 @@ def _determine_new_finalized_checkpoint(state: BeaconState, else: new_finalized_root = state.finalized_checkpoint.root - return Checkpoint( - epoch=new_finalized_epoch, - root=new_finalized_root, - ) + return Checkpoint(epoch=new_finalized_epoch, root=new_finalized_root) -def process_justification_and_finalization(state: BeaconState, config: Eth2Config) -> BeaconState: +def process_justification_and_finalization( + state: BeaconState, config: Eth2Config +) -> BeaconState: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) genesis_epoch = config.GENESIS_EPOCH @@ -314,15 +238,10 @@ def process_justification_and_finalization(state: BeaconState, config: Eth2Confi ( new_current_justified_checkpoint, justification_bits, - ) = _determine_new_justified_checkpoint_and_bitfield( - state, - config, - ) + ) = _determine_new_justified_checkpoint_and_bitfield(state, config) new_finalized_checkpoint = _determine_new_finalized_checkpoint( - state, - justification_bits, - config, + state, justification_bits, config ) return state.copy( @@ -333,17 +252,13 @@ def process_justification_and_finalization(state: BeaconState, config: Eth2Confi ) -def _is_threshold_met_against_committee(state: BeaconState, - attesting_indices: Set[ValidatorIndex], - committee: Set[ValidatorIndex]) -> bool: - total_attesting_balance = get_total_balance( - state, - attesting_indices, - ) - total_committee_balance = get_total_balance( - state, - committee, - ) +def _is_threshold_met_against_committee( + state: BeaconState, + attesting_indices: Set[ValidatorIndex], + committee: Set[ValidatorIndex], +) -> bool: + total_attesting_balance = get_total_balance(state, attesting_indices) + total_committee_balance = get_total_balance(state, committee) return _bft_threshold_met(total_attesting_balance, total_committee_balance) @@ -354,47 +269,35 @@ def process_crosslinks(state: BeaconState, config: Eth2Config) -> BeaconState: new_current_crosslinks = state.current_crosslinks for epoch in (previous_epoch, current_epoch): - active_validators_indices = get_active_validator_indices(state.validators, epoch) + active_validators_indices = get_active_validator_indices( + state.validators, epoch + ) epoch_committee_count = get_committee_count( len(active_validators_indices), config.SHARD_COUNT, config.SLOTS_PER_EPOCH, config.TARGET_COMMITTEE_SIZE, ) - epoch_start_shard = get_start_shard( - state, - epoch, - CommitteeConfig(config), - ) + epoch_start_shard = get_start_shard(state, epoch, CommitteeConfig(config)) for shard_offset in range(epoch_committee_count): shard = Shard((epoch_start_shard + shard_offset) % config.SHARD_COUNT) - crosslink_committee = set(get_crosslink_committee( - state, - epoch, - shard, - CommitteeConfig(config), - )) + crosslink_committee = set( + get_crosslink_committee(state, epoch, shard, CommitteeConfig(config)) + ) if not crosslink_committee: # empty crosslink committee this epoch continue winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices( - state=state, - epoch=epoch, - shard=shard, - config=config, + state=state, epoch=epoch, shard=shard, config=config ) threshold_met = _is_threshold_met_against_committee( - state, - attesting_indices, - crosslink_committee, + state, attesting_indices, crosslink_committee ) if threshold_met: new_current_crosslinks = update_tuple_item( - new_current_crosslinks, - shard, - winning_crosslink, + new_current_crosslinks, shard, winning_crosslink ) return state.copy( @@ -403,49 +306,38 @@ def process_crosslinks(state: BeaconState, config: Eth2Config) -> BeaconState: ) -def get_attestation_deltas(state: BeaconState, - config: Eth2Config) -> Tuple[Sequence[Gwei], Sequence[Gwei]]: +def get_attestation_deltas( + state: BeaconState, config: Eth2Config +) -> Tuple[Sequence[Gwei], Sequence[Gwei]]: committee_config = CommitteeConfig(config) - rewards = tuple( - 0 for _ in range(len(state.validators)) - ) - penalties = tuple( - 0 for _ in range(len(state.validators)) - ) + rewards = tuple(0 for _ in range(len(state.validators))) + penalties = tuple(0 for _ in range(len(state.validators))) previous_epoch = state.previous_epoch(config.SLOTS_PER_EPOCH, config.GENESIS_EPOCH) total_balance = get_total_active_balance(state, config) eligible_validator_indices = tuple( - ValidatorIndex(index) for index, v in enumerate(state.validators) - if v.is_active(previous_epoch) or ( - v.slashed and previous_epoch + 1 < v.withdrawable_epoch - ) + ValidatorIndex(index) + for index, v in enumerate(state.validators) + if v.is_active(previous_epoch) + or (v.slashed and previous_epoch + 1 < v.withdrawable_epoch) ) matching_source_attestations = get_matching_source_attestations( - state, - previous_epoch, - config, + state, previous_epoch, config ) matching_target_attestations = get_matching_target_attestations( - state, - previous_epoch, - config, + state, previous_epoch, config ) matching_head_attestations = get_matching_head_attestations( - state, - previous_epoch, - config, + state, previous_epoch, config ) for attestations in ( - matching_source_attestations, - matching_target_attestations, - matching_head_attestations + matching_source_attestations, + matching_target_attestations, + matching_head_attestations, ): unslashed_attesting_indices = get_unslashed_attesting_indices( - state, - attestations, - committee_config, + state, attestations, committee_config ) attesting_balance = get_total_balance(state, unslashed_attesting_indices) for index in eligible_validator_indices: @@ -454,37 +346,28 @@ def get_attestation_deltas(state: BeaconState, rewards, index, lambda balance, delta: balance + delta, - get_base_reward( - state, - index, - config, - ) * attesting_balance // total_balance, + get_base_reward(state, index, config) + * attesting_balance + // total_balance, ) else: penalties = update_tuple_item_with_fn( penalties, index, lambda balance, delta: balance + delta, - get_base_reward( - state, - index, - config, - ), + get_base_reward(state, index, config), ) for index in get_unslashed_attesting_indices( - state, - matching_source_attestations, - committee_config, + state, matching_source_attestations, committee_config ): attestation = min( ( - a for a in matching_source_attestations - if index in get_attesting_indices( - state, - a.data, - a.aggregation_bits, - committee_config, + a + for a in matching_source_attestations + if index + in get_attesting_indices( + state, a.data, a.aggregation_bits, committee_config ) ), key=lambda a: a.inclusion_delay, @@ -503,31 +386,27 @@ def get_attestation_deltas(state: BeaconState, index, lambda balance, delta: balance + delta, ( - max_attester_reward * ( - config.SLOTS_PER_EPOCH + - config.MIN_ATTESTATION_INCLUSION_DELAY - - attestation.inclusion_delay - ) // config.SLOTS_PER_EPOCH - ) + max_attester_reward + * ( + config.SLOTS_PER_EPOCH + + config.MIN_ATTESTATION_INCLUSION_DELAY + - attestation.inclusion_delay + ) + // config.SLOTS_PER_EPOCH + ), ) finality_delay = previous_epoch - state.finalized_checkpoint.epoch if finality_delay > config.MIN_EPOCHS_TO_INACTIVITY_PENALTY: matching_target_attesting_indices = get_unslashed_attesting_indices( - state, - matching_target_attestations, - committee_config, + state, matching_target_attestations, committee_config ) for index in eligible_validator_indices: penalties = update_tuple_item_with_fn( penalties, index, lambda balance, delta: balance + delta, - BASE_REWARDS_PER_EPOCH * get_base_reward( - state, - index, - config, - ), + BASE_REWARDS_PER_EPOCH * get_base_reward(state, index, config), ) if index not in matching_target_attesting_indices: effective_balance = state.validators[index].effective_balance @@ -535,23 +414,21 @@ def get_attestation_deltas(state: BeaconState, penalties, index, lambda balance, delta: balance + delta, - effective_balance * finality_delay // config.INACTIVITY_PENALTY_QUOTIENT, + effective_balance + * finality_delay + // config.INACTIVITY_PENALTY_QUOTIENT, ) - return tuple( - Gwei(reward) for reward in rewards - ), tuple( - Gwei(penalty) for penalty in penalties + return ( + tuple(Gwei(reward) for reward in rewards), + tuple(Gwei(penalty) for penalty in penalties), ) -def get_crosslink_deltas(state: BeaconState, - config: Eth2Config) -> Tuple[Sequence[Gwei], Sequence[Gwei]]: - rewards = tuple( - 0 for _ in range(len(state.validators)) - ) - penalties = tuple( - 0 for _ in range(len(state.validators)) - ) +def get_crosslink_deltas( + state: BeaconState, config: Eth2Config +) -> Tuple[Sequence[Gwei], Sequence[Gwei]]: + rewards = tuple(0 for _ in range(len(state.validators))) + penalties = tuple(0 for _ in range(len(state.validators))) epoch = state.previous_epoch(config.SLOTS_PER_EPOCH, config.GENESIS_EPOCH) active_validators_indices = get_active_validator_indices(state.validators, epoch) epoch_committee_count = get_committee_count( @@ -560,33 +437,17 @@ def get_crosslink_deltas(state: BeaconState, config.SLOTS_PER_EPOCH, config.TARGET_COMMITTEE_SIZE, ) - epoch_start_shard = get_start_shard( - state, - epoch, - CommitteeConfig(config), - ) + epoch_start_shard = get_start_shard(state, epoch, CommitteeConfig(config)) for shard_offset in range(epoch_committee_count): shard = Shard((epoch_start_shard + shard_offset) % config.SHARD_COUNT) - crosslink_committee = set(get_crosslink_committee( - state, - epoch, - shard, - CommitteeConfig(config), - )) - _, attesting_indices = get_winning_crosslink_and_attesting_indices( - state=state, - epoch=epoch, - shard=shard, - config=config, - ) - total_attesting_balance = get_total_balance( - state, - attesting_indices, + crosslink_committee = set( + get_crosslink_committee(state, epoch, shard, CommitteeConfig(config)) ) - total_committee_balance = get_total_balance( - state, - crosslink_committee, + _, attesting_indices = get_winning_crosslink_and_attesting_indices( + state=state, epoch=epoch, shard=shard, config=config ) + total_attesting_balance = get_total_balance(state, attesting_indices) + total_committee_balance = get_total_balance(state, crosslink_committee) for index in crosslink_committee: base_reward = get_base_reward(state, index, config) if index in attesting_indices: @@ -594,7 +455,7 @@ def get_crosslink_deltas(state: BeaconState, rewards, index, lambda balance, delta: balance + delta, - base_reward * total_attesting_balance // total_committee_balance + base_reward * total_attesting_balance // total_committee_balance, ) else: penalties = update_tuple_item_with_fn( @@ -603,50 +464,57 @@ def get_crosslink_deltas(state: BeaconState, lambda balance, delta: balance + delta, base_reward, ) - return tuple( - Gwei(reward) for reward in rewards - ), tuple( - Gwei(penalty) for penalty in penalties + return ( + tuple(Gwei(reward) for reward in rewards), + tuple(Gwei(penalty) for penalty in penalties), ) -def process_rewards_and_penalties(state: BeaconState, config: Eth2Config) -> BeaconState: +def process_rewards_and_penalties( + state: BeaconState, config: Eth2Config +) -> BeaconState: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) if current_epoch == config.GENESIS_EPOCH: return state - rewards_for_attestations, penalties_for_attestations = get_attestation_deltas(state, config) - rewards_for_crosslinks, penalties_for_crosslinks = get_crosslink_deltas(state, config) + rewards_for_attestations, penalties_for_attestations = get_attestation_deltas( + state, config + ) + rewards_for_crosslinks, penalties_for_crosslinks = get_crosslink_deltas( + state, config + ) for index in range(len(state.validators)): index = ValidatorIndex(index) - state = increase_balance(state, index, Gwei( - rewards_for_attestations[index] + rewards_for_crosslinks[index] - )) - state = decrease_balance(state, index, Gwei( - penalties_for_attestations[index] + penalties_for_crosslinks[index] - )) + state = increase_balance( + state, + index, + Gwei(rewards_for_attestations[index] + rewards_for_crosslinks[index]), + ) + state = decrease_balance( + state, + index, + Gwei(penalties_for_attestations[index] + penalties_for_crosslinks[index]), + ) return state @curry -def _process_activation_eligibility_or_ejections(state: BeaconState, - validator: Validator, - config: Eth2Config) -> Validator: +def _process_activation_eligibility_or_ejections( + state: BeaconState, validator: Validator, config: Eth2Config +) -> Validator: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) if ( - validator.activation_eligibility_epoch == FAR_FUTURE_EPOCH and - validator.effective_balance == config.MAX_EFFECTIVE_BALANCE + validator.activation_eligibility_epoch == FAR_FUTURE_EPOCH + and validator.effective_balance == config.MAX_EFFECTIVE_BALANCE ): - validator = validator.copy( - activation_eligibility_epoch=current_epoch, - ) + validator = validator.copy(activation_eligibility_epoch=current_epoch) if ( - validator.is_active(current_epoch) and - validator.effective_balance <= config.EJECTION_BALANCE + validator.is_active(current_epoch) + and validator.effective_balance <= config.EJECTION_BALANCE ): validator = initiate_exit_for_validator(validator, state, config) @@ -654,9 +522,9 @@ def _process_activation_eligibility_or_ejections(state: BeaconState, @curry -def _update_validator_activation_epoch(state: BeaconState, - config: Eth2Config, - validator: Validator) -> Validator: +def _update_validator_activation_epoch( + state: BeaconState, config: Eth2Config, validator: Validator +) -> Validator: if validator.activation_epoch == FAR_FUTURE_EPOCH: return validator.copy( activation_epoch=compute_activation_exit_epoch( @@ -675,34 +543,29 @@ def process_registry_updates(state: BeaconState, config: Eth2Config) -> BeaconSt ) activation_exit_epoch = compute_activation_exit_epoch( - state.finalized_checkpoint.epoch, - config.ACTIVATION_EXIT_DELAY, + state.finalized_checkpoint.epoch, config.ACTIVATION_EXIT_DELAY ) activation_queue = sorted( ( - index for index, validator in enumerate(new_validators) if - validator.activation_eligibility_epoch != FAR_FUTURE_EPOCH and - validator.activation_epoch >= activation_exit_epoch + index + for index, validator in enumerate(new_validators) + if validator.activation_eligibility_epoch != FAR_FUTURE_EPOCH + and validator.activation_epoch >= activation_exit_epoch ), key=lambda index: new_validators[index].activation_eligibility_epoch, ) - for index in activation_queue[:get_validator_churn_limit(state, config)]: + for index in activation_queue[: get_validator_churn_limit(state, config)]: new_validators = update_tuple_item_with_fn( - new_validators, - index, - _update_validator_activation_epoch(state, config), + new_validators, index, _update_validator_activation_epoch(state, config) ) - return state.copy( - validators=new_validators, - ) + return state.copy(validators=new_validators) -def _determine_slashing_penalty(total_penalties: Gwei, - total_balance: Gwei, - balance: Gwei, - increment: Gwei) -> Gwei: +def _determine_slashing_penalty( + total_penalties: Gwei, total_balance: Gwei, balance: Gwei, increment: Gwei +) -> Gwei: penalty_numerator = balance // increment * min(total_penalties * 3, total_balance) penalty = penalty_numerator // total_balance * increment return Gwei(penalty) @@ -715,7 +578,10 @@ def process_slashings(state: BeaconState, config: Eth2Config) -> BeaconState: slashing_period = config.EPOCHS_PER_SLASHINGS_VECTOR // 2 for index, validator in enumerate(state.validators): index = ValidatorIndex(index) - if validator.slashed and current_epoch + slashing_period == validator.withdrawable_epoch: + if ( + validator.slashed + and current_epoch + slashing_period == validator.withdrawable_epoch + ): penalty = _determine_slashing_penalty( Gwei(sum(state.slashings)), total_balance, @@ -726,14 +592,18 @@ def process_slashings(state: BeaconState, config: Eth2Config) -> BeaconState: return state -def _determine_next_eth1_votes(state: BeaconState, config: Eth2Config) -> Tuple[Eth1Data, ...]: +def _determine_next_eth1_votes( + state: BeaconState, config: Eth2Config +) -> Tuple[Eth1Data, ...]: if (state.slot + 1) % config.SLOTS_PER_ETH1_VOTING_PERIOD == 0: return tuple() else: return state.eth1_data_votes -def _update_effective_balances(state: BeaconState, config: Eth2Config) -> Tuple[Validator, ...]: +def _update_effective_balances( + state: BeaconState, config: Eth2Config +) -> Tuple[Validator, ...]: half_increment = config.EFFECTIVE_BALANCE_INCREMENT // 2 new_validators = state.validators for index, validator in enumerate(state.validators): @@ -748,9 +618,7 @@ def _update_effective_balances(state: BeaconState, config: Eth2Config) -> Tuple[ new_validators = update_tuple_item_with_fn( new_validators, index, - lambda v, new_balance: v.copy( - effective_balance=new_balance, - ), + lambda v, new_balance: v.copy(effective_balance=new_balance), new_effective_balance, ) return new_validators @@ -758,35 +626,34 @@ def _update_effective_balances(state: BeaconState, config: Eth2Config) -> Tuple[ def _compute_next_start_shard(state: BeaconState, config: Eth2Config) -> Shard: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) - return (state.start_shard + get_shard_delta( - state, - current_epoch, - CommitteeConfig(config), - )) % config.SHARD_COUNT + return ( + state.start_shard + + get_shard_delta(state, current_epoch, CommitteeConfig(config)) + ) % config.SHARD_COUNT -def _compute_next_active_index_roots(state: BeaconState, config: Eth2Config) -> Tuple[Hash32, ...]: +def _compute_next_active_index_roots( + state: BeaconState, config: Eth2Config +) -> Tuple[Hash32, ...]: next_epoch = state.next_epoch(config.SLOTS_PER_EPOCH) index_root_position = ( next_epoch + config.ACTIVATION_EXIT_DELAY ) % config.EPOCHS_PER_HISTORICAL_VECTOR validator_indices_for_new_active_index_root = get_active_validator_indices( - state.validators, - Epoch(next_epoch + config.ACTIVATION_EXIT_DELAY), + state.validators, Epoch(next_epoch + config.ACTIVATION_EXIT_DELAY) ) new_active_index_root = ssz.get_hash_tree_root( validator_indices_for_new_active_index_root, ssz.sedes.List(ssz.uint64, config.VALIDATOR_REGISTRY_LIMIT), ) return update_tuple_item( - state.active_index_roots, - index_root_position, - new_active_index_root, + state.active_index_roots, index_root_position, new_active_index_root ) -def _compute_next_compact_committees_roots(state: BeaconState, - config: Eth2Config) -> Tuple[Hash32, ...]: +def _compute_next_compact_committees_roots( + state: BeaconState, config: Eth2Config +) -> Tuple[Hash32, ...]: next_epoch = state.next_epoch(config.SLOTS_PER_EPOCH) committee_root_position = next_epoch % config.EPOCHS_PER_HISTORICAL_VECTOR return update_tuple_item( @@ -799,33 +666,30 @@ def _compute_next_compact_committees_roots(state: BeaconState, def _compute_next_slashings(state: BeaconState, config: Eth2Config) -> Tuple[Gwei, ...]: next_epoch = state.next_epoch(config.SLOTS_PER_EPOCH) return update_tuple_item( - state.slashings, - next_epoch % config.EPOCHS_PER_SLASHINGS_VECTOR, - Gwei(0), + state.slashings, next_epoch % config.EPOCHS_PER_SLASHINGS_VECTOR, Gwei(0) ) -def _compute_next_randao_mixes(state: BeaconState, config: Eth2Config) -> Tuple[Hash32, ...]: +def _compute_next_randao_mixes( + state: BeaconState, config: Eth2Config +) -> Tuple[Hash32, ...]: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) next_epoch = state.next_epoch(config.SLOTS_PER_EPOCH) return update_tuple_item( state.randao_mixes, next_epoch % config.EPOCHS_PER_HISTORICAL_VECTOR, - get_randao_mix( - state, - current_epoch, - config.EPOCHS_PER_HISTORICAL_VECTOR, - ), + get_randao_mix(state, current_epoch, config.EPOCHS_PER_HISTORICAL_VECTOR), ) -def _compute_next_historical_roots(state: BeaconState, config: Eth2Config) -> Tuple[Hash32, ...]: +def _compute_next_historical_roots( + state: BeaconState, config: Eth2Config +) -> Tuple[Hash32, ...]: next_epoch = state.next_epoch(config.SLOTS_PER_EPOCH) new_historical_roots = state.historical_roots if next_epoch % (config.SLOTS_PER_HISTORICAL_ROOT // config.SLOTS_PER_EPOCH) == 0: historical_batch = HistoricalBatch( - block_roots=state.block_roots, - state_roots=state.state_roots, + block_roots=state.block_roots, state_roots=state.state_roots ) new_historical_roots += (historical_batch.hash_tree_root,) return new_historical_roots @@ -837,11 +701,7 @@ def process_final_updates(state: BeaconState, config: Eth2Config) -> BeaconState new_start_shard = _compute_next_start_shard(state, config) new_active_index_roots = _compute_next_active_index_roots(state, config) new_compact_committees_roots = _compute_next_compact_committees_roots( - state.copy( - validators=new_validators, - start_shard=new_start_shard, - ), - config, + state.copy(validators=new_validators, start_shard=new_start_shard), config ) new_slashings = _compute_next_slashings(state, config) new_randao_mixes = _compute_next_randao_mixes(state, config) diff --git a/eth2/beacon/state_machines/forks/serenity/operation_processing.py b/eth2/beacon/state_machines/forks/serenity/operation_processing.py index 5f9a28e333..fa02911666 100644 --- a/eth2/beacon/state_machines/forks/serenity/operation_processing.py +++ b/eth2/beacon/state_machines/forks/serenity/operation_processing.py @@ -1,51 +1,35 @@ -from typing import ( - Tuple, -) +from typing import Tuple -from eth_utils import ( - ValidationError, -) +from eth_utils import ValidationError -from eth2.configs import ( - Eth2Config, - CommitteeConfig, -) -from eth2.beacon.validator_status_helpers import ( - initiate_validator_exit, - slash_validator, -) -from eth2.beacon.attestation_helpers import ( - get_attestation_data_slot, -) -from eth2.beacon.committee_helpers import ( - get_beacon_proposer_index, -) -from eth2.beacon.epoch_processing_helpers import ( - increase_balance, - decrease_balance, -) +from eth2.beacon.attestation_helpers import get_attestation_data_slot +from eth2.beacon.committee_helpers import get_beacon_proposer_index +from eth2.beacon.deposit_helpers import process_deposit +from eth2.beacon.epoch_processing_helpers import decrease_balance, increase_balance from eth2.beacon.types.blocks import BaseBeaconBlock from eth2.beacon.types.pending_attestations import PendingAttestation from eth2.beacon.types.states import BeaconState -from eth2.beacon.deposit_helpers import ( - process_deposit, +from eth2.beacon.validator_status_helpers import ( + initiate_validator_exit, + slash_validator, ) +from eth2.configs import CommitteeConfig, Eth2Config from .block_validation import ( validate_attestation, validate_attester_slashing, - validate_proposer_slashing, - validate_voluntary_exit, validate_correct_number_of_deposits, + validate_proposer_slashing, validate_some_slashing, validate_transfer, validate_unique_transfers, + validate_voluntary_exit, ) -def process_proposer_slashings(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_proposer_slashings( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: if len(block.body.proposer_slashings) > config.MAX_PROPOSER_SLASHINGS: raise ValidationError( f"The block ({block}) has too many proposer slashings:\n" @@ -56,18 +40,14 @@ def process_proposer_slashings(state: BeaconState, for proposer_slashing in block.body.proposer_slashings: validate_proposer_slashing(state, proposer_slashing, config.SLOTS_PER_EPOCH) - state = slash_validator( - state, - proposer_slashing.proposer_index, - config, - ) + state = slash_validator(state, proposer_slashing.proposer_index, config) return state -def process_attester_slashings(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_attester_slashings( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: if len(block.body.attester_slashings) > config.MAX_ATTESTER_SLASHINGS: raise ValidationError( f"The block ({block}) has too many attester slashings:\n" @@ -95,24 +75,22 @@ def process_attester_slashings(state: BeaconState, attestation_2.custody_bit_0_indices + attestation_2.custody_bit_1_indices ) - eligible_indices = sorted(set(attesting_indices_1).intersection(attesting_indices_2)) + eligible_indices = sorted( + set(attesting_indices_1).intersection(attesting_indices_2) + ) for index in eligible_indices: validator = state.validators[index] if validator.is_slashable(current_epoch): - state = slash_validator( - state, - index, - config, - ) + state = slash_validator(state, index, config) slashed_any = True validate_some_slashing(slashed_any, attester_slashing) return state -def process_attestations(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_attestations( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: if len(block.body.attestations) > config.MAX_ATTESTATIONS: raise ValidationError( f"The block has too many attestations:\n" @@ -124,21 +102,10 @@ def process_attestations(state: BeaconState, new_current_epoch_attestations: Tuple[PendingAttestation, ...] = tuple() new_previous_epoch_attestations: Tuple[PendingAttestation, ...] = tuple() for attestation in block.body.attestations: - validate_attestation( - state, - attestation, - config, - ) + validate_attestation(state, attestation, config) - attestation_slot = get_attestation_data_slot( - state, - attestation.data, - config, - ) - proposer_index = get_beacon_proposer_index( - state, - CommitteeConfig(config), - ) + attestation_slot = get_attestation_data_slot(state, attestation.data, config) + proposer_index = get_beacon_proposer_index(state, CommitteeConfig(config)) pending_attestation = PendingAttestation( aggregation_bits=attestation.aggregation_bits, data=attestation.data, @@ -161,9 +128,9 @@ def process_attestations(state: BeaconState, ) -def process_deposits(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_deposits( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: if len(block.body.deposits) > config.MAX_DEPOSITS: raise ValidationError( f"The block ({block}) has too many deposits:\n" @@ -172,18 +139,14 @@ def process_deposits(state: BeaconState, ) for deposit in block.body.deposits: - state = process_deposit( - state, - deposit, - config, - ) + state = process_deposit(state, deposit, config) return state -def process_voluntary_exits(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_voluntary_exits( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: if len(block.body.voluntary_exits) > config.MAX_VOLUNTARY_EXITS: raise ValidationError( f"The block ({block}) has too many voluntary exits:\n" @@ -203,9 +166,9 @@ def process_voluntary_exits(state: BeaconState, return state -def process_transfers(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_transfers( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: if len(block.body.transfers) > config.MAX_TRANSFERS: raise ValidationError( f"The block ({block}) has too many transfers:\n" @@ -214,36 +177,21 @@ def process_transfers(state: BeaconState, ) for transfer in block.body.transfers: - validate_transfer( - state, - transfer, - config, - ) - state = decrease_balance( - state, - transfer.sender, - transfer.amount + transfer.fee, - ) - state = increase_balance( - state, - transfer.recipient, - transfer.amount, - ) + validate_transfer(state, transfer, config) + state = decrease_balance(state, transfer.sender, transfer.amount + transfer.fee) + state = increase_balance(state, transfer.recipient, transfer.amount) state = increase_balance( state, - get_beacon_proposer_index( - state, - CommitteeConfig(config), - ), + get_beacon_proposer_index(state, CommitteeConfig(config)), transfer.fee, ) return state -def process_operations(state: BeaconState, - block: BaseBeaconBlock, - config: Eth2Config) -> BeaconState: +def process_operations( + state: BeaconState, block: BaseBeaconBlock, config: Eth2Config +) -> BeaconState: validate_correct_number_of_deposits(state, block, config) validate_unique_transfers(state, block, config) diff --git a/eth2/beacon/state_machines/forks/serenity/slot_processing.py b/eth2/beacon/state_machines/forks/serenity/slot_processing.py index a716a32478..88926272a6 100644 --- a/eth2/beacon/state_machines/forks/serenity/slot_processing.py +++ b/eth2/beacon/state_machines/forks/serenity/slot_processing.py @@ -1,42 +1,24 @@ -from typing import ( - Sequence, - Tuple, -) - -from eth.constants import ( - ZERO_HASH32, -) - -from eth_typing import ( - Hash32, -) -from eth_utils import ( - ValidationError, -) +from typing import Sequence, Tuple + +from eth.constants import ZERO_HASH32 +from eth_typing import Hash32 +from eth_utils import ValidationError from eth2._utils.tuple import update_tuple_item -from eth2.configs import ( - Eth2Config, -) -from eth2.beacon.typing import ( - Slot, -) from eth2.beacon.types.states import BeaconState +from eth2.beacon.typing import Slot +from eth2.configs import Eth2Config -from .epoch_processing import ( - process_epoch, -) +from .epoch_processing import process_epoch -def _update_historical_root(roots: Tuple[Hash32, ...], - index: Slot, - slots_per_historical_root: int, - new_root: Hash32) -> Sequence[Hash32]: - return update_tuple_item( - roots, - index % slots_per_historical_root, - new_root, - ) +def _update_historical_root( + roots: Tuple[Hash32, ...], + index: Slot, + slots_per_historical_root: int, + new_root: Hash32, +) -> Sequence[Hash32]: + return update_tuple_item(roots, index % slots_per_historical_root, new_root) def _process_slot(state: BeaconState, config: Eth2Config) -> BeaconState: @@ -44,18 +26,13 @@ def _process_slot(state: BeaconState, config: Eth2Config) -> BeaconState: previous_state_root = state.hash_tree_root updated_state_roots = _update_historical_root( - state.state_roots, - state.slot, - slots_per_historical_root, - previous_state_root, + state.state_roots, state.slot, slots_per_historical_root, previous_state_root ) if state.latest_block_header.state_root == ZERO_HASH32: latest_block_header = state.latest_block_header state = state.copy( - latest_block_header=latest_block_header.copy( - state_root=previous_state_root, - ), + latest_block_header=latest_block_header.copy(state_root=previous_state_root) ) updated_block_roots = _update_historical_root( @@ -65,16 +42,11 @@ def _process_slot(state: BeaconState, config: Eth2Config) -> BeaconState: state.latest_block_header.signing_root, ) - return state.copy( - block_roots=updated_block_roots, - state_roots=updated_state_roots, - ) + return state.copy(block_roots=updated_block_roots, state_roots=updated_state_roots) def _increment_slot(state: BeaconState) -> BeaconState: - return state.copy( - slot=state.slot + 1, - ) + return state.copy(slot=state.slot + 1) def process_slots(state: BeaconState, slot: Slot, config: Eth2Config) -> BeaconState: diff --git a/eth2/beacon/state_machines/forks/serenity/state_transitions.py b/eth2/beacon/state_machines/forks/serenity/state_transitions.py index 27e353f3e7..fdab7486a5 100644 --- a/eth2/beacon/state_machines/forks/serenity/state_transitions.py +++ b/eth2/beacon/state_machines/forks/serenity/state_transitions.py @@ -1,17 +1,11 @@ -from eth2.configs import ( - Eth2Config, -) from eth2.beacon.state_machines.state_transitions import BaseStateTransition from eth2.beacon.types.blocks import BaseBeaconBlock from eth2.beacon.types.states import BeaconState from eth2.beacon.typing import Slot +from eth2.configs import Eth2Config -from .block_processing import ( - process_block, -) -from .slot_processing import ( - process_slots, -) +from .block_processing import process_block +from .slot_processing import process_slots class SerenityStateTransition(BaseStateTransition): @@ -20,11 +14,13 @@ class SerenityStateTransition(BaseStateTransition): def __init__(self, config: Eth2Config): self.config = config - def apply_state_transition(self, - state: BeaconState, - block: BaseBeaconBlock=None, - future_slot: Slot=None, - check_proposer_signature: bool=True) -> BeaconState: + def apply_state_transition( + self, + state: BeaconState, + block: BaseBeaconBlock = None, + future_slot: Slot = None, + check_proposer_signature: bool = True, + ) -> BeaconState: # NOTE: Callers should request a transition to some slot past the ``state.slot``. # This can be done by providing either a ``block`` *or* a ``future_slot``. # We enforce this invariant with the assertion on ``target_slot``. diff --git a/eth2/beacon/state_machines/forks/xiao_long_bao/__init__.py b/eth2/beacon/state_machines/forks/xiao_long_bao/__init__.py index 190df4e9e3..50ffff0e82 100644 --- a/eth2/beacon/state_machines/forks/xiao_long_bao/__init__.py +++ b/eth2/beacon/state_machines/forks/xiao_long_bao/__init__.py @@ -1,10 +1,6 @@ +from eth2.beacon.fork_choice.higher_slot import higher_slot_scoring from eth2.beacon.fork_choice.scoring import ScoringFn as ForkChoiceScoringFn -from eth2.beacon.fork_choice.higher_slot import ( - higher_slot_scoring, -) -from eth2.beacon.state_machines.base import ( - BeaconStateMachine, -) +from eth2.beacon.state_machines.base import BeaconStateMachine from eth2.beacon.state_machines.forks.serenity.blocks import ( SerenityBeaconBlock, create_serenity_block_from_parent, @@ -12,24 +8,16 @@ from eth2.beacon.state_machines.forks.serenity.state_transitions import ( SerenityStateTransition, ) -from eth2.beacon.state_machines.forks.serenity.states import ( - SerenityBeaconState, -) -from eth2.beacon.types.blocks import ( - BaseBeaconBlock, -) -from eth2.beacon.typing import ( - FromBlockParams, -) +from eth2.beacon.state_machines.forks.serenity.states import SerenityBeaconState +from eth2.beacon.types.blocks import BaseBeaconBlock +from eth2.beacon.typing import FromBlockParams -from .configs import ( - XIAO_LONG_BAO_CONFIG, -) +from .configs import XIAO_LONG_BAO_CONFIG class XiaoLongBaoStateMachine(BeaconStateMachine): # fork name - fork = 'xiao_long_bao' + fork = "xiao_long_bao" config = XIAO_LONG_BAO_CONFIG # classes @@ -39,8 +27,9 @@ class XiaoLongBaoStateMachine(BeaconStateMachine): # methods @staticmethod - def create_block_from_parent(parent_block: BaseBeaconBlock, - block_params: FromBlockParams) -> BaseBeaconBlock: + def create_block_from_parent( + parent_block: BaseBeaconBlock, block_params: FromBlockParams + ) -> BaseBeaconBlock: return create_serenity_block_from_parent(parent_block, block_params) def get_fork_choice_scoring(self) -> ForkChoiceScoringFn: diff --git a/eth2/beacon/state_machines/forks/xiao_long_bao/configs.py b/eth2/beacon/state_machines/forks/xiao_long_bao/configs.py index 964b0a2bce..54534b32e9 100644 --- a/eth2/beacon/state_machines/forks/xiao_long_bao/configs.py +++ b/eth2/beacon/state_machines/forks/xiao_long_bao/configs.py @@ -1,7 +1,4 @@ -from eth2.beacon.state_machines.forks.serenity.configs import ( - SERENITY_CONFIG, -) - +from eth2.beacon.state_machines.forks.serenity.configs import SERENITY_CONFIG XIAO_LONG_BAO_CONFIG = SERENITY_CONFIG._replace( SLOTS_PER_EPOCH=4, @@ -9,7 +6,7 @@ SHARD_COUNT=4, MIN_ATTESTATION_INCLUSION_DELAY=2, # Shorten the HISTORICAL lengths to make genesis yaml lighter - EPOCHS_PER_HISTORICAL_VECTOR=2**7, - SLOTS_PER_HISTORICAL_ROOT=2**4, - EPOCHS_PER_SLASHINGS_VECTOR=2**4, + EPOCHS_PER_HISTORICAL_VECTOR=2 ** 7, + SLOTS_PER_HISTORICAL_ROOT=2 ** 4, + EPOCHS_PER_SLASHINGS_VECTOR=2 ** 4, ) diff --git a/eth2/beacon/state_machines/state_transitions.py b/eth2/beacon/state_machines/state_transitions.py index 20f9f19b9d..532bbfc994 100644 --- a/eth2/beacon/state_machines/state_transitions.py +++ b/eth2/beacon/state_machines/state_transitions.py @@ -1,16 +1,11 @@ -from abc import ( - ABC, - abstractmethod, -) +from abc import ABC, abstractmethod -from eth._utils.datatypes import ( - Configurable, -) +from eth._utils.datatypes import Configurable -from eth2.configs import Eth2Config from eth2.beacon.types.blocks import BaseBeaconBlock from eth2.beacon.types.states import BeaconState from eth2.beacon.typing import Slot +from eth2.configs import Eth2Config class BaseStateTransition(Configurable, ABC): @@ -20,11 +15,13 @@ def __init__(self, config: Eth2Config): self.config = config @abstractmethod - def apply_state_transition(self, - state: BeaconState, - block: BaseBeaconBlock=None, - future_slot: Slot=None, - check_proposer_signature: bool=True) -> BeaconState: + def apply_state_transition( + self, + state: BeaconState, + block: BaseBeaconBlock = None, + future_slot: Slot = None, + check_proposer_signature: bool = True, + ) -> BeaconState: """ Applies the state transition function to ``state`` based on data in ``block`` or ``future_slot``. The ``block.slot`` or the ``future_slot`` diff --git a/eth2/beacon/tools/builder/committee_assignment.py b/eth2/beacon/tools/builder/committee_assignment.py index cd60c60485..7a78710f2b 100644 --- a/eth2/beacon/tools/builder/committee_assignment.py +++ b/eth2/beacon/tools/builder/committee_assignment.py @@ -1,54 +1,40 @@ -from typing import ( - Tuple, - NamedTuple, -) +from typing import NamedTuple, Tuple -from eth_utils import ( - ValidationError, -) +from eth_utils import ValidationError -from eth2.configs import ( - CommitteeConfig, - Eth2Config, -) from eth2.beacon.committee_helpers import ( get_beacon_proposer_index, - get_crosslink_committee, get_committee_count, + get_crosslink_committee, get_start_shard, ) +from eth2.beacon.exceptions import NoCommitteeAssignment from eth2.beacon.helpers import ( - get_active_validator_indices, compute_start_slot_of_epoch, + get_active_validator_indices, ) from eth2.beacon.types.states import BeaconState -from eth2.beacon.typing import ( - Shard, - Slot, - ValidatorIndex, - Epoch, -) -from eth2.beacon.exceptions import ( - NoCommitteeAssignment, -) - +from eth2.beacon.typing import Epoch, Shard, Slot, ValidatorIndex +from eth2.configs import CommitteeConfig, Eth2Config CommitteeAssignment = NamedTuple( - 'CommitteeAssignment', + "CommitteeAssignment", ( - ('committee', Tuple[ValidatorIndex, ...]), - ('shard', Shard), - ('slot', Slot), - ('is_proposer', bool) - ) + ("committee", Tuple[ValidatorIndex, ...]), + ("shard", Shard), + ("slot", Slot), + ("is_proposer", bool), + ), ) # TODO(ralexstokes) refactor using other helpers, also likely to have duplicated in tests -def get_committee_assignment(state: BeaconState, - config: Eth2Config, - epoch: Epoch, - validator_index: ValidatorIndex) -> CommitteeAssignment: +def get_committee_assignment( + state: BeaconState, + config: Eth2Config, + epoch: Epoch, + validator_index: ValidatorIndex, +) -> CommitteeAssignment: """ Return the ``CommitteeAssignment`` in the ``epoch`` for ``validator_index``. ``CommitteeAssignment.committee`` is the tuple array of validators in the committee @@ -64,16 +50,16 @@ def get_committee_assignment(state: BeaconState, ) active_validators = get_active_validator_indices(state.validators, epoch) - committees_per_slot = get_committee_count( - len(active_validators), - config.SHARD_COUNT, - config.SLOTS_PER_EPOCH, - config.TARGET_COMMITTEE_SIZE, - ) // config.SLOTS_PER_EPOCH - epoch_start_slot = compute_start_slot_of_epoch( - epoch, - config.SLOTS_PER_EPOCH, + committees_per_slot = ( + get_committee_count( + len(active_validators), + config.SHARD_COUNT, + config.SLOTS_PER_EPOCH, + config.TARGET_COMMITTEE_SIZE, + ) + // config.SLOTS_PER_EPOCH ) + epoch_start_slot = compute_start_slot_of_epoch(epoch, config.SLOTS_PER_EPOCH) epoch_start_shard = get_start_shard(state, epoch, CommitteeConfig(config)) for slot in range(epoch_start_slot, epoch_start_slot + config.SLOTS_PER_EPOCH): @@ -81,14 +67,15 @@ def get_committee_assignment(state: BeaconState, slot_start_shard = (epoch_start_shard + offset) % config.SHARD_COUNT for i in range(committees_per_slot): shard = Shard((slot_start_shard + i) % config.SHARD_COUNT) - committee = get_crosslink_committee(state, epoch, shard, CommitteeConfig(config)) + committee = get_crosslink_committee( + state, epoch, shard, CommitteeConfig(config) + ) if validator_index in committee: is_proposer = validator_index == get_beacon_proposer_index( - state.copy( - slot=slot, - ), - CommitteeConfig(config), + state.copy(slot=slot), CommitteeConfig(config) + ) + return CommitteeAssignment( + committee, Shard(shard), Slot(slot), is_proposer ) - return CommitteeAssignment(committee, Shard(shard), Slot(slot), is_proposer) raise NoCommitteeAssignment diff --git a/eth2/beacon/tools/builder/initializer.py b/eth2/beacon/tools/builder/initializer.py index 276564c83b..9f6a351ed5 100644 --- a/eth2/beacon/tools/builder/initializer.py +++ b/eth2/beacon/tools/builder/initializer.py @@ -1,64 +1,32 @@ -from typing import ( - cast, - Dict, - Sequence, - Tuple, - Type, -) - -from eth_typing import ( - BLSPubkey, - Hash32, -) - -from eth.constants import ( - ZERO_HASH32, -) - -from eth2._utils.hash import ( - hash_eth2, -) -from eth2._utils.merkle.common import ( - get_merkle_proof, -) -from eth2._utils.merkle.sparse import ( - calc_merkle_tree_from_leaves, - get_root, -) -from eth2.configs import Eth2Config -from eth2.beacon.constants import ( - ZERO_TIMESTAMP, -) -from eth2.beacon.genesis import ( - get_genesis_block, - initialize_beacon_state_from_eth1, -) -from eth2.beacon.types.blocks import ( - BaseBeaconBlock, -) -from eth2.beacon.types.deposits import Deposit +from typing import Dict, Sequence, Tuple, Type, cast + +from eth.constants import ZERO_HASH32 +from eth_typing import BLSPubkey, Hash32 + +from eth2._utils.hash import hash_eth2 +from eth2._utils.merkle.common import get_merkle_proof +from eth2._utils.merkle.sparse import calc_merkle_tree_from_leaves, get_root +from eth2.beacon.constants import ZERO_TIMESTAMP +from eth2.beacon.genesis import get_genesis_block, initialize_beacon_state_from_eth1 +from eth2.beacon.tools.builder.validator import create_mock_deposit_data +from eth2.beacon.types.blocks import BaseBeaconBlock from eth2.beacon.types.deposit_data import DepositData # noqa: F401 +from eth2.beacon.types.deposits import Deposit from eth2.beacon.types.eth1_data import Eth1Data from eth2.beacon.types.states import BeaconState from eth2.beacon.types.validators import Validator -from eth2.beacon.typing import ( - Timestamp, -) -from eth2.beacon.validator_status_helpers import ( - activate_validator, -) - -from eth2.beacon.tools.builder.validator import ( - create_mock_deposit_data, -) +from eth2.beacon.typing import Timestamp +from eth2.beacon.validator_status_helpers import activate_validator +from eth2.configs import Eth2Config def create_mock_deposits_and_root( - pubkeys: Sequence[BLSPubkey], - keymap: Dict[BLSPubkey, int], - config: Eth2Config, - withdrawal_credentials: Sequence[Hash32]=None, - leaves: Sequence[Hash32]=None) -> Tuple[Tuple[Deposit, ...], Hash32]: + pubkeys: Sequence[BLSPubkey], + keymap: Dict[BLSPubkey, int], + config: Eth2Config, + withdrawal_credentials: Sequence[Hash32] = None, + leaves: Sequence[Hash32] = None, +) -> Tuple[Tuple[Deposit, ...], Hash32]: """ Creates as many new deposits as there are keys in ``pubkeys``. @@ -69,7 +37,9 @@ def create_mock_deposits_and_root( empty, this function simulates the genesis deposit tree calculation. """ if not withdrawal_credentials: - withdrawal_credentials = tuple(Hash32(b'\x22' * 32) for _ in range(len(pubkeys))) + withdrawal_credentials = tuple( + Hash32(b"\x22" * 32) for _ in range(len(pubkeys)) + ) else: assert len(withdrawal_credentials) == len(pubkeys) if not leaves: @@ -93,12 +63,10 @@ def create_mock_deposits_and_root( deposits: Tuple[Deposit, ...] = tuple() for index, data in enumerate(deposit_datas): length_mix_in = Hash32((index + 1).to_bytes(32, byteorder="little")) - tree = calc_merkle_tree_from_leaves(deposit_data_leaves[:index + 1]) + tree = calc_merkle_tree_from_leaves(deposit_data_leaves[: index + 1]) deposit = Deposit( - proof=( - get_merkle_proof(tree, item_index=index) + (length_mix_in,) - ), + proof=(get_merkle_proof(tree, item_index=index) + (length_mix_in,)), data=data, ) deposits += (deposit,) @@ -107,12 +75,14 @@ def create_mock_deposits_and_root( return deposits, hash_eth2(tree_root + length_mix_in) -def create_mock_deposit(state: BeaconState, - pubkey: BLSPubkey, - keymap: Dict[BLSPubkey, int], - withdrawal_credentials: Hash32, - config: Eth2Config, - leaves: Sequence[Hash32]=None) -> Tuple[BeaconState, Deposit]: +def create_mock_deposit( + state: BeaconState, + pubkey: BLSPubkey, + keymap: Dict[BLSPubkey, int], + withdrawal_credentials: Hash32, + config: Eth2Config, + leaves: Sequence[Hash32] = None, +) -> Tuple[BeaconState, Deposit]: deposits, root = create_mock_deposits_and_root( (pubkey,), keymap, @@ -136,15 +106,14 @@ def create_mock_deposit(state: BeaconState, def create_mock_genesis( - pubkeys: Sequence[BLSPubkey], - config: Eth2Config, - keymap: Dict[BLSPubkey, int], - genesis_block_class: Type[BaseBeaconBlock], - genesis_time: Timestamp=ZERO_TIMESTAMP) -> Tuple[BeaconState, BaseBeaconBlock]: + pubkeys: Sequence[BLSPubkey], + config: Eth2Config, + keymap: Dict[BLSPubkey, int], + genesis_block_class: Type[BaseBeaconBlock], + genesis_time: Timestamp = ZERO_TIMESTAMP, +) -> Tuple[BeaconState, BaseBeaconBlock]: genesis_deposits, deposit_root = create_mock_deposits_and_root( - pubkeys=pubkeys, - keymap=keymap, - config=config, + pubkeys=pubkeys, keymap=keymap, config=config ) genesis_eth1_data = Eth1Data( @@ -161,28 +130,23 @@ def create_mock_genesis( ) block = get_genesis_block( - genesis_state_root=state.hash_tree_root, - block_class=genesis_block_class, + genesis_state_root=state.hash_tree_root, block_class=genesis_block_class ) assert len(state.validators) == len(pubkeys) return state, block -def create_mock_validator(pubkey: BLSPubkey, - config: Eth2Config, - withdrawal_credentials: Hash32=ZERO_HASH32, - is_active: bool=True) -> Validator: +def create_mock_validator( + pubkey: BLSPubkey, + config: Eth2Config, + withdrawal_credentials: Hash32 = ZERO_HASH32, + is_active: bool = True, +) -> Validator: validator = Validator.create_pending_validator( - pubkey, - withdrawal_credentials, - config.MAX_EFFECTIVE_BALANCE, - config, + pubkey, withdrawal_credentials, config.MAX_EFFECTIVE_BALANCE, config ) if is_active: - return activate_validator( - validator, - config.GENESIS_EPOCH, - ) + return activate_validator(validator, config.GENESIS_EPOCH) else: return validator diff --git a/eth2/beacon/tools/builder/proposer.py b/eth2/beacon/tools/builder/proposer.py index 6ad1fdb961..5570bf5081 100644 --- a/eth2/beacon/tools/builder/proposer.py +++ b/eth2/beacon/tools/builder/proposer.py @@ -1,56 +1,24 @@ -from typing import ( - Dict, - Sequence, - Type, -) -from eth_typing import ( - BLSPubkey, - BLSSignature, -) -import ssz - +from typing import Dict, Sequence, Type -from eth2.configs import ( - CommitteeConfig, - Eth2Config, -) -from eth2.beacon.signature_domain import ( - SignatureDomain, -) -from eth2.beacon.committee_helpers import ( - get_beacon_proposer_index, -) -from eth2.beacon.exceptions import ( - ProposerIndexError, -) -from eth2.beacon.helpers import ( - compute_epoch_of_slot, -) -from eth2.beacon.state_machines.base import ( - BaseBeaconStateMachine, -) +from eth_typing import BLSPubkey, BLSSignature +import ssz +from eth2.beacon.committee_helpers import get_beacon_proposer_index +from eth2.beacon.exceptions import ProposerIndexError +from eth2.beacon.helpers import compute_epoch_of_slot +from eth2.beacon.signature_domain import SignatureDomain +from eth2.beacon.state_machines.base import BaseBeaconStateMachine +from eth2.beacon.tools.builder.validator import sign_transaction from eth2.beacon.types.attestations import Attestation -from eth2.beacon.types.blocks import ( - BaseBeaconBlock, - BeaconBlockBody, -) +from eth2.beacon.types.blocks import BaseBeaconBlock, BeaconBlockBody from eth2.beacon.types.states import BeaconState -from eth2.beacon.typing import ( - FromBlockParams, - Slot, - ValidatorIndex, -) +from eth2.beacon.typing import FromBlockParams, Slot, ValidatorIndex +from eth2.configs import CommitteeConfig, Eth2Config -from eth2.beacon.tools.builder.validator import ( - sign_transaction, -) - -def _generate_randao_reveal(privkey: int, - slot: Slot, - state: BeaconState, - config: Eth2Config) -> BLSSignature: +def _generate_randao_reveal( + privkey: int, slot: Slot, state: BeaconState, config: Eth2Config +) -> BLSSignature: """ Return the RANDAO reveal for the validator represented by ``privkey``. The current implementation requires a validator to provide the BLS signature @@ -71,15 +39,11 @@ def _generate_randao_reveal(privkey: int, return randao_reveal -def validate_proposer_index(state: BeaconState, - config: Eth2Config, - slot: Slot, - validator_index: ValidatorIndex) -> None: +def validate_proposer_index( + state: BeaconState, config: Eth2Config, slot: Slot, validator_index: ValidatorIndex +) -> None: beacon_proposer_index = get_beacon_proposer_index( - state.copy( - slot=slot, - ), - CommitteeConfig(config), + state.copy(slot=slot), CommitteeConfig(config) ) if validator_index != beacon_proposer_index: @@ -87,17 +51,18 @@ def validate_proposer_index(state: BeaconState, def create_block_on_state( - *, - state: BeaconState, - config: Eth2Config, - state_machine: BaseBeaconStateMachine, - block_class: Type[BaseBeaconBlock], - parent_block: BaseBeaconBlock, - slot: Slot, - validator_index: ValidatorIndex, - privkey: int, - attestations: Sequence[Attestation], - check_proposer_index: bool=True) -> BaseBeaconBlock: + *, + state: BeaconState, + config: Eth2Config, + state_machine: BaseBeaconStateMachine, + block_class: Type[BaseBeaconBlock], + parent_block: BaseBeaconBlock, + slot: Slot, + validator_index: ValidatorIndex, + privkey: int, + attestations: Sequence[Attestation], + check_proposer_index: bool = True +) -> BaseBeaconBlock: """ Create a beacon block with the given parameters. """ @@ -105,25 +70,22 @@ def create_block_on_state( validate_proposer_index(state, config, slot, validator_index) block = block_class.from_parent( - parent_block=parent_block, - block_params=FromBlockParams(slot=slot), + parent_block=parent_block, block_params=FromBlockParams(slot=slot) ) # TODO: Add more operations randao_reveal = _generate_randao_reveal(privkey, slot, state, config) eth1_data = state.eth1_data body = BeaconBlockBody( - randao_reveal=randao_reveal, - eth1_data=eth1_data, - attestations=attestations, + randao_reveal=randao_reveal, eth1_data=eth1_data, attestations=attestations ) - block = block.copy( - body=body, - ) + block = block.copy(body=body) # Apply state transition to get state root - state, block = state_machine.import_block(block, check_proposer_signature=False) + state, block = state_machine.import_block( + block, state, check_proposer_signature=False + ) # Sign signature = sign_transaction( @@ -135,50 +97,43 @@ def create_block_on_state( slots_per_epoch=config.SLOTS_PER_EPOCH, ) - block = block.copy( - signature=signature, - ) + block = block.copy(signature=signature) return block -def advance_to_slot(state_machine: BaseBeaconStateMachine, - state: BeaconState, - slot: Slot) -> BeaconState: +def advance_to_slot( + state_machine: BaseBeaconStateMachine, state: BeaconState, slot: Slot +) -> BeaconState: # advance the state to the ``slot``. state_transition = state_machine.state_transition state = state_transition.apply_state_transition(state, future_slot=slot) return state -def _get_proposer_index(state: BeaconState, - config: Eth2Config) -> ValidatorIndex: - proposer_index = get_beacon_proposer_index( - state, - CommitteeConfig(config), - ) +def _get_proposer_index(state: BeaconState, config: Eth2Config) -> ValidatorIndex: + proposer_index = get_beacon_proposer_index(state, CommitteeConfig(config)) return proposer_index -def create_mock_block(*, - state: BeaconState, - config: Eth2Config, - state_machine: BaseBeaconStateMachine, - block_class: Type[BaseBeaconBlock], - parent_block: BaseBeaconBlock, - keymap: Dict[BLSPubkey, int], - slot: Slot=None, - attestations: Sequence[Attestation]=()) -> BaseBeaconBlock: +def create_mock_block( + *, + state: BeaconState, + config: Eth2Config, + state_machine: BaseBeaconStateMachine, + block_class: Type[BaseBeaconBlock], + parent_block: BaseBeaconBlock, + keymap: Dict[BLSPubkey, int], + slot: Slot = None, + attestations: Sequence[Attestation] = () +) -> BaseBeaconBlock: """ Create a mocking block at ``slot`` with the given block parameters and ``keymap``. Note that it doesn't return the correct ``state_root``. """ future_state = advance_to_slot(state_machine, state, slot) - proposer_index = _get_proposer_index( - future_state, - config - ) + proposer_index = _get_proposer_index(future_state, config) proposer_pubkey = state.validators[proposer_index].pubkey proposer_privkey = keymap[proposer_pubkey] diff --git a/eth2/beacon/tools/builder/state.py b/eth2/beacon/tools/builder/state.py index 6db824ac59..3cd9e3676a 100644 --- a/eth2/beacon/tools/builder/state.py +++ b/eth2/beacon/tools/builder/state.py @@ -1,20 +1,14 @@ -from typing import ( - Sequence, -) +from typing import Sequence -from eth2.configs import Eth2Config from eth2.beacon.genesis import ( - state_with_validator_digests, initialize_beacon_state_from_eth1, + state_with_validator_digests, ) from eth2.beacon.types.eth1_data import Eth1Data from eth2.beacon.types.states import BeaconState from eth2.beacon.types.validators import Validator -from eth2.beacon.typing import ( - Epoch, - Gwei, - Timestamp, -) +from eth2.beacon.typing import Epoch, Gwei, Timestamp +from eth2.configs import Eth2Config def _check_distinct_pubkeys(validators: Sequence[Validator]) -> None: @@ -22,8 +16,9 @@ def _check_distinct_pubkeys(validators: Sequence[Validator]) -> None: assert len(set(pubkeys)) == len(pubkeys) -def _check_no_missing_balances(validators: Sequence[Validator], - balances: Sequence[Gwei]) -> None: +def _check_no_missing_balances( + validators: Sequence[Validator], balances: Sequence[Gwei] +) -> None: assert len(validators) == len(balances) @@ -33,23 +28,27 @@ def _check_sufficient_balance(balances: Sequence[Gwei], threshold: Gwei) -> None assert False -def _check_activated_validators(validators: Sequence[Validator], - genesis_epoch: Epoch) -> None: +def _check_activated_validators( + validators: Sequence[Validator], genesis_epoch: Epoch +) -> None: for validator in validators: assert validator.activation_eligibility_epoch == genesis_epoch assert validator.activation_epoch == genesis_epoch -def _check_correct_eth1_data(eth1_data: Eth1Data, - validators: Sequence[Validator]) -> None: +def _check_correct_eth1_data( + eth1_data: Eth1Data, validators: Sequence[Validator] +) -> None: assert eth1_data.deposit_count == len(validators) -def create_mock_genesis_state_from_validators(genesis_time: Timestamp, - genesis_eth1_data: Eth1Data, - genesis_validators: Sequence[Validator], - genesis_balances: Sequence[Gwei], - config: Eth2Config) -> BeaconState: +def create_mock_genesis_state_from_validators( + genesis_time: Timestamp, + genesis_eth1_data: Eth1Data, + genesis_validators: Sequence[Validator], + genesis_balances: Sequence[Gwei], + config: Eth2Config, +) -> BeaconState: """ Produce a valid genesis state without creating the corresponding deposits. @@ -73,13 +72,10 @@ def create_mock_genesis_state_from_validators(genesis_time: Timestamp, state_with_validators = empty_state.copy( eth1_deposit_index=empty_state.eth1_deposit_index + len(genesis_validators), eth1_data=empty_state.eth1_data.copy( - deposit_count=empty_state.eth1_data.deposit_count + len(genesis_validators), + deposit_count=empty_state.eth1_data.deposit_count + len(genesis_validators) ), validators=genesis_validators, balances=genesis_balances, ) - return state_with_validator_digests( - state_with_validators, - config, - ) + return state_with_validator_digests(state_with_validators, config) diff --git a/eth2/beacon/tools/builder/validator.py b/eth2/beacon/tools/builder/validator.py index 35575a914a..7919f93685 100644 --- a/eth2/beacon/tools/builder/validator.py +++ b/eth2/beacon/tools/builder/validator.py @@ -1,62 +1,37 @@ import math import random +from typing import Dict, Iterable, Sequence, Tuple -from typing import ( - Dict, - Iterable, - Sequence, - Tuple, -) - -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) - -from eth.constants import ( - ZERO_HASH32, -) -from eth_utils import ( - to_tuple, -) -from eth_utils.toolz import ( - pipe, - keymap as keymapper, -) -from eth2._utils.bls import bls, Domain +from eth.constants import ZERO_HASH32 +from eth_typing import BLSPubkey, BLSSignature, Hash32 +from eth_utils import to_tuple +from eth_utils.toolz import keymap as keymapper +from eth_utils.toolz import pipe -from eth2._utils.bitfield import ( - get_empty_bitfield, - set_voted, -) -from eth2.configs import ( - CommitteeConfig, - Eth2Config, -) -from eth2.beacon.signature_domain import ( - SignatureDomain, -) +from eth2._utils.bitfield import get_empty_bitfield, set_voted +from eth2._utils.bls import Domain, bls from eth2.beacon.committee_helpers import ( - get_crosslink_committee, get_committee_count, - get_start_shard, + get_crosslink_committee, get_shard_delta, + get_start_shard, ) from eth2.beacon.helpers import ( compute_domain, - get_block_root_at_slot, - get_block_root, - get_domain, - compute_start_slot_of_epoch, compute_epoch_of_slot, + compute_start_slot_of_epoch, get_active_validator_indices, + get_block_root, + get_block_root_at_slot, + get_domain, ) -from eth2.beacon.types.attestations import Attestation, IndexedAttestation +from eth2.beacon.signature_domain import SignatureDomain +from eth2.beacon.state_machines.base import BaseBeaconStateMachine from eth2.beacon.types.attestation_data import AttestationData from eth2.beacon.types.attestation_data_and_custody_bits import ( AttestationDataAndCustodyBit, ) +from eth2.beacon.types.attestations import Attestation, IndexedAttestation from eth2.beacon.types.attester_slashings import AttesterSlashing from eth2.beacon.types.blocks import BeaconBlockHeader from eth2.beacon.types.checkpoints import Checkpoint @@ -78,43 +53,42 @@ default_epoch, default_shard, ) -from eth2.beacon.state_machines.base import ( - BaseBeaconStateMachine, -) +from eth2.configs import CommitteeConfig, Eth2Config # TODO(ralexstokes) merge w/ below -def _mk_pending_attestation(bitfield: Bitfield=default_bitfield, - target_root: Hash32=ZERO_HASH32, - target_epoch: Epoch=default_epoch, - shard: Shard=default_shard, - start_epoch: Epoch=default_epoch, - parent_root: Hash32=ZERO_HASH32, - data_root: Hash32=ZERO_HASH32) -> PendingAttestation: +def _mk_pending_attestation( + bitfield: Bitfield = default_bitfield, + target_root: Hash32 = ZERO_HASH32, + target_epoch: Epoch = default_epoch, + shard: Shard = default_shard, + start_epoch: Epoch = default_epoch, + parent_root: Hash32 = ZERO_HASH32, + data_root: Hash32 = ZERO_HASH32, +) -> PendingAttestation: return PendingAttestation( aggregation_bits=bitfield, data=AttestationData( - target=Checkpoint( - epoch=target_epoch, - root=target_root, - ), + target=Checkpoint(epoch=target_epoch, root=target_root), crosslink=Crosslink( shard=shard, parent_root=parent_root, start_epoch=start_epoch, end_epoch=target_epoch, data_root=data_root, - ) + ), ), ) -def mk_pending_attestation_from_committee(parent: Crosslink, - committee_size: int, - shard: Shard, - target_epoch: Epoch=default_epoch, - target_root: Hash32=ZERO_HASH32, - data_root: Hash32=ZERO_HASH32) -> PendingAttestation: +def mk_pending_attestation_from_committee( + parent: Crosslink, + committee_size: int, + shard: Shard, + target_epoch: Epoch = default_epoch, + target_root: Hash32 = ZERO_HASH32, + data_root: Hash32 = ZERO_HASH32, +) -> PendingAttestation: bitfield = get_empty_bitfield(committee_size) for i in range(committee_size): bitfield = set_voted(bitfield, i) @@ -131,35 +105,28 @@ def mk_pending_attestation_from_committee(parent: Crosslink, def _mk_some_pending_attestations_with_some_participation_in_epoch( - state: BeaconState, - epoch: Epoch, - config: Eth2Config, - participation_ratio: float, - number_of_shards_to_check: int) -> Iterable[PendingAttestation]: + state: BeaconState, + epoch: Epoch, + config: Eth2Config, + participation_ratio: float, + number_of_shards_to_check: int, +) -> Iterable[PendingAttestation]: block_root = get_block_root( - state, - epoch, - config.SLOTS_PER_EPOCH, - config.SLOTS_PER_HISTORICAL_ROOT, - ) - epoch_start_shard = get_start_shard( - state, - epoch, - CommitteeConfig(config), + state, epoch, config.SLOTS_PER_EPOCH, config.SLOTS_PER_HISTORICAL_ROOT ) + epoch_start_shard = get_start_shard(state, epoch, CommitteeConfig(config)) if epoch == state.current_epoch(config.SLOTS_PER_EPOCH): parent_crosslinks = state.current_crosslinks else: parent_crosslinks = state.previous_crosslinks - for shard in range(epoch_start_shard, epoch_start_shard + number_of_shards_to_check): + for shard in range( + epoch_start_shard, epoch_start_shard + number_of_shards_to_check + ): shard = Shard(shard % config.SHARD_COUNT) crosslink_committee = get_crosslink_committee( - state, - epoch, - shard, - CommitteeConfig(config), + state, epoch, shard, CommitteeConfig(config) ) if not crosslink_committee: continue @@ -178,33 +145,23 @@ def _mk_some_pending_attestations_with_some_participation_in_epoch( def mk_all_pending_attestations_with_some_participation_in_epoch( - state: BeaconState, - epoch: Epoch, - config: Eth2Config, - participation_ratio: float) -> Iterable[PendingAttestation]: + state: BeaconState, epoch: Epoch, config: Eth2Config, participation_ratio: float +) -> Iterable[PendingAttestation]: return _mk_some_pending_attestations_with_some_participation_in_epoch( state, epoch, config, participation_ratio, - get_shard_delta( - state, - epoch, - CommitteeConfig(config), - ), + get_shard_delta(state, epoch, CommitteeConfig(config)), ) @to_tuple def mk_all_pending_attestations_with_full_participation_in_epoch( - state: BeaconState, - epoch: Epoch, - config: Eth2Config) -> Iterable[PendingAttestation]: + state: BeaconState, epoch: Epoch, config: Eth2Config +) -> Iterable[PendingAttestation]: return mk_all_pending_attestations_with_some_participation_in_epoch( - state, - epoch, - config, - 1.0, + state, epoch, config, 1.0 ) @@ -212,21 +169,18 @@ def mk_all_pending_attestations_with_full_participation_in_epoch( # Aggregation # def verify_votes( - message_hash: Hash32, - votes: Iterable[Tuple[ValidatorIndex, BLSSignature, BLSPubkey]], - domain: Domain) -> Tuple[Tuple[BLSSignature, ...], Tuple[ValidatorIndex, ...]]: + message_hash: Hash32, + votes: Iterable[Tuple[ValidatorIndex, BLSSignature, BLSPubkey]], + domain: Domain, +) -> Tuple[Tuple[BLSSignature, ...], Tuple[ValidatorIndex, ...]]: """ Verify the given votes. """ sigs_with_committee_info = tuple( (sig, committee_index) - for (committee_index, sig, pubkey) - in votes + for (committee_index, sig, pubkey) in votes if bls.verify( - message_hash=message_hash, - pubkey=pubkey, - signature=sig, - domain=domain, + message_hash=message_hash, pubkey=pubkey, signature=sig, domain=domain ) ) try: @@ -239,10 +193,10 @@ def verify_votes( def aggregate_votes( - bitfield: Bitfield, - sigs: Sequence[BLSSignature], - voting_sigs: Sequence[BLSSignature], - attesting_indices: Sequence[CommitteeIndex] + bitfield: Bitfield, + sigs: Sequence[BLSSignature], + voting_sigs: Sequence[BLSSignature], + attesting_indices: Sequence[CommitteeIndex], ) -> Tuple[Bitfield, BLSSignature]: """ Aggregate the votes. @@ -251,10 +205,7 @@ def aggregate_votes( sigs = tuple(sigs) + tuple(voting_sigs) bitfield = pipe( bitfield, - *( - set_voted(index=committee_index) - for committee_index in attesting_indices - ) + *(set_voted(index=committee_index) for committee_index in attesting_indices) ) return bitfield, bls.aggregate_signatures(sigs) @@ -263,8 +214,7 @@ def aggregate_votes( # # Signer # -def sign_proof_of_possession(deposit_data: DepositData, - privkey: int) -> BLSSignature: +def sign_proof_of_possession(deposit_data: DepositData, privkey: int) -> BLSSignature: return bls.sign( message_hash=deposit_data.signing_root, privkey=privkey, @@ -272,37 +222,36 @@ def sign_proof_of_possession(deposit_data: DepositData, ) -def sign_transaction(*, - message_hash: Hash32, - privkey: int, - state: BeaconState, - slot: Slot, - signature_domain: SignatureDomain, - slots_per_epoch: int) -> BLSSignature: +def sign_transaction( + *, + message_hash: Hash32, + privkey: int, + state: BeaconState, + slot: Slot, + signature_domain: SignatureDomain, + slots_per_epoch: int +) -> BLSSignature: domain = get_domain( state, signature_domain, slots_per_epoch, message_epoch=compute_epoch_of_slot(slot, slots_per_epoch), ) - return bls.sign( - message_hash=message_hash, - privkey=privkey, - domain=domain, - ) + return bls.sign(message_hash=message_hash, privkey=privkey, domain=domain) -SAMPLE_HASH_1 = Hash32(b'\x11' * 32) -SAMPLE_HASH_2 = Hash32(b'\x22' * 32) +SAMPLE_HASH_1 = Hash32(b"\x11" * 32) +SAMPLE_HASH_2 = Hash32(b"\x22" * 32) def create_block_header_with_signature( - state: BeaconState, - body_root: Hash32, - privkey: int, - slots_per_epoch: int, - parent_root: Hash32=SAMPLE_HASH_1, - state_root: Hash32=SAMPLE_HASH_2)-> BeaconBlockHeader: + state: BeaconState, + body_root: Hash32, + privkey: int, + slots_per_epoch: int, + parent_root: Hash32 = SAMPLE_HASH_1, + state_root: Hash32 = SAMPLE_HASH_2, +) -> BeaconBlockHeader: block_header = BeaconBlockHeader( slot=state.slot, parent_root=parent_root, @@ -331,12 +280,13 @@ def create_block_header_with_signature( # ProposerSlashing # def create_mock_proposer_slashing_at_block( - state: BeaconState, - config: Eth2Config, - keymap: Dict[BLSPubkey, int], - block_root_1: Hash32, - block_root_2: Hash32, - proposer_index: ValidatorIndex) -> ProposerSlashing: + state: BeaconState, + config: Eth2Config, + keymap: Dict[BLSPubkey, int], + block_root_1: Hash32, + block_root_2: Hash32, + proposer_index: ValidatorIndex, +) -> ProposerSlashing: """ Return a `ProposerSlashing` derived from the given block roots. @@ -360,19 +310,19 @@ def create_mock_proposer_slashing_at_block( ) return ProposerSlashing( - proposer_index=proposer_index, - header_1=block_header_1, - header_2=block_header_2, + proposer_index=proposer_index, header_1=block_header_1, header_2=block_header_2 ) # # AttesterSlashing # -def create_mock_slashable_attestation(state: BeaconState, - config: Eth2Config, - keymap: Dict[BLSPubkey, int], - attestation_slot: Slot) -> IndexedAttestation: +def create_mock_slashable_attestation( + state: BeaconState, + config: Eth2Config, + keymap: Dict[BLSPubkey, int], + attestation_slot: Slot, +) -> IndexedAttestation: """ Create an `IndexedAttestation` that is signed by one attester. """ @@ -382,9 +332,7 @@ def create_mock_slashable_attestation(state: BeaconState, # Use genesis block root as `beacon_block_root`, only for tests. beacon_block_root = get_block_root_at_slot( - state, - attestation_slot, - config.SLOTS_PER_HISTORICAL_ROOT, + state, attestation_slot, config.SLOTS_PER_HISTORICAL_ROOT ) # Get `target_root` @@ -393,8 +341,7 @@ def create_mock_slashable_attestation(state: BeaconState, source_root = get_block_root_at_slot( state, compute_start_slot_of_epoch( - state.current_justified_checkpoint.epoch, - config.SLOTS_PER_EPOCH, + state.current_justified_checkpoint.epoch, config.SLOTS_PER_EPOCH ), config.SLOTS_PER_HISTORICAL_ROOT, ) @@ -403,32 +350,22 @@ def create_mock_slashable_attestation(state: BeaconState, attestation_data = AttestationData( beacon_block_root=beacon_block_root, source=Checkpoint( - epoch=state.current_justified_checkpoint.epoch, - root=source_root, + epoch=state.current_justified_checkpoint.epoch, root=source_root ), target=Checkpoint( - epoch=compute_epoch_of_slot( - attestation_slot, - config.SLOTS_PER_EPOCH, - ), + epoch=compute_epoch_of_slot(attestation_slot, config.SLOTS_PER_EPOCH), root=target_root, ), crosslink=previous_crosslink, ) message_hash, attesting_indices = _get_mock_message_and_attesting_indices( - attestation_data, - committee, - num_voted_attesters=1, + attestation_data, committee, num_voted_attesters=1 ) signature = sign_transaction( message_hash=message_hash, - privkey=keymap[ - state.validators[ - attesting_indices[0] - ].pubkey - ], + privkey=keymap[state.validators[attesting_indices[0]].pubkey], state=state, slot=attestation_slot, signature_domain=SignatureDomain.DOMAIN_ATTESTATION, @@ -445,45 +382,43 @@ def create_mock_slashable_attestation(state: BeaconState, def create_mock_attester_slashing_is_double_vote( - state: BeaconState, - config: Eth2Config, - keymap: Dict[BLSPubkey, int], - attestation_epoch: Epoch) -> AttesterSlashing: - attestation_slot_1 = compute_start_slot_of_epoch(attestation_epoch, config.SLOTS_PER_EPOCH) + state: BeaconState, + config: Eth2Config, + keymap: Dict[BLSPubkey, int], + attestation_epoch: Epoch, +) -> AttesterSlashing: + attestation_slot_1 = compute_start_slot_of_epoch( + attestation_epoch, config.SLOTS_PER_EPOCH + ) attestation_slot_2 = Slot(attestation_slot_1 + 1) slashable_attestation_1 = create_mock_slashable_attestation( - state, - config, - keymap, - attestation_slot_1, + state, config, keymap, attestation_slot_1 ) slashable_attestation_2 = create_mock_slashable_attestation( - state, - config, - keymap, - attestation_slot_2, + state, config, keymap, attestation_slot_2 ) return AttesterSlashing( - attestation_1=slashable_attestation_1, - attestation_2=slashable_attestation_2, + attestation_1=slashable_attestation_1, attestation_2=slashable_attestation_2 ) def create_mock_attester_slashing_is_surround_vote( - state: BeaconState, - config: Eth2Config, - keymap: Dict[BLSPubkey, int], - attestation_epoch: Epoch) -> AttesterSlashing: + state: BeaconState, + config: Eth2Config, + keymap: Dict[BLSPubkey, int], + attestation_epoch: Epoch, +) -> AttesterSlashing: # target_epoch_2 < target_epoch_1 - attestation_slot_2 = compute_start_slot_of_epoch(attestation_epoch, config.SLOTS_PER_EPOCH) + attestation_slot_2 = compute_start_slot_of_epoch( + attestation_epoch, config.SLOTS_PER_EPOCH + ) attestation_slot_1 = Slot(attestation_slot_2 + config.SLOTS_PER_EPOCH) slashable_attestation_1 = create_mock_slashable_attestation( state.copy( - slot=attestation_slot_1, - current_justified_epoch=config.GENESIS_EPOCH, + slot=attestation_slot_1, current_justified_epoch=config.GENESIS_EPOCH ), config, keymap, @@ -492,7 +427,8 @@ def create_mock_attester_slashing_is_surround_vote( slashable_attestation_2 = create_mock_slashable_attestation( state.copy( slot=attestation_slot_1, - current_justified_epoch=config.GENESIS_EPOCH + 1, # source_epoch_1 < source_epoch_2 + current_justified_epoch=config.GENESIS_EPOCH + + 1, # source_epoch_1 < source_epoch_2 ), config, keymap, @@ -500,81 +436,71 @@ def create_mock_attester_slashing_is_surround_vote( ) return AttesterSlashing( - attestation_1=slashable_attestation_1, - attestation_2=slashable_attestation_2, + attestation_1=slashable_attestation_1, attestation_2=slashable_attestation_2 ) # # Attestation # -def _get_target_root(state: BeaconState, - config: Eth2Config, - beacon_block_root: Hash32) -> Hash32: +def _get_target_root( + state: BeaconState, config: Eth2Config, beacon_block_root: Hash32 +) -> Hash32: epoch = compute_epoch_of_slot(state.slot, config.SLOTS_PER_EPOCH) - epoch_start_slot = compute_start_slot_of_epoch( - epoch, - config.SLOTS_PER_EPOCH, - ) + epoch_start_slot = compute_start_slot_of_epoch(epoch, config.SLOTS_PER_EPOCH) if epoch_start_slot == state.slot: return beacon_block_root else: return get_block_root( - state, - epoch, - config.SLOTS_PER_EPOCH, - config.SLOTS_PER_HISTORICAL_ROOT, + state, epoch, config.SLOTS_PER_EPOCH, config.SLOTS_PER_HISTORICAL_ROOT ) def _get_mock_message_and_attesting_indices( - attestation_data: AttestationData, - committee: Sequence[ValidatorIndex], - num_voted_attesters: int) -> Tuple[Hash32, Tuple[CommitteeIndex, ...]]: + attestation_data: AttestationData, + committee: Sequence[ValidatorIndex], + num_voted_attesters: int, +) -> Tuple[Hash32, Tuple[CommitteeIndex, ...]]: """ Get ``message_hash`` and voting indices of the given ``committee``. """ message_hash = AttestationDataAndCustodyBit( - data=attestation_data, - custody_bit=False + data=attestation_data, custody_bit=False ).hash_tree_root committee_size = len(committee) assert num_voted_attesters <= committee_size attesting_indices = tuple( - CommitteeIndex(i) for i in random.sample(range(committee_size), num_voted_attesters) + CommitteeIndex(i) + for i in random.sample(range(committee_size), num_voted_attesters) ) return message_hash, tuple(sorted(attesting_indices)) -def _create_mock_signed_attestation(state: BeaconState, - attestation_data: AttestationData, - attestation_slot: Slot, - committee: Sequence[ValidatorIndex], - num_voted_attesters: int, - keymap: Dict[BLSPubkey, int], - slots_per_epoch: int) -> Attestation: +def _create_mock_signed_attestation( + state: BeaconState, + attestation_data: AttestationData, + attestation_slot: Slot, + committee: Sequence[ValidatorIndex], + num_voted_attesters: int, + keymap: Dict[BLSPubkey, int], + slots_per_epoch: int, +) -> Attestation: """ Create a mocking attestation of the given ``attestation_data`` slot with ``keymap``. """ message_hash, attesting_indices = _get_mock_message_and_attesting_indices( - attestation_data, - committee, - num_voted_attesters, + attestation_data, committee, num_voted_attesters ) # Use privkeys to sign the attestation signatures = [ sign_transaction( message_hash=message_hash, - privkey=keymap[ - state.validators[ - committee[committee_index] - ].pubkey - ], + privkey=keymap[state.validators[committee[committee_index]].pubkey], state=state, slot=attestation_slot, signature_domain=SignatureDomain.DOMAIN_ATTESTATION, @@ -602,51 +528,52 @@ def _create_mock_signed_attestation(state: BeaconState, # TODO(ralexstokes) merge in w/ ``get_committee_assignment`` def get_crosslink_committees_at_slot( - state: BeaconState, - slot: Slot, - config: Eth2Config) -> Tuple[Tuple[Tuple[ValidatorIndex, ...], Shard], ...]: + state: BeaconState, slot: Slot, config: Eth2Config +) -> Tuple[Tuple[Tuple[ValidatorIndex, ...], Shard], ...]: epoch = compute_epoch_of_slot(slot, config.SLOTS_PER_EPOCH) active_validators = get_active_validator_indices(state.validators, epoch) - committees_per_slot = get_committee_count( - len(active_validators), - config.SHARD_COUNT, - config.SLOTS_PER_EPOCH, - config.TARGET_COMMITTEE_SIZE, - ) // config.SLOTS_PER_EPOCH + committees_per_slot = ( + get_committee_count( + len(active_validators), + config.SHARD_COUNT, + config.SLOTS_PER_EPOCH, + config.TARGET_COMMITTEE_SIZE, + ) + // config.SLOTS_PER_EPOCH + ) results = [] offset = committees_per_slot * (slot % config.SLOTS_PER_EPOCH) - slot_start_shard = Shard(( - get_start_shard(state, epoch, CommitteeConfig(config)) + offset - ) % config.SHARD_COUNT) + slot_start_shard = Shard( + (get_start_shard(state, epoch, CommitteeConfig(config)) + offset) + % config.SHARD_COUNT + ) for i in range(committees_per_slot): shard = (slot_start_shard + i) % config.SHARD_COUNT - committee = get_crosslink_committee(state, epoch, shard, CommitteeConfig(config)) + committee = get_crosslink_committee( + state, epoch, shard, CommitteeConfig(config) + ) results.append((committee, Shard(shard))) return tuple(results) -def create_signed_attestation_at_slot(state: BeaconState, - config: Eth2Config, - state_machine: BaseBeaconStateMachine, - attestation_slot: Slot, - beacon_block_root: Hash32, - validator_privkeys: Dict[ValidatorIndex, int], - committee: Tuple[ValidatorIndex, ...], - shard: Shard) -> Attestation: +def create_signed_attestation_at_slot( + state: BeaconState, + config: Eth2Config, + state_machine: BaseBeaconStateMachine, + attestation_slot: Slot, + beacon_block_root: Hash32, + validator_privkeys: Dict[ValidatorIndex, int], + committee: Tuple[ValidatorIndex, ...], + shard: Shard, +) -> Attestation: """ Create the attestations of the given ``attestation_slot`` slot with ``validator_privkeys``. """ state_transition = state_machine.state_transition - state = state_transition.apply_state_transition( - state, - future_slot=attestation_slot, - ) + state = state_transition.apply_state_transition(state, future_slot=attestation_slot) - target_epoch = compute_epoch_of_slot( - attestation_slot, - config.SLOTS_PER_EPOCH, - ) + target_epoch = compute_epoch_of_slot(attestation_slot, config.SLOTS_PER_EPOCH) target_root = _get_target_root(state, config, beacon_block_root) @@ -658,16 +585,13 @@ def create_signed_attestation_at_slot(state: BeaconState, epoch=state.current_justified_checkpoint.epoch, root=state.current_justified_checkpoint.root, ), - target=Checkpoint( - root=target_root, - epoch=target_epoch, - ), + target=Checkpoint(root=target_root, epoch=target_epoch), crosslink=Crosslink( shard=shard, parent_root=parent_crosslink.hash_tree_root, start_epoch=parent_crosslink.end_epoch, end_epoch=target_epoch, - ) + ), ) return _create_mock_signed_attestation( @@ -683,28 +607,24 @@ def create_signed_attestation_at_slot(state: BeaconState, @to_tuple def create_mock_signed_attestations_at_slot( - state: BeaconState, - config: Eth2Config, - state_machine: BaseBeaconStateMachine, - attestation_slot: Slot, - beacon_block_root: Hash32, - keymap: Dict[BLSPubkey, int], - voted_attesters_ratio: float=1.0) -> Iterable[Attestation]: + state: BeaconState, + config: Eth2Config, + state_machine: BaseBeaconStateMachine, + attestation_slot: Slot, + beacon_block_root: Hash32, + keymap: Dict[BLSPubkey, int], + voted_attesters_ratio: float = 1.0, +) -> Iterable[Attestation]: """ Create the mocking attestations of the given ``attestation_slot`` slot with ``keymap``. """ crosslink_committees_at_slot = get_crosslink_committees_at_slot( - state, - attestation_slot, - config, + state, attestation_slot, config ) # Get `target_root` target_root = _get_target_root(state, config, beacon_block_root) - target_epoch = compute_epoch_of_slot( - state.slot, - config.SLOTS_PER_EPOCH, - ) + target_epoch = compute_epoch_of_slot(state.slot, config.SLOTS_PER_EPOCH) for crosslink_committee in crosslink_committees_at_slot: committee, shard = crosslink_committee @@ -717,19 +637,16 @@ def create_mock_signed_attestations_at_slot( epoch=state.current_justified_checkpoint.epoch, root=state.current_justified_checkpoint.root, ), - target=Checkpoint( - root=target_root, - epoch=target_epoch, - ), + target=Checkpoint(root=target_root, epoch=target_epoch), crosslink=Crosslink( shard=shard, parent_root=parent_crosslink.hash_tree_root, start_epoch=parent_crosslink.end_epoch, end_epoch=min( target_epoch, - parent_crosslink.end_epoch + config.MAX_EPOCHS_PER_CROSSLINK + parent_crosslink.end_epoch + config.MAX_EPOCHS_PER_CROSSLINK, ), - ) + ), ) num_voted_attesters = int(len(committee) * voted_attesters_ratio) @@ -748,17 +665,16 @@ def create_mock_signed_attestations_at_slot( # # VoluntaryExit # -def create_mock_voluntary_exit(state: BeaconState, - config: Eth2Config, - keymap: Dict[BLSPubkey, int], - validator_index: ValidatorIndex, - exit_epoch: Epoch=None) -> VoluntaryExit: +def create_mock_voluntary_exit( + state: BeaconState, + config: Eth2Config, + keymap: Dict[BLSPubkey, int], + validator_index: ValidatorIndex, + exit_epoch: Epoch = None, +) -> VoluntaryExit: current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) target_epoch = current_epoch if exit_epoch is None else exit_epoch - voluntary_exit = VoluntaryExit( - epoch=target_epoch, - validator_index=validator_index, - ) + voluntary_exit = VoluntaryExit(epoch=target_epoch, validator_index=validator_index) return voluntary_exit.copy( signature=sign_transaction( message_hash=voluntary_exit.signing_root, @@ -774,24 +690,19 @@ def create_mock_voluntary_exit(state: BeaconState, # # Deposit # -def create_mock_deposit_data(*, - config: Eth2Config, - pubkey: BLSPubkey, - privkey: int, - withdrawal_credentials: Hash32, - amount: Gwei=None) -> DepositData: +def create_mock_deposit_data( + *, + config: Eth2Config, + pubkey: BLSPubkey, + privkey: int, + withdrawal_credentials: Hash32, + amount: Gwei = None +) -> DepositData: if amount is None: amount = config.MAX_EFFECTIVE_BALANCE data = DepositData( - pubkey=pubkey, - withdrawal_credentials=withdrawal_credentials, - amount=amount, - ) - signature = sign_proof_of_possession( - deposit_data=data, - privkey=privkey, - ) - return data.copy( - signature=signature, + pubkey=pubkey, withdrawal_credentials=withdrawal_credentials, amount=amount ) + signature = sign_proof_of_possession(deposit_data=data, privkey=privkey) + return data.copy(signature=signature) diff --git a/eth2/beacon/tools/factories.py b/eth2/beacon/tools/factories.py index 9af12279fd..0d30831991 100644 --- a/eth2/beacon/tools/factories.py +++ b/eth2/beacon/tools/factories.py @@ -1,46 +1,24 @@ import time -import factory -from typing import ( - Any, - Type, - TypeVar, -) - -from eth2._utils.bls import bls - -from eth2._utils.hash import ( - hash_eth2, -) -from eth2.beacon.state_machines.forks.serenity.blocks import ( - SerenityBeaconBlock, -) -from eth2.beacon.state_machines.forks.xiao_long_bao.configs import ( - XIAO_LONG_BAO_CONFIG, -) -from eth2.beacon.tools.builder.initializer import ( - create_mock_genesis, -) +from typing import Any, Type, TypeVar from eth.db.atomic import AtomicDB -from eth2.beacon.typing import ( - Timestamp, -) -from eth2.beacon.chains.base import ( - BaseBeaconChain, -) -from eth2.beacon.chains.testnet import ( - TestnetChain, -) -from eth2.configs import ( - Eth2GenesisConfig, -) +import factory +from eth2._utils.bls import bls +from eth2._utils.hash import hash_eth2 +from eth2.beacon.chains.base import BaseBeaconChain +from eth2.beacon.chains.testnet import TestnetChain +from eth2.beacon.state_machines.forks.serenity.blocks import SerenityBeaconBlock +from eth2.beacon.state_machines.forks.xiao_long_bao.configs import XIAO_LONG_BAO_CONFIG +from eth2.beacon.tools.builder.initializer import create_mock_genesis +from eth2.beacon.typing import Timestamp +from eth2.configs import Eth2GenesisConfig NUM_VALIDATORS = 8 -privkeys = tuple(int.from_bytes( - hash_eth2(str(i).encode('utf-8'))[:4], 'big') +privkeys = tuple( + int.from_bytes(hash_eth2(str(i).encode("utf-8"))[:4], "big") for i in range(NUM_VALIDATORS) ) index_to_pubkey = {} diff --git a/eth2/beacon/tools/fixtures/conditions.py b/eth2/beacon/tools/fixtures/conditions.py new file mode 100644 index 0000000000..d2e044762c --- /dev/null +++ b/eth2/beacon/tools/fixtures/conditions.py @@ -0,0 +1,18 @@ +from ssz.tools import to_formatted_dict + +from eth2.beacon.types.states import BeaconState + + +def validate_state(post_state: BeaconState, expected_state: BeaconState) -> None: + # Use dict diff, easier to see the diff + dict_post_state = to_formatted_dict(post_state, BeaconState) + dict_expected_state = to_formatted_dict(expected_state, BeaconState) + for key, value in dict_expected_state.items(): + if isinstance(value, list): + value = tuple(value) + if dict_post_state[key] != value: + raise AssertionError( + f"state.{key} is incorrect:\n" + f"\tExpected: {value}\n" + f"\tResult: {dict_post_state[key]}\n" + ) diff --git a/eth2/beacon/tools/fixtures/config_name.py b/eth2/beacon/tools/fixtures/config_name.py deleted file mode 100644 index 5a68db2f5a..0000000000 --- a/eth2/beacon/tools/fixtures/config_name.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import ( - NewType, -) - - -ConfigName = NewType('ConfigName', str) - -Mainnet = "mainnet" -Minimal = "minimal" - -ALL_CONFIG_NAMES = (Mainnet, Minimal) -ONLY_MINIMAL = (Minimal,) diff --git a/eth2/beacon/tools/fixtures/config_types.py b/eth2/beacon/tools/fixtures/config_types.py new file mode 100644 index 0000000000..6e2d2b0e68 --- /dev/null +++ b/eth2/beacon/tools/fixtures/config_types.py @@ -0,0 +1,24 @@ +import abc + + +class ConfigType(abc.ABC): + name: str + path: str + + +class Mainnet(ConfigType): + # name is the human-readable name for this configuration. + name = "mainnet" + # path is the file system path to the config as YAML, relative to the project root. + path = "tests/eth2/fixtures/mainnet.yaml" + + +class Full(Mainnet): + name = "full" + + +class Minimal(ConfigType): + # name is the human-readable name for this configuration. + name = "minimal" + # path is the file system path to the config as YAML, relative to the project root. + path = "tests/eth2/fixtures/minimal.yaml" diff --git a/eth2/beacon/tools/fixtures/helpers.py b/eth2/beacon/tools/fixtures/helpers.py deleted file mode 100644 index 11eabd0ace..0000000000 --- a/eth2/beacon/tools/fixtures/helpers.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import ( - Tuple, - Type, -) - -from eth_utils import ( - encode_hex, - ValidationError, -) -from ssz.tools import ( - to_formatted_dict, -) - -from eth2.beacon.db.chain import BeaconChainDB -from eth2.beacon.tools.builder.proposer import ( - advance_to_slot, -) -from eth2.beacon.operations.attestation_pool import AttestationPool -from eth2.beacon.state_machines.forks.serenity import ( - SerenityStateMachine, -) -from eth2.beacon.tools.fixtures.test_case import ( - StateTestCase, -) -from eth2.beacon.types.states import BeaconState - - -def run_state_execution(test_case: StateTestCase, - sm_class: Type[SerenityStateMachine], - chaindb: BeaconChainDB, - attestation_pool: AttestationPool, - state: BeaconState) -> BeaconState: - chaindb.persist_state(state) - post_state = state - post_state, chaindb = apply_advance_to_slot( - test_case, - sm_class, - chaindb, - attestation_pool, - post_state, - ) - post_state, chaindb = apply_blocks( - test_case, - sm_class, - chaindb, - attestation_pool, - post_state, - ) - return post_state - - -def apply_advance_to_slot(test_case: StateTestCase, - sm_class: Type[SerenityStateMachine], - chaindb: BeaconChainDB, - attestation_pool: AttestationPool, - state: BeaconState) -> Tuple[BeaconState, BeaconChainDB]: - post_state = state.copy() - sm = sm_class(chaindb, attestation_pool, None, post_state) - slot = test_case.pre.slot + test_case.slots - chaindb.persist_state(post_state) - return advance_to_slot(sm, post_state, slot), chaindb - - -def apply_blocks(test_case: StateTestCase, - sm_class: Type[SerenityStateMachine], - chaindb: BeaconChainDB, - attestation_pool: AttestationPool, - state: BeaconState) -> Tuple[BeaconState, BeaconChainDB]: - post_state = state.copy() - for block in test_case.blocks: - sm = sm_class(chaindb, attestation_pool, None, post_state) - post_state, imported_block = sm.import_block(block) - chaindb.persist_state(post_state) - if imported_block.state_root != block.state_root: - raise ValidationError( - f"Block did not have the expected state root:\n" - f"\tExpected: {encode_hex(block.state_root)}\n" - f"\tResult: {encode_hex(imported_block.state_root)}\n" - ) - - return post_state, chaindb - - -def validate_state(test_case_state: BeaconState, post_state: BeaconState) -> None: - # Use dict diff, easier to see the diff - dict_post_state = to_formatted_dict(post_state, BeaconState) - dict_expected_state = to_formatted_dict(test_case_state, BeaconState) - for key, value in dict_expected_state.items(): - if isinstance(value, list): - value = tuple(value) - if dict_post_state[key] != value: - raise AssertionError( - f"state.{key} is incorrect:\n" - f"\tExpected: {value}\n" - f"\tResult: {dict_post_state[key]}\n" - ) diff --git a/eth2/beacon/tools/fixtures/loading.py b/eth2/beacon/tools/fixtures/loading.py index eae3113073..ee9816025d 100644 --- a/eth2/beacon/tools/fixtures/loading.py +++ b/eth2/beacon/tools/fixtures/loading.py @@ -1,247 +1,88 @@ -import os from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - Iterable, - Sequence, - Tuple, - Type, - Union, -) -from ruamel.yaml import ( - YAML, -) - -from eth_utils import ( - decode_hex, - to_tuple, -) -from eth_typing import ( - BLSPubkey, - BLSSignature, -) -from eth_utils.toolz import ( - assoc, - keyfilter, -) -from ssz.tools import ( - from_formatted_dict, -) - -from eth2.beacon.helpers import ( - compute_epoch_of_slot, -) -from eth2.beacon.types.blocks import BaseBeaconBlock -from eth2.beacon.types.deposits import Deposit -from eth2.beacon.types.states import BeaconState -from eth2.configs import ( - Eth2Config, -) - -from eth2.beacon.tools.fixtures.config_name import ( - ALL_CONFIG_NAMES, - ConfigName, -) -from eth2.beacon.tools.fixtures.test_case import ( - OperationOrBlockHeader, -) -from eth2.beacon.tools.fixtures.test_file import ( - TestFile, -) - - -# -# Eth2Config -# +from typing import Any, Dict, Tuple, Union + +from eth_typing import BLSPubkey, BLSSignature +from eth_utils import decode_hex +from eth_utils.toolz import assoc, keyfilter +from ruamel.yaml import YAML + +from eth2.beacon.helpers import compute_epoch_of_slot +from eth2.configs import Eth2Config + + def generate_config_by_dict(dict_config: Dict[str, Any]) -> Eth2Config: - config_without_domains = keyfilter(lambda name: "DOMAIN_" not in name, dict_config) - config_without_phase_1 = keyfilter( - lambda name: "EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS" not in name, - config_without_domains, - ) + filtered_keys = ("DOMAIN_", "EARLY_DERIVED_SECRET_PENALTY_MAX_FUTURE_EPOCHS") return Eth2Config( **assoc( - config_without_phase_1, + keyfilter( + lambda name: all(key not in name for key in filtered_keys), dict_config + ), "GENESIS_EPOCH", compute_epoch_of_slot( - dict_config['GENESIS_SLOT'], - dict_config['SLOTS_PER_EPOCH'], - ) + dict_config["GENESIS_SLOT"], dict_config["SLOTS_PER_EPOCH"] + ), ) ) -config_cache: Dict[str, Eth2Config] = {} - +def _load_yaml_at(p: Path) -> Dict[str, Any]: + y = YAML(typ="unsafe") + return y.load(p) -def get_config(root_project_dir: Path, config_name: ConfigName) -> Eth2Config: - if config_name in config_cache: - return config_cache[config_name] - - # TODO: change the path after the constants presets are copied to submodule - path = root_project_dir / 'tests/eth2/fixtures' - yaml = YAML(typ="unsafe") - file_name = config_name + '.yaml' - file_to_open = path / file_name - with open(file_to_open, 'U') as f: - new_text = f.read() - data = yaml.load(new_text) - config = generate_config_by_dict(data) - config_cache[config_name] = config - return config - - -def get_test_file_from_dict(data: Dict[str, Any], - root_project_dir: Path, - file_name: str, - parse_test_case_fn: Callable[..., Any]) -> TestFile: - config_name = data['config'] - assert config_name in ALL_CONFIG_NAMES - config_name = ConfigName(config_name) - config = get_config(root_project_dir, config_name) - handler = data['handler'] - parsed_test_cases = tuple( - parse_test_case_fn(test_case, handler, index, config) - for index, test_case in enumerate(data['test_cases']) - ) - return TestFile( - file_name=file_name, - config=config, - test_cases=parsed_test_cases, - ) - - -@to_tuple -def get_yaml_files_pathes(dir_path: Path) -> Iterable[str]: - for root, _, files in os.walk(dir_path): - for name in files: - yield os.path.join(root, name) - - -@to_tuple -def load_from_yaml_files(root_project_dir: Path, - dir_path: Path, - config_names: Sequence[ConfigName], - parse_test_case_fn: Callable[..., Any]) -> Iterable[TestFile]: - entries = get_yaml_files_pathes(dir_path) - for file_path in entries: - file_name = os.path.basename(file_path) - if len(config_names) == 0: - yield load_from_yaml_file(root_project_dir, file_path, file_name, parse_test_case_fn) - for config_name in config_names: - if config_name in file_name: - yield load_from_yaml_file( - root_project_dir, - file_path, - file_name, - parse_test_case_fn, - ) - - -def load_from_yaml_file(root_project_dir: Path, - file_path: str, - file_name: str, - parse_test_case_fn: Callable[..., Any]) -> TestFile: - yaml = YAML(typ="unsafe") - with open(file_path, 'U') as f: - new_text = f.read() - data = yaml.load(new_text) - test_file = get_test_file_from_dict( - data, - root_project_dir, - file_name, - parse_test_case_fn, - ) - return test_file +# NOTE: should cache test suite data if users are running +# the same test suite at different points during testing. +def load_test_suite_at(p: Path) -> Dict[str, Any]: + return _load_yaml_at(p) -@to_tuple -def get_all_test_files(root_project_dir: Path, - fixture_pathes: Tuple[Path, ...], - config_names: Sequence[ConfigName], - parse_test_case_fn: Callable[..., Any]) -> Iterable[TestFile]: - for path in fixture_pathes: - yield from load_from_yaml_files(root_project_dir, path, config_names, parse_test_case_fn) +config_cache: Dict[Path, Eth2Config] = {} -# -# Parser helpers -# -def get_bls_setting(test_case: Dict[str, Any]) -> bool: - # Default is free to choose, so we choose OFF. - if 'bls_setting' not in test_case or test_case['bls_setting'] == 2: - return False - else: - return True +def load_config_at_path(p: Path) -> Eth2Config: + if p in config_cache: + return config_cache[p] -def get_states(test_case: Dict[str, Any], - cls_state: Type[BeaconState]) -> Tuple[BeaconState, BeaconState, bool]: - pre = from_formatted_dict(test_case['pre'], cls_state) - if test_case['post'] is not None: - post = from_formatted_dict(test_case['post'], cls_state) - is_valid = True - else: - post = None - is_valid = False - - return pre, post, is_valid - - -def get_slots(test_case: Dict[str, Any]) -> int: - return test_case['slots'] if 'slots' in test_case else 0 - - -def get_blocks(test_case: Dict[str, Any], - cls_block: Type[BaseBeaconBlock]) -> Tuple[BaseBeaconBlock, ...]: - if 'blocks' in test_case: - return tuple(from_formatted_dict(block, cls_block) for block in test_case['blocks']) - else: - return () + config_data = _load_yaml_at(p) + config = generate_config_by_dict(config_data) + config_cache[p] = config + return config -def get_deposits(test_case: Dict[str, Any], - cls_deposit: Type[Deposit]) -> Tuple[Deposit, ...]: - return tuple(from_formatted_dict(deposit, cls_deposit) for deposit in test_case['deposits']) +def get_input_bls_pubkeys( + test_case: Dict[str, Any] +) -> Dict[str, Tuple[BLSPubkey, ...]]: + return { + "pubkeys": tuple(BLSPubkey(decode_hex(item)) for item in test_case["input"]) + } -def get_operation_or_header(test_case: Dict[str, Any], - cls_operation_or_header: Type[OperationOrBlockHeader], - handler: str) -> Tuple[OperationOrBlockHeader, ...]: - if handler in test_case: - return from_formatted_dict(test_case[handler], cls_operation_or_header) - else: - raise NameError( - f"Operation {handler} is not supported." +def get_input_bls_signatures( + test_case: Dict[str, Any] +) -> Dict[str, Tuple[BLSSignature, ...]]: + return { + "signatures": tuple( + BLSSignature(decode_hex(item)) for item in test_case["input"] ) - - -def get_input_bls_pubkeys(test_case: Dict[str, Any]) -> Dict[str, Tuple[BLSPubkey, ...]]: - return {'pubkeys': tuple(BLSPubkey(decode_hex(item)) for item in test_case['input'])} - - -def get_input_bls_signatures(test_case: Dict[str, Any]) -> Dict[str, Tuple[BLSSignature, ...]]: - return {'signatures': tuple(BLSSignature(decode_hex(item)) for item in test_case['input'])} + } def get_input_bls_privkey(test_case: Dict[str, Any]) -> Dict[str, int]: - return {'privkey': int.from_bytes(decode_hex(test_case['input']), 'big')} + return {"privkey": int.from_bytes(decode_hex(test_case["input"]), "big")} -def get_input_sign_message(test_case: Dict[str, Any]) -> Dict[str, Union[int, bytes, bytes]]: +def get_input_sign_message(test_case: Dict[str, Any]) -> Dict[str, Union[int, bytes]]: return { - 'privkey': int.from_bytes(decode_hex(test_case['input']['privkey']), 'big'), - 'message_hash': decode_hex(test_case['input']['message']), - 'domain': decode_hex(test_case['input']['domain']), + "privkey": int.from_bytes(decode_hex(test_case["input"]["privkey"]), "big"), + "message_hash": decode_hex(test_case["input"]["message"]), + "domain": decode_hex(test_case["input"]["domain"]), } def get_output_bls_pubkey(test_case: Dict[str, Any]) -> BLSPubkey: - return BLSPubkey(decode_hex(test_case['output'])) + return BLSPubkey(decode_hex(test_case["output"])) def get_output_bls_signature(test_case: Dict[str, Any]) -> BLSSignature: - return BLSSignature(decode_hex(test_case['output'])) + return BLSSignature(decode_hex(test_case["output"])) diff --git a/eth2/beacon/tools/fixtures/parser.py b/eth2/beacon/tools/fixtures/parser.py new file mode 100644 index 0000000000..43de2e4265 --- /dev/null +++ b/eth2/beacon/tools/fixtures/parser.py @@ -0,0 +1,95 @@ +from pathlib import Path +from typing import Any, Dict, Generator, Optional, Sequence + +from eth2.beacon.tools.fixtures.config_types import ConfigType +from eth2.beacon.tools.fixtures.loading import load_config_at_path, load_test_suite_at +from eth2.beacon.tools.fixtures.test_case import TestCase +from eth2.beacon.tools.fixtures.test_handler import Input, Output, TestHandler +from eth2.beacon.tools.fixtures.test_types import HandlerType, TestType +from eth2.beacon.tools.misc.ssz_vector import override_lengths +from eth2.configs import Eth2Config + +# NOTE: if the tests_root_path keeps changing, can turn into +# a ``pytest.config.Option`` and supply from the command line. +TESTS_ROOT_PATH = Path("eth2-fixtures") +TESTS_PATH = Path("tests") + +TestSuite = Generator[TestCase, None, None] + + +def _build_test_suite_path( + tests_root_path: Path, + test_type: TestType[HandlerType], + test_handler: TestHandler[Input, Output], + config_type: Optional[ConfigType], +) -> Path: + return test_type.build_path(tests_root_path, test_handler, config_type) + + +def _parse_test_cases( + config: Eth2Config, + test_handler: TestHandler[Input, Output], + test_cases: Sequence[Dict[str, Any]], +) -> TestSuite: + for index, test_case in enumerate(test_cases): + yield TestCase(index, test_case, test_handler, config) + + +def _load_test_suite( + tests_root_path: Path, + test_type: TestType[HandlerType], + test_handler: TestHandler[Input, Output], + config_type: Optional[ConfigType], + config: Optional[Eth2Config], +) -> TestSuite: + test_suite_path = _build_test_suite_path( + tests_root_path, test_type, test_handler, config_type + ) + + test_suite_data = load_test_suite_at(test_suite_path) + + return _parse_test_cases(config, test_handler, test_suite_data["test_cases"]) + + +class DirectoryNotFoundException(Exception): + pass + + +def _search_for_dir(target_dir: Path, p: Path) -> Path: + for child in p.iterdir(): + if not child.is_dir(): + continue + if child.name == target_dir.name: + return child + raise DirectoryNotFoundException() + + +def _find_project_root_dir(target: Path) -> Path: + """ + Search the file tree for a path with a child directory equal to ``target``. + """ + p = Path(".").resolve() + for _ in range(1000): + try: + candidate = _search_for_dir(target, p) + return candidate.parent + except DirectoryNotFoundException: + p = p.parent + raise DirectoryNotFoundException + + +def parse_test_suite( + test_type: TestType[HandlerType], + test_handler: TestHandler[Input, Output], + config_type: Optional[ConfigType], +) -> TestSuite: + project_root_dir = _find_project_root_dir(TESTS_ROOT_PATH) + tests_path = project_root_dir / TESTS_ROOT_PATH / TESTS_PATH + if config_type: + config_path = project_root_dir / config_type.path + config = load_config_at_path(config_path) + override_lengths(config) + else: + config = None + + return _load_test_suite(tests_path, test_type, test_handler, config_type, config) diff --git a/eth2/beacon/tools/fixtures/test_case.py b/eth2/beacon/tools/fixtures/test_case.py index 73c7ed7910..eeb5063d8a 100644 --- a/eth2/beacon/tools/fixtures/test_case.py +++ b/eth2/beacon/tools/fixtures/test_case.py @@ -1,56 +1,50 @@ -from typing import ( - Tuple, - Union, -) - -from dataclasses import ( - dataclass, - field, -) - - -from eth2.beacon.types.attestations import Attestation -from eth2.beacon.types.attester_slashings import AttesterSlashing -from eth2.beacon.types.blocks import BeaconBlock -from eth2.beacon.types.block_headers import BeaconBlockHeader -from eth2.beacon.types.deposits import Deposit -from eth2.beacon.types.proposer_slashings import ProposerSlashing -from eth2.beacon.types.transfers import Transfer -from eth2.beacon.types.voluntary_exits import VoluntaryExit - - -from eth2.beacon.types.states import BeaconState -from eth2.beacon.typing import ( - Slot, -) - - -Operation = Union[ProposerSlashing, AttesterSlashing, Attestation, Deposit, VoluntaryExit, Transfer] -OperationOrBlockHeader = Union[Operation, BeaconBlockHeader] - - -@dataclass -class BaseTestCase: - handler: str - index: int - - -@dataclass -class StateTestCase(BaseTestCase): - bls_setting: bool - description: str - pre: BeaconState - post: BeaconState - slots: Slot = Slot(0) - blocks: Tuple[BeaconBlock, ...] = field(default_factory=tuple) - is_valid: bool = True - - -@dataclass -class OperationCase(BaseTestCase): - bls_setting: bool - description: str - pre: BeaconState - operation: OperationOrBlockHeader - post: BeaconState - is_valid: bool = True +from enum import Enum +from typing import Any, Dict + +from eth2._utils.bls import bls +from eth2._utils.bls.backends import MilagroBackend +from eth2.configs import Eth2Config + +from .test_handler import Input, Output, TestHandler + + +class BLSSetting(Enum): + Optional = 0 + Enabled = 1 + Disabled = 2 + + +def _select_bls_backend(bls_setting: BLSSetting) -> None: + if bls_setting == BLSSetting.Disabled: + bls.use_noop_backend() + elif bls_setting == BLSSetting.Enabled: + bls.use(MilagroBackend) + elif bls_setting == BLSSetting.Optional: + # do not verify BLS to save time + bls.use_noop_backend() + + +class TestCase: + def __init__( + self, + index: int, + test_case_data: Dict[str, Any], + handler: TestHandler[Input, Output], + config: Eth2Config, + ) -> None: + self.index = index + self.description = test_case_data.get("description", "") + self.bls_setting = BLSSetting(test_case_data.get("bls_setting", 0)) + self.config = config + self.test_case_data = test_case_data + self.handler = handler + + def valid(self) -> bool: + return self.handler.valid(self.test_case_data) + + def execute(self) -> None: + _select_bls_backend(self.bls_setting) + inputs = self.handler.parse_inputs(self.test_case_data) + outputs = self.handler.run_with(inputs, self.config) + expected_outputs = self.handler.parse_outputs(self.test_case_data) + self.handler.condition(outputs, expected_outputs) diff --git a/eth2/beacon/tools/fixtures/test_file.py b/eth2/beacon/tools/fixtures/test_file.py deleted file mode 100644 index 7d4cb03f31..0000000000 --- a/eth2/beacon/tools/fixtures/test_file.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import ( - Sequence, -) -from dataclasses import ( - dataclass, -) - -from eth2.configs import ( - Eth2Config, -) -from eth2.beacon.tools.fixtures.test_case import BaseTestCase - - -@dataclass -class TestFile: - file_name: str - config: Eth2Config - test_cases: Sequence[BaseTestCase] diff --git a/eth2/beacon/tools/fixtures/test_gen.py b/eth2/beacon/tools/fixtures/test_gen.py new file mode 100644 index 0000000000..c60b77dc7f --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_gen.py @@ -0,0 +1,166 @@ +import itertools +from typing import Any, Callable, Dict, Generator, Iterator, Set, Tuple + +from eth_utils.toolz import thread_last +from typing_extensions import Protocol + +from eth2.beacon.tools.fixtures.config_types import ConfigType +from eth2.beacon.tools.fixtures.parser import parse_test_suite +from eth2.beacon.tools.fixtures.test_case import TestCase +from eth2.beacon.tools.fixtures.test_handler import Input, Output, TestHandler +from eth2.beacon.tools.fixtures.test_types import HandlerType, TestType + +TestSuiteDescriptor = Tuple[Tuple[TestType[Any], TestHandler[Any, Any]], ConfigType] + + +class DecoratorTarget(Protocol): + __eth2_fixture_config: Dict[str, Any] + + +# NOTE: ``pytest`` does not export the ``Metafunc`` class so we +# make a new type here to stand in for it. +class Metafunc(Protocol): + function: DecoratorTarget + + def parametrize( + self, param_name: str, argvals: Tuple[TestCase, ...], ids: Tuple[str, ...] + ) -> None: + ... + + +def pytest_from_eth2_fixture( + config: Dict[str, Any] +) -> Callable[[DecoratorTarget], DecoratorTarget]: + """ + This function attaches the ``config`` to the ``func`` via the + ``decorator``. The idea here is to just communicate this data to + later stages of the test generation. + """ + + def decorator(func: DecoratorTarget) -> DecoratorTarget: + func.__eth2_fixture_config = config + return func + + return decorator + + +def _read_request_from_metafunc(metafunc: Metafunc) -> Dict[str, Any]: + fn = metafunc.function + return fn.__eth2_fixture_config + + +requested_config_types: Set[ConfigType] = set() + + +def _add_config_type_to_tracking_set(config: ConfigType) -> None: + if len(requested_config_types) == 0: + requested_config_types.add(config) + + +def _check_only_one_config_type(config_type: ConfigType) -> None: + """ + Given the way we currently handle setting the size of dynamic SSZ types, + we can only run one type of configuration *per process*. + """ + if config_type not in requested_config_types: + raise Exception( + "Can only run a _single_ type of configuration per process; " + "please inspect pytest configuration." + ) + + +def _generate_test_suite_descriptors_from( + eth2_fixture_request: Dict[str, Any] +) -> Tuple[TestSuiteDescriptor, ...]: + if "config_types" in eth2_fixture_request: + config_types = eth2_fixture_request["config_types"] + # NOTE: in an ideal world, a user of the test generator can + # specify multiple types of config in one test run. They could also specify + # multiple test runs w/ disparate configurations. Given the way we currently + # handle setting SSZ bounds (globally!), we have to enforce the invariant of only + # one type of config per process. + if len(config_types) != 1: + raise Exception( + "only run one config type per process, due to overwriting SSZ bounds" + ) + config_type = config_types[0] + _add_config_type_to_tracking_set(config_type) + _check_only_one_config_type(config_type) + else: + config_types = (None,) + + test_types = eth2_fixture_request["test_types"] + + # special case only one handler, "core" + if not isinstance(test_types, Dict): + test_types = { + _type: lambda handler: handler.name == "core" for _type in test_types + } + + selected_handlers: Tuple[Tuple[TestType[Any], TestHandler[Any, Any]], ...] = tuple() + for test_type, handler_filter in test_types.items(): + for handler in test_type.handlers: + if handler_filter(handler) or handler.name == "core": + selected_handler = (test_type, handler) + selected_handlers += selected_handler + result: Iterator[Any] = itertools.product((selected_handlers,), config_types) + return tuple(result) + + +def _generate_pytest_case_from( + test_type: TestType[HandlerType], + handler_type: TestHandler[Input, Output], + config_type: ConfigType, + test_case: TestCase, +) -> Tuple[TestCase, str]: + # special case only one handler "core" + test_name = test_type.name + if len(test_type.handlers) == 1 or handler_type.name == "core": + handler_name = "" + else: + handler_name = handler_type.name + + if config_type: + config_name = config_type.name + else: + config_name = "" + + test_id_prefix = thread_last( + (test_name, handler_name, config_name), + (filter, lambda component: component != ""), + lambda components: "_".join(components), + ) + test_id = f"{test_id_prefix}.yaml:{test_case.index}" + + if test_case.description: + test_id += f":{test_case.description}" + return test_case, test_id + + +def _generate_pytest_cases_from_test_suite_descriptors( + test_suite_descriptors: Tuple[TestSuiteDescriptor, ...] +) -> Generator[Tuple[TestCase, str], None, None]: + for (test_type, handler_type), config_type in test_suite_descriptors: + test_suite = parse_test_suite(test_type, handler_type, config_type) + for test_case in test_suite: + yield _generate_pytest_case_from( + test_type, handler_type, config_type, test_case + ) + + +def generate_pytests_from_eth2_fixture(metafunc: Metafunc) -> None: + """ + Generate all the test cases requested by the config (attached to ``metafunc``'s + function object) and inject them via ``metafunc.parametrize``. + """ + eth2_fixture_request = _read_request_from_metafunc(metafunc) + test_suite_descriptors = _generate_test_suite_descriptors_from(eth2_fixture_request) + pytest_cases = tuple( + _generate_pytest_cases_from_test_suite_descriptors(test_suite_descriptors) + ) + if pytest_cases: + argvals, ids = zip(*pytest_cases) + else: + argvals, ids = (), () + + metafunc.parametrize("test_case", argvals, ids=ids) diff --git a/eth2/beacon/tools/fixtures/test_handler.py b/eth2/beacon/tools/fixtures/test_handler.py new file mode 100644 index 0000000000..06892ae25c --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_handler.py @@ -0,0 +1,36 @@ +import abc +from abc import abstractmethod +from typing import Any, Dict, Generic, TypeVar + +from eth2.configs import Eth2Config + +Input = TypeVar("Input") +Output = TypeVar("Output") + + +class TestHandler(abc.ABC, Generic[Input, Output]): + name: str + + @classmethod + @abstractmethod + def parse_inputs(cls, test_case_data: Dict[str, Any]) -> Input: + ... + + @staticmethod + @abstractmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> Output: + ... + + @staticmethod + def valid(data: Dict[str, Any]) -> bool: + return True + + @classmethod + @abstractmethod + def run_with(cls, inputs: Input, config: Eth2Config) -> Output: + ... + + @staticmethod + @abstractmethod + def condition(output: Output, expected_output: Output) -> None: + ... diff --git a/eth2/beacon/tools/fixtures/test_types/__init__.py b/eth2/beacon/tools/fixtures/test_types/__init__.py new file mode 100644 index 0000000000..0e73c2865c --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/__init__.py @@ -0,0 +1,34 @@ +import abc +from pathlib import Path +from typing import Generic, Optional, Sized, TypeVar + +from eth2.beacon.tools.fixtures.config_types import ConfigType +from eth2.beacon.tools.fixtures.test_handler import Input, Output, TestHandler + +HandlerType = TypeVar("HandlerType", bound=Sized) + + +class TestType(abc.ABC, Generic[HandlerType]): + name: str + handlers: HandlerType + + @classmethod + def build_path( + cls, + tests_root_path: Path, + test_handler: TestHandler[Input, Output], + config_type: Optional[ConfigType], + ) -> Path: + if len(cls.handlers) == 1: + file_name = f"{cls.name}" + else: + file_name = f"{cls.name}_{test_handler.name}" + + if config_type: + file_name += f"_{config_type.name}" + + file_name += ".yaml" + + return ( + tests_root_path / Path(cls.name) / Path(test_handler.name) / Path(file_name) + ) diff --git a/eth2/beacon/tools/fixtures/test_types/bls.py b/eth2/beacon/tools/fixtures/test_types/bls.py new file mode 100644 index 0000000000..87d46b1b14 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/bls.py @@ -0,0 +1,158 @@ +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Type, Union, cast + +from py_ecc.bls.typing import Domain + +from eth2._utils.bls import BLSPubkey, BLSSignature, Hash32, bls +from eth2._utils.bls.backends import MilagroBackend +from eth2.beacon.tools.fixtures.config_types import ConfigType +from eth2.beacon.tools.fixtures.loading import ( + get_input_bls_privkey, + get_input_bls_pubkeys, + get_input_bls_signatures, + get_input_sign_message, + get_output_bls_pubkey, + get_output_bls_signature, +) +from eth2.beacon.tools.fixtures.test_handler import Input, Output, TestHandler +from eth2.configs import Eth2Config + +from . import TestType + +SequenceOfBLSPubkey = Tuple[BLSPubkey, ...] +SequenceOfBLSSignature = Tuple[BLSSignature, ...] +SignatureDescriptor = Dict[str, Union[int, bytes]] + + +class AggregatePubkeysHandler(TestHandler[SequenceOfBLSPubkey, BLSPubkey]): + name = "aggregate_pubkeys" + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> SequenceOfBLSPubkey: + return get_input_bls_pubkeys(test_case_data)["pubkeys"] + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BLSPubkey: + return get_output_bls_pubkey(test_case_data) + + @classmethod + def run_with(_cls, inputs: SequenceOfBLSPubkey, _config: Eth2Config) -> BLSPubkey: + # BLS override + bls.use(MilagroBackend) + + return bls.aggregate_pubkeys(inputs) + + @staticmethod + def condition(output: BLSPubkey, expected_output: BLSPubkey) -> None: + assert output == expected_output + + +class AggregateSignaturesHandler(TestHandler[SequenceOfBLSSignature, BLSSignature]): + name = "aggregate_sigs" + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> SequenceOfBLSSignature: + return get_input_bls_signatures(test_case_data)["signatures"] + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BLSSignature: + return get_output_bls_signature(test_case_data) + + @classmethod + def run_with( + _cls, inputs: SequenceOfBLSSignature, _config: Eth2Config + ) -> BLSSignature: + # BLS override + bls.use(MilagroBackend) + + return bls.aggregate_signatures(inputs) + + @staticmethod + def condition(output: BLSSignature, expected_output: BLSSignature) -> None: + assert output == expected_output + + +class PrivateToPublicKeyHandler(TestHandler[int, BLSPubkey]): + name = "priv_to_pub" + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> int: + return get_input_bls_privkey(test_case_data)["privkey"] + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BLSPubkey: + return get_output_bls_pubkey(test_case_data) + + @classmethod + def run_with(_cls, inputs: int, _config: Eth2Config) -> BLSPubkey: + # BLS override + bls.use(MilagroBackend) + + return bls.privtopub(inputs) + + @staticmethod + def condition(output: BLSPubkey, expected_output: BLSPubkey) -> None: + assert output == expected_output + + +class SignMessageHandler(TestHandler[SignatureDescriptor, BLSSignature]): + name = "sign_msg" + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> SignatureDescriptor: + return get_input_sign_message(test_case_data) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BLSSignature: + return get_output_bls_signature(test_case_data) + + @classmethod + def run_with( + _cls, inputs: SignatureDescriptor, _config: Eth2Config + ) -> BLSSignature: + # BLS override + bls.use(MilagroBackend) + + return bls.sign( + cast(Hash32, inputs["message_hash"]), + int(inputs["privkey"]), + cast(Domain, (inputs["domain"])), + ) + + @staticmethod + def condition(output: BLSSignature, expected_output: BLSSignature) -> None: + assert output == expected_output + + +BLSHandlerType = Tuple[ + Type[AggregatePubkeysHandler], + Type[AggregateSignaturesHandler], + Type[PrivateToPublicKeyHandler], + Type[SignMessageHandler], +] + + +class BLSTestType(TestType[BLSHandlerType]): + name = "bls" + + handlers = ( + AggregatePubkeysHandler, + AggregateSignaturesHandler, + # MsgHashG2CompressedHandler, # NOTE: not exposed via public API in py_ecc + # MsgHashG2UncompressedHandler, # NOTE: not exposed via public API in py_ecc + PrivateToPublicKeyHandler, + SignMessageHandler, + ) + + @classmethod + def build_path( + cls, + tests_root_path: Path, + test_handler: TestHandler[Input, Output], + config_type: Optional[ConfigType], + ) -> Path: + file_name = f"{test_handler.name}.yaml" + + return ( + tests_root_path / Path(cls.name) / Path(test_handler.name) / Path(file_name) + ) diff --git a/eth2/beacon/tools/fixtures/test_types/epoch_processing.py b/eth2/beacon/tools/fixtures/test_types/epoch_processing.py new file mode 100644 index 0000000000..b66c4040a5 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/epoch_processing.py @@ -0,0 +1,98 @@ +from pathlib import Path +from typing import Any, Callable, Dict, Tuple, Type + +from ssz.tools import from_formatted_dict + +from eth2.beacon.state_machines.forks.serenity.epoch_processing import ( + process_crosslinks, + process_final_updates, + process_justification_and_finalization, + process_registry_updates, + process_slashings, +) +from eth2.beacon.tools.fixtures.conditions import validate_state +from eth2.beacon.tools.fixtures.config_types import ConfigType +from eth2.beacon.tools.fixtures.test_handler import Input, Output, TestHandler +from eth2.beacon.types.states import BeaconState +from eth2.configs import Eth2Config + +from . import TestType + + +class EpochProcessingHandler(TestHandler[BeaconState, BeaconState]): + processor: Callable[[BeaconState, Eth2Config], BeaconState] + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> BeaconState: + return from_formatted_dict(test_case_data["pre"], BeaconState) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BeaconState: + return from_formatted_dict(test_case_data["post"], BeaconState) + + @classmethod + def run_with(cls, inputs: BeaconState, config: Eth2Config) -> BeaconState: + state = inputs + return cls.processor(state, config) + + @staticmethod + def condition(output: BeaconState, expected_output: BeaconState) -> None: + validate_state(output, expected_output) + + +class JustificationAndFinalizationHandler(EpochProcessingHandler): + name = "justification_and_finalization" + processor = process_justification_and_finalization + + +class CrosslinksHandler(EpochProcessingHandler): + name = "crosslinks" + processor = process_crosslinks + + +class RegistryUpdatesHandler(EpochProcessingHandler): + name = "registry_updates" + processor = process_registry_updates + + +class SlashingsHandler(EpochProcessingHandler): + name = "slashings" + processor = process_slashings + + +class FinalUpdatesHandler(EpochProcessingHandler): + name = "final_updates" + processor = process_final_updates + + +EpochProcessingHandlerType = Tuple[ + Type[JustificationAndFinalizationHandler], + Type[CrosslinksHandler], + Type[RegistryUpdatesHandler], + Type[SlashingsHandler], + Type[FinalUpdatesHandler], +] + + +class EpochProcessingTestType(TestType[EpochProcessingHandlerType]): + name = "epoch_processing" + + handlers = ( + JustificationAndFinalizationHandler, + CrosslinksHandler, + RegistryUpdatesHandler, + SlashingsHandler, + FinalUpdatesHandler, + ) + + @classmethod + def build_path( + cls, + tests_root_path: Path, + test_handler: TestHandler[Input, Output], + config_type: ConfigType, + ) -> Path: + file_name = f"{test_handler.name}_{config_type.name}.yaml" + return ( + tests_root_path / Path(cls.name) / Path(test_handler.name) / Path(file_name) + ) diff --git a/eth2/beacon/tools/fixtures/test_types/genesis.py b/eth2/beacon/tools/fixtures/test_types/genesis.py new file mode 100644 index 0000000000..5bfa1e4f47 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/genesis.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, Tuple, Type, cast + +from eth_typing import Hash32 +from eth_utils import decode_hex +from ssz.tools import from_formatted_dict + +from eth2.beacon.genesis import ( + initialize_beacon_state_from_eth1, + is_valid_genesis_state, +) +from eth2.beacon.tools.fixtures.conditions import validate_state +from eth2.beacon.tools.fixtures.test_handler import TestHandler +from eth2.beacon.types.deposits import Deposit +from eth2.beacon.types.states import BeaconState +from eth2.beacon.typing import Timestamp +from eth2.configs import Eth2Config + +from . import TestType + + +class ValidityHandler(TestHandler[BeaconState, bool]): + name = "validity" + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> BeaconState: + return from_formatted_dict(test_case_data["genesis"], BeaconState) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> bool: + return bool(test_case_data["is_valid"]) + + @classmethod + def run_with(_cls, genesis_state: BeaconState, config: Eth2Config) -> bool: + return is_valid_genesis_state(genesis_state, config) + + @staticmethod + def condition(output: bool, expected_output: bool) -> None: + assert output == expected_output + + +class InitializationHandler( + TestHandler[Tuple[Hash32, Timestamp, Tuple[Deposit, ...]], BeaconState] +): + name = "initialization" + + @classmethod + def parse_inputs( + _cls, test_case_data: Dict[str, Any] + ) -> Tuple[Hash32, Timestamp, Tuple[Deposit, ...]]: + return ( + cast(Hash32, decode_hex(test_case_data["eth1_block_hash"])), + Timestamp(test_case_data["eth1_timestamp"]), + tuple( + cast(Deposit, from_formatted_dict(deposit_data, Deposit)) + for deposit_data in test_case_data["deposits"] + ), + ) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BeaconState: + return from_formatted_dict(test_case_data["state"], BeaconState) + + @classmethod + def run_with( + _cls, inputs: Tuple[Hash32, Timestamp, Tuple[Deposit, ...]], config: Eth2Config + ) -> BeaconState: + eth1_block_hash, eth1_timestamp, deposits = inputs + + return initialize_beacon_state_from_eth1( + eth1_block_hash=eth1_block_hash, + eth1_timestamp=eth1_timestamp, + deposits=deposits, + config=config, + ) + + @staticmethod + def condition(output: BeaconState, expected_output: BeaconState) -> None: + validate_state(output, expected_output) + + +GenesisHandlerType = Tuple[Type[ValidityHandler], Type[InitializationHandler]] + + +class GenesisTestType(TestType[GenesisHandlerType]): + name = "genesis" + + handlers = (ValidityHandler, InitializationHandler) diff --git a/eth2/beacon/tools/fixtures/test_types/operations.py b/eth2/beacon/tools/fixtures/test_types/operations.py new file mode 100644 index 0000000000..85922e8287 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/operations.py @@ -0,0 +1,201 @@ +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Type, Union + +from eth_utils import ValidationError +import ssz +from ssz.tools import from_formatted_dict + +from eth2._utils.bls import SignatureError +from eth2.beacon.state_machines.forks.serenity.block_processing import ( + process_block_header, +) +from eth2.beacon.state_machines.forks.serenity.operation_processing import ( + process_attestations, + process_attester_slashings, + process_deposits, + process_proposer_slashings, + process_transfers, + process_voluntary_exits, +) +from eth2.beacon.tools.fixtures.conditions import validate_state +from eth2.beacon.tools.fixtures.config_types import ConfigType +from eth2.beacon.tools.fixtures.test_handler import Input, Output, TestHandler +from eth2.beacon.types.attestations import Attestation +from eth2.beacon.types.attester_slashings import AttesterSlashing +from eth2.beacon.types.blocks import BeaconBlock, BeaconBlockBody +from eth2.beacon.types.deposits import Deposit +from eth2.beacon.types.proposer_slashings import ProposerSlashing +from eth2.beacon.types.states import BeaconState +from eth2.beacon.types.transfers import Transfer +from eth2.beacon.types.voluntary_exits import VoluntaryExit +from eth2.configs import Eth2Config + +from . import TestType + +Operation = Union[ + ProposerSlashing, AttesterSlashing, Attestation, Deposit, VoluntaryExit, Transfer +] +OperationOrBlockHeader = Union[Operation, BeaconBlock] + + +class OperationHandler( + TestHandler[Tuple[BeaconState, OperationOrBlockHeader], BeaconState] +): + name: str + operation_name: Optional[str] + operation_type: ssz.Serializable + processor: staticmethod # Optional[Callable[[BeaconState, BeaconBlock, Eth2Config], BeaconState]] # noqa: E501 + expected_exceptions: Tuple[Type[Exception], ...] = () + + @classmethod + def parse_inputs(cls, test_case_data: Dict[str, Any]) -> Tuple[BeaconState, Any]: + operation_name = ( + cls.operation_name if hasattr(cls, "operation_name") else cls.name + ) + return ( + from_formatted_dict(test_case_data["pre"], BeaconState), + from_formatted_dict(test_case_data[operation_name], cls.operation_type), + ) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BeaconState: + return from_formatted_dict(test_case_data["post"], BeaconState) + + @staticmethod + def valid(test_case_data: Dict[str, Any]) -> bool: + return bool(test_case_data["post"]) + + @classmethod + def _update_config_if_needed(cls, config: Eth2Config) -> Eth2Config: + """ + Some ad-hoc work arounds... + + - Increase the count of allowed Transfer operations, even though we start with 0. + """ + if cls.name == "transfer": + return config._replace(MAX_TRANSFERS=1) + return config + + @classmethod + def run_with( + cls, inputs: Tuple[BeaconState, OperationOrBlockHeader], config: Eth2Config + ) -> BeaconState: + config = cls._update_config_if_needed(config) + state, operation = inputs + # NOTE: we do not have an easy way to evaluate a single operation on the state + # So, we wrap it in a beacon block. The following statement lets us rely on + # the config given in a particular handler class while working w/in the + # update API provided by `py-ssz`. + # NOTE: we ignore the type here, otherwise need to spell out each of the keyword + # arguments individually... save some work and just build them dynamically + block = BeaconBlock( + body=BeaconBlockBody(**{f"{cls.name}s": (operation,)}) # type: ignore + ) + try: + return cls.processor(state, block, config) + except ValidationError as e: + # if already a ValidationError, re-raise + raise e + except Exception as e: + # check if the exception is expected... + for exception in cls.expected_exceptions: + if isinstance(e, exception): + raise ValidationError(e) + # else raise (and fail the pytest test case ...) + raise e + + @staticmethod + def condition(output: BeaconState, expected_output: BeaconState) -> None: + validate_state(output, expected_output) + + +class AttestationHandler(OperationHandler): + name = "attestation" + operation_type = Attestation + processor = staticmethod(process_attestations) + expected_exceptions = (IndexError,) + + +class AttesterSlashingHandler(OperationHandler): + name = "attester_slashing" + operation_type = AttesterSlashing + processor = staticmethod(process_attester_slashings) + expected_exceptions = (SignatureError,) + + +class BlockHeaderHandler(OperationHandler): + name = "block_header" + operation_name = "block" + operation_type = BeaconBlock + + @classmethod + def run_with( + _cls, inputs: Tuple[BeaconState, BeaconBlock], config: Eth2Config + ) -> BeaconState: + state, block = inputs + check_proposer_signature = True + return process_block_header(state, block, config, check_proposer_signature) + + +class DepositHandler(OperationHandler): + name = "deposit" + operation_type = Deposit + processor = staticmethod(process_deposits) + + +class ProposerSlashingHandler(OperationHandler): + name = "proposer_slashing" + operation_type = ProposerSlashing + processor = staticmethod(process_proposer_slashings) + expected_exceptions = (IndexError,) + + +class TransferHandler(OperationHandler): + name = "transfer" + operation_type = Transfer + processor = staticmethod(process_transfers) + expected_exceptions = (IndexError,) + + +class VoluntaryExitHandler(OperationHandler): + name = "voluntary_exit" + operation_type = VoluntaryExit + processor = staticmethod(process_voluntary_exits) + expected_exceptions = (IndexError,) + + +OperationsHandlerType = Tuple[ + Type[AttestationHandler], + Type[AttesterSlashingHandler], + Type[BlockHeaderHandler], + Type[DepositHandler], + Type[ProposerSlashingHandler], + Type[TransferHandler], + Type[VoluntaryExitHandler], +] + + +class OperationsTestType(TestType[OperationsHandlerType]): + name = "operations" + + handlers = ( + AttestationHandler, + AttesterSlashingHandler, + BlockHeaderHandler, + DepositHandler, + ProposerSlashingHandler, + TransferHandler, + VoluntaryExitHandler, + ) + + @classmethod + def build_path( + cls, + tests_root_path: Path, + test_handler: TestHandler[Input, Output], + config_type: ConfigType, + ) -> Path: + file_name = f"{test_handler.name}_{config_type.name}.yaml" + return ( + tests_root_path / Path(cls.name) / Path(test_handler.name) / Path(file_name) + ) diff --git a/eth2/beacon/tools/fixtures/test_types/sanity.py b/eth2/beacon/tools/fixtures/test_types/sanity.py new file mode 100644 index 0000000000..d3c6a6a532 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/sanity.py @@ -0,0 +1,97 @@ +from typing import Any, Dict, Tuple, Type + +from eth_utils import ValidationError +from ssz.tools import from_formatted_dict + +from eth2.beacon.state_machines.forks.serenity.slot_processing import process_slots +from eth2.beacon.state_machines.forks.serenity.state_transitions import ( + SerenityStateTransition, +) +from eth2.beacon.tools.fixtures.conditions import validate_state +from eth2.beacon.tools.fixtures.test_handler import TestHandler +from eth2.beacon.types.blocks import BeaconBlock +from eth2.beacon.types.states import BeaconState +from eth2.beacon.typing import Slot +from eth2.configs import Eth2Config + +from . import TestType + + +class BlocksHandler( + TestHandler[Tuple[BeaconState, Tuple[BeaconBlock, ...]], BeaconState] +): + name = "blocks" + + @classmethod + def parse_inputs( + _cls, test_case_data: Dict[str, Any] + ) -> Tuple[BeaconState, Tuple[BeaconBlock, ...]]: + return ( + from_formatted_dict(test_case_data["pre"], BeaconState), + tuple( + from_formatted_dict(block_data, BeaconBlock) + for block_data in test_case_data["blocks"] + ), + ) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BeaconState: + return from_formatted_dict(test_case_data["post"], BeaconState) + + @staticmethod + def valid(test_case_data: Dict[str, Any]) -> bool: + return bool(test_case_data["post"]) + + @classmethod + def run_with( + _cls, inputs: Tuple[BeaconState, Tuple[BeaconBlock, ...]], config: Eth2Config + ) -> BeaconState: + state, blocks = inputs + state_transition = SerenityStateTransition(config) + for block in blocks: + state = state_transition.apply_state_transition(state, block) + if block.state_root != state.hash_tree_root: + raise ValidationError( + "block's state root did not match computed state root" + ) + return state + + @staticmethod + def condition(output: BeaconState, expected_output: BeaconState) -> None: + validate_state(output, expected_output) + + +class SlotsHandler(TestHandler[Tuple[BeaconState, int], BeaconState]): + name = "slots" + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> Tuple[BeaconState, int]: + return ( + from_formatted_dict(test_case_data["pre"], BeaconState), + test_case_data["slots"], + ) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> BeaconState: + return from_formatted_dict(test_case_data["post"], BeaconState) + + @classmethod + def run_with( + _cls, inputs: Tuple[BeaconState, int], config: Eth2Config + ) -> BeaconState: + state, offset = inputs + target_slot = Slot(state.slot + offset) + return process_slots(state, target_slot, config) + + @staticmethod + def condition(output: BeaconState, expected_output: BeaconState) -> None: + validate_state(output, expected_output) + + +SanityHandlerType = Tuple[Type[BlocksHandler], Type[SlotsHandler]] + + +class SanityTestType(TestType[SanityHandlerType]): + name = "sanity" + + handlers = (BlocksHandler, SlotsHandler) diff --git a/eth2/beacon/tools/fixtures/test_types/shuffling.py b/eth2/beacon/tools/fixtures/test_types/shuffling.py new file mode 100644 index 0000000000..f9492aeb91 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/shuffling.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, Tuple, Type + +from eth_utils import decode_hex + +from eth2.beacon.committee_helpers import compute_shuffled_index +from eth2.beacon.tools.fixtures.test_handler import TestHandler +from eth2.configs import Eth2Config + +from . import TestType + + +class CoreHandler(TestHandler[Tuple[int, bytes], Tuple[int, ...]]): + name = "core" + + @classmethod + def parse_inputs(_cls, test_case_data: Dict[str, Any]) -> Tuple[int, bytes]: + return (test_case_data["count"], decode_hex(test_case_data["seed"])) + + @staticmethod + def parse_outputs(test_case_data: Dict[str, Any]) -> Tuple[int, ...]: + return tuple(int(data) for data in test_case_data["shuffled"]) + + @classmethod + def run_with(_cls, inputs: Any, config: Eth2Config) -> Tuple[int, ...]: + count, seed = inputs + return tuple( + compute_shuffled_index(index, count, seed, config.SHUFFLE_ROUND_COUNT) + for index in range(count) + ) + + @staticmethod + def condition(output: Any, expected_output: Any) -> None: + assert output == expected_output + + +ShufflingHandlerType = Tuple[Type[CoreHandler]] + + +class ShufflingTestType(TestType[ShufflingHandlerType]): + name = "shuffling" + + handlers = (CoreHandler,) diff --git a/eth2/beacon/tools/fixtures/test_types/ssz_generic.py b/eth2/beacon/tools/fixtures/test_types/ssz_generic.py new file mode 100644 index 0000000000..39d5acbe05 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/ssz_generic.py @@ -0,0 +1,10 @@ +from typing import Tuple + +from . import TestType + +# temporary +SSZGenericHandlerType = Tuple[None] + + +class SSZGeneric(TestType[SSZGenericHandlerType]): + pass diff --git a/eth2/beacon/tools/fixtures/test_types/ssz_static.py b/eth2/beacon/tools/fixtures/test_types/ssz_static.py new file mode 100644 index 0000000000..dca7287ad0 --- /dev/null +++ b/eth2/beacon/tools/fixtures/test_types/ssz_static.py @@ -0,0 +1,10 @@ +from typing import Tuple + +from . import TestType + +# temporary +SSZStaticHandlerType = Tuple[None] + + +class SSZStatic(TestType[SSZStaticHandlerType]): + pass diff --git a/eth2/beacon/tools/misc/ssz_vector.py b/eth2/beacon/tools/misc/ssz_vector.py index 2d564f0847..3f4729023f 100644 --- a/eth2/beacon/tools/misc/ssz_vector.py +++ b/eth2/beacon/tools/misc/ssz_vector.py @@ -3,16 +3,13 @@ import ssz import ssz.sedes as sedes -from eth2.configs import ( - Eth2Config, -) - from eth2.beacon.types.attestations import Attestation, IndexedAttestation from eth2.beacon.types.blocks import BeaconBlockBody from eth2.beacon.types.compact_committees import CompactCommittee from eth2.beacon.types.historical_batch import HistoricalBatch from eth2.beacon.types.pending_attestations import PendingAttestation from eth2.beacon.types.states import BeaconState +from eth2.configs import Eth2Config def _mk_overrides(config: Eth2Config) -> Dict[ssz.Serializable, Dict[str, int]]: @@ -40,8 +37,10 @@ def _mk_overrides(config: Eth2Config) -> Dict[ssz.Serializable, Dict[str, int]]: "active_index_roots": config.EPOCHS_PER_HISTORICAL_VECTOR, "compact_committees_roots": config.EPOCHS_PER_HISTORICAL_VECTOR, "slashings": config.EPOCHS_PER_SLASHINGS_VECTOR, - "previous_epoch_attestations": config.MAX_ATTESTATIONS * config.SLOTS_PER_EPOCH, - "current_epoch_attestations": config.MAX_ATTESTATIONS * config.SLOTS_PER_EPOCH, + "previous_epoch_attestations": config.MAX_ATTESTATIONS + * config.SLOTS_PER_EPOCH, + "current_epoch_attestations": config.MAX_ATTESTATIONS + * config.SLOTS_PER_EPOCH, "previous_crosslinks": config.SHARD_COUNT, "current_crosslinks": config.SHARD_COUNT, }, @@ -57,9 +56,7 @@ def _mk_overrides(config: Eth2Config) -> Dict[ssz.Serializable, Dict[str, int]]: "custody_bit_0_indices": config.MAX_VALIDATORS_PER_COMMITTEE, "custody_bit_1_indices": config.MAX_VALIDATORS_PER_COMMITTEE, }, - PendingAttestation: { - "aggregation_bits": config.MAX_VALIDATORS_PER_COMMITTEE, - }, + PendingAttestation: {"aggregation_bits": config.MAX_VALIDATORS_PER_COMMITTEE}, } diff --git a/eth2/beacon/types/attestation_data.py b/eth2/beacon/types/attestation_data.py index 0019804d0a..78e87770d6 100644 --- a/eth2/beacon/types/attestation_data.py +++ b/eth2/beacon/types/attestation_data.py @@ -1,48 +1,32 @@ -from eth.constants import ( - ZERO_HASH32, -) - -from eth_typing import ( - Hash32, -) - +from eth.constants import ZERO_HASH32 +from eth_typing import Hash32 +from eth_utils import humanize_hash import ssz -from ssz.sedes import ( - bytes32, -) - -from eth2.beacon.types.checkpoints import ( - Checkpoint, - default_checkpoint, -) -from eth2.beacon.types.crosslinks import ( - Crosslink, - default_crosslink, -) -from eth_utils import ( - humanize_hash, -) +from ssz.sedes import bytes32 + +from eth2.beacon.types.checkpoints import Checkpoint, default_checkpoint +from eth2.beacon.types.crosslinks import Crosslink, default_crosslink class AttestationData(ssz.Serializable): fields = [ # LMD GHOST vote - ('beacon_block_root', bytes32), - + ("beacon_block_root", bytes32), # FFG vote - ('source', Checkpoint), - ('target', Checkpoint), - + ("source", Checkpoint), + ("target", Checkpoint), # Crosslink vote - ('crosslink', Crosslink), + ("crosslink", Crosslink), ] - def __init__(self, - beacon_block_root: Hash32=ZERO_HASH32, - source: Checkpoint=default_checkpoint, - target: Checkpoint=default_checkpoint, - crosslink: Crosslink=default_crosslink) -> None: + def __init__( + self, + beacon_block_root: Hash32 = ZERO_HASH32, + source: Checkpoint = default_checkpoint, + target: Checkpoint = default_checkpoint, + crosslink: Crosslink = default_crosslink, + ) -> None: super().__init__( beacon_block_root=beacon_block_root, source=source, diff --git a/eth2/beacon/types/attestation_data_and_custody_bits.py b/eth2/beacon/types/attestation_data_and_custody_bits.py index 22c80b4e14..d016d1b83d 100644 --- a/eth2/beacon/types/attestation_data_and_custody_bits.py +++ b/eth2/beacon/types/attestation_data_and_custody_bits.py @@ -1,27 +1,21 @@ import ssz -from ssz.sedes import ( - boolean, -) +from ssz.sedes import boolean -from .attestation_data import ( - AttestationData, - default_attestation_data, -) +from .attestation_data import AttestationData, default_attestation_data class AttestationDataAndCustodyBit(ssz.Serializable): fields = [ # Attestation data - ('data', AttestationData), + ("data", AttestationData), # Custody bit - ('custody_bit', boolean), + ("custody_bit", boolean), ] - def __init__(self, - data: AttestationData=default_attestation_data, - custody_bit: bool=False)-> None: - super().__init__( - data=data, - custody_bit=custody_bit, - ) + def __init__( + self, + data: AttestationData = default_attestation_data, + custody_bit: bool = False, + ) -> None: + super().__init__(data=data, custody_bit=custody_bit) diff --git a/eth2/beacon/types/attestations.py b/eth2/beacon/types/attestations.py index 195fdc23eb..684462a84d 100644 --- a/eth2/beacon/types/attestations.py +++ b/eth2/beacon/types/attestations.py @@ -1,54 +1,33 @@ -from typing import ( - Sequence, -) +from typing import Sequence +from eth_typing import BLSSignature import ssz -from ssz.sedes import ( - Bitlist, - bytes96, - List, - uint64, -) +from ssz.sedes import Bitlist, List, bytes96, uint64 -from .attestation_data import ( - AttestationData, - default_attestation_data, -) - -from eth2.beacon.typing import ( - Bitfield, - ValidatorIndex, -) from eth2.beacon.constants import EMPTY_SIGNATURE -from eth_typing import ( - BLSSignature, -) +from eth2.beacon.typing import Bitfield, ValidatorIndex -from .defaults import ( - default_bitfield, - default_tuple, -) +from .attestation_data import AttestationData, default_attestation_data +from .defaults import default_bitfield, default_tuple class Attestation(ssz.Serializable): fields = [ - ('aggregation_bits', Bitlist(1)), - ('data', AttestationData), - ('custody_bits', Bitlist(1)), - ('signature', bytes96), + ("aggregation_bits", Bitlist(1)), + ("data", AttestationData), + ("custody_bits", Bitlist(1)), + ("signature", bytes96), ] - def __init__(self, aggregation_bits: Bitfield=default_bitfield, - data: AttestationData=default_attestation_data, - custody_bits: Bitfield=default_bitfield, - signature: BLSSignature=EMPTY_SIGNATURE) -> None: - super().__init__( - aggregation_bits, - data, - custody_bits, - signature, - ) + def __init__( + self, + aggregation_bits: Bitfield = default_bitfield, + data: AttestationData = default_attestation_data, + custody_bits: Bitfield = default_bitfield, + signature: BLSSignature = EMPTY_SIGNATURE, + ) -> None: + super().__init__(aggregation_bits, data, custody_bits, signature) def __repr__(self) -> str: return f"" @@ -58,25 +37,22 @@ class IndexedAttestation(ssz.Serializable): fields = [ # Validator indices - ('custody_bit_0_indices', List(uint64, 1)), - ('custody_bit_1_indices', List(uint64, 1)), + ("custody_bit_0_indices", List(uint64, 1)), + ("custody_bit_1_indices", List(uint64, 1)), # Attestation data - ('data', AttestationData), + ("data", AttestationData), # Aggregate signature - ('signature', bytes96), + ("signature", bytes96), ] - def __init__(self, - custody_bit_0_indices: Sequence[ValidatorIndex]=default_tuple, - custody_bit_1_indices: Sequence[ValidatorIndex]=default_tuple, - data: AttestationData=default_attestation_data, - signature: BLSSignature=EMPTY_SIGNATURE) -> None: - super().__init__( - custody_bit_0_indices, - custody_bit_1_indices, - data, - signature, - ) + def __init__( + self, + custody_bit_0_indices: Sequence[ValidatorIndex] = default_tuple, + custody_bit_1_indices: Sequence[ValidatorIndex] = default_tuple, + data: AttestationData = default_attestation_data, + signature: BLSSignature = EMPTY_SIGNATURE, + ) -> None: + super().__init__(custody_bit_0_indices, custody_bit_1_indices, data, signature) def __repr__(self) -> str: return f"" diff --git a/eth2/beacon/types/attester_slashings.py b/eth2/beacon/types/attester_slashings.py index 07f9bbd9e7..ebf832cb6f 100644 --- a/eth2/beacon/types/attester_slashings.py +++ b/eth2/beacon/types/attester_slashings.py @@ -1,24 +1,20 @@ import ssz -from .attestations import ( - IndexedAttestation, - default_indexed_attestation, -) +from .attestations import IndexedAttestation, default_indexed_attestation class AttesterSlashing(ssz.Serializable): fields = [ # First attestation - ('attestation_1', IndexedAttestation), + ("attestation_1", IndexedAttestation), # Second attestation - ('attestation_2', IndexedAttestation), + ("attestation_2", IndexedAttestation), ] - def __init__(self, - attestation_1: IndexedAttestation=default_indexed_attestation, - attestation_2: IndexedAttestation=default_indexed_attestation)-> None: - super().__init__( - attestation_1, - attestation_2, - ) + def __init__( + self, + attestation_1: IndexedAttestation = default_indexed_attestation, + attestation_2: IndexedAttestation = default_indexed_attestation, + ) -> None: + super().__init__(attestation_1, attestation_2) diff --git a/eth2/beacon/types/block_headers.py b/eth2/beacon/types/block_headers.py index 5cfac4d7d0..8fdd08905a 100644 --- a/eth2/beacon/types/block_headers.py +++ b/eth2/beacon/types/block_headers.py @@ -1,49 +1,34 @@ -from eth.constants import ( - ZERO_HASH32, -) - -from eth_typing import ( - BLSSignature, - Hash32, -) -from eth_utils import ( - encode_hex, -) - +from eth.constants import ZERO_HASH32 +from eth_typing import BLSSignature, Hash32 +from eth_utils import encode_hex import ssz -from ssz.sedes import ( - bytes32, - bytes96, - uint64, -) +from ssz.sedes import bytes32, bytes96, uint64 from eth2.beacon.constants import EMPTY_SIGNATURE -from eth2.beacon.typing import ( - Slot, -) +from eth2.beacon.typing import Slot -from .defaults import ( - default_slot, -) +from .defaults import default_slot class BeaconBlockHeader(ssz.SignedSerializable): fields = [ - ('slot', uint64), - ('parent_root', bytes32), - ('state_root', bytes32), - ('body_root', bytes32), - ('signature', bytes96), + ("slot", uint64), + ("parent_root", bytes32), + ("state_root", bytes32), + ("body_root", bytes32), + ("signature", bytes96), ] - def __init__(self, - *, - slot: Slot=default_slot, - parent_root: Hash32=ZERO_HASH32, - state_root: Hash32=ZERO_HASH32, - body_root: Hash32=ZERO_HASH32, - signature: BLSSignature=EMPTY_SIGNATURE): + def __init__( + self, + *, + slot: Slot = default_slot, + parent_root: Hash32 = ZERO_HASH32, + state_root: Hash32 = ZERO_HASH32, + body_root: Hash32 = ZERO_HASH32, + signature: BLSSignature = EMPTY_SIGNATURE, + ): super().__init__( slot=slot, parent_root=parent_root, @@ -54,9 +39,9 @@ def __init__(self, def __repr__(self) -> str: return ( - f'' + f"" ) diff --git a/eth2/beacon/types/blocks.py b/eth2/beacon/types/blocks.py index 0ac489b644..c2198b36fd 100644 --- a/eth2/beacon/types/blocks.py +++ b/eth2/beacon/types/blocks.py @@ -1,64 +1,26 @@ -from abc import ( - ABC, - abstractmethod, -) - -from typing import ( - Sequence, - TYPE_CHECKING, -) - -from eth.constants import ( - ZERO_HASH32, -) - -from eth_typing import ( - BLSSignature, - Hash32, -) -from eth_utils import ( - encode_hex, -) -import ssz -from ssz.sedes import ( - List, - bytes32, - bytes96, - uint64, -) - - -from eth._utils.datatypes import ( - Configurable, -) +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Sequence -from eth2.beacon.constants import ( - EMPTY_SIGNATURE, - GENESIS_PARENT_ROOT, -) -from eth2.beacon.typing import ( - Slot, - FromBlockParams, -) +from eth._utils.datatypes import Configurable +from eth.constants import ZERO_HASH32 +from eth_typing import BLSSignature, Hash32 +from eth_utils import encode_hex +import ssz +from ssz.sedes import List, bytes32, bytes96, uint64 +from eth2.beacon.constants import EMPTY_SIGNATURE, GENESIS_PARENT_ROOT +from eth2.beacon.typing import FromBlockParams, Slot from .attestations import Attestation from .attester_slashings import AttesterSlashing from .block_headers import BeaconBlockHeader -from .defaults import ( - default_tuple, - default_slot, -) +from .defaults import default_slot, default_tuple from .deposits import Deposit -from .eth1_data import ( - Eth1Data, - default_eth1_data, -) +from .eth1_data import Eth1Data, default_eth1_data from .proposer_slashings import ProposerSlashing from .transfers import Transfer from .voluntary_exits import VoluntaryExit - if TYPE_CHECKING: from eth2.beacon.db.chain import BaseBeaconChainDB # noqa: F401 @@ -66,28 +28,30 @@ class BeaconBlockBody(ssz.Serializable): fields = [ - ('randao_reveal', bytes96), - ('eth1_data', Eth1Data), - ('graffiti', bytes32), - ('proposer_slashings', List(ProposerSlashing, 1)), - ('attester_slashings', List(AttesterSlashing, 1)), - ('attestations', List(Attestation, 1)), - ('deposits', List(Deposit, 1)), - ('voluntary_exits', List(VoluntaryExit, 1)), - ('transfers', List(Transfer, 1)), + ("randao_reveal", bytes96), + ("eth1_data", Eth1Data), + ("graffiti", bytes32), + ("proposer_slashings", List(ProposerSlashing, 1)), + ("attester_slashings", List(AttesterSlashing, 1)), + ("attestations", List(Attestation, 1)), + ("deposits", List(Deposit, 1)), + ("voluntary_exits", List(VoluntaryExit, 1)), + ("transfers", List(Transfer, 1)), ] - def __init__(self, - *, - randao_reveal: bytes96=EMPTY_SIGNATURE, - eth1_data: Eth1Data=default_eth1_data, - graffiti: Hash32=ZERO_HASH32, - proposer_slashings: Sequence[ProposerSlashing]=default_tuple, - attester_slashings: Sequence[AttesterSlashing]=default_tuple, - attestations: Sequence[Attestation]=default_tuple, - deposits: Sequence[Deposit]=default_tuple, - voluntary_exits: Sequence[VoluntaryExit]=default_tuple, - transfers: Sequence[Transfer]=default_tuple)-> None: + def __init__( + self, + *, + randao_reveal: bytes96 = EMPTY_SIGNATURE, + eth1_data: Eth1Data = default_eth1_data, + graffiti: Hash32 = ZERO_HASH32, + proposer_slashings: Sequence[ProposerSlashing] = default_tuple, + attester_slashings: Sequence[AttesterSlashing] = default_tuple, + attestations: Sequence[Attestation] = default_tuple, + deposits: Sequence[Deposit] = default_tuple, + voluntary_exits: Sequence[VoluntaryExit] = default_tuple, + transfers: Sequence[Transfer] = default_tuple, + ) -> None: super().__init__( randao_reveal=randao_reveal, eth1_data=eth1_data, @@ -110,20 +74,22 @@ def is_empty(self) -> bool: class BaseBeaconBlock(ssz.SignedSerializable, Configurable, ABC): fields = [ - ('slot', uint64), - ('parent_root', bytes32), - ('state_root', bytes32), - ('body', BeaconBlockBody), - ('signature', bytes96), + ("slot", uint64), + ("parent_root", bytes32), + ("state_root", bytes32), + ("body", BeaconBlockBody), + ("signature", bytes96), ] - def __init__(self, - *, - slot: Slot=default_slot, - parent_root: Hash32=ZERO_HASH32, - state_root: Hash32=ZERO_HASH32, - body: BeaconBlockBody=default_beacon_block_body, - signature: BLSSignature=EMPTY_SIGNATURE) -> None: + def __init__( + self, + *, + slot: Slot = default_slot, + parent_root: Hash32 = ZERO_HASH32, + state_root: Hash32 = ZERO_HASH32, + body: BeaconBlockBody = default_beacon_block_body, + signature: BLSSignature = EMPTY_SIGNATURE, + ) -> None: super().__init__( slot=slot, parent_root=parent_root, @@ -134,12 +100,12 @@ def __init__(self, def __repr__(self) -> str: return ( - f'' + f"" ) @property @@ -158,7 +124,7 @@ def header(self) -> BeaconBlockHeader: @classmethod @abstractmethod - def from_root(cls, root: Hash32, chaindb: 'BaseBeaconChainDB') -> 'BaseBeaconBlock': + def from_root(cls, root: Hash32, chaindb: "BaseBeaconChainDB") -> "BaseBeaconBlock": """ Return the block denoted by the given block root. """ @@ -169,7 +135,7 @@ class BeaconBlock(BaseBeaconBlock): block_body_class = BeaconBlockBody @classmethod - def from_root(cls, root: Hash32, chaindb: 'BaseBeaconChainDB') -> 'BeaconBlock': + def from_root(cls, root: Hash32, chaindb: "BaseBeaconChainDB") -> "BeaconBlock": """ Return the block denoted by the given block ``root``. """ @@ -195,9 +161,9 @@ def from_root(cls, root: Hash32, chaindb: 'BaseBeaconChainDB') -> 'BeaconBlock': ) @classmethod - def from_parent(cls, - parent_block: 'BaseBeaconBlock', - block_params: FromBlockParams) -> 'BaseBeaconBlock': + def from_parent( + cls, parent_block: "BaseBeaconBlock", block_params: FromBlockParams + ) -> "BaseBeaconBlock": """ Initialize a new block with the ``parent_block`` as the block's previous block root. @@ -215,8 +181,7 @@ def from_parent(cls, ) @classmethod - def convert_block(cls, - block: 'BaseBeaconBlock') -> 'BeaconBlock': + def convert_block(cls, block: "BaseBeaconBlock") -> "BeaconBlock": return cls( slot=block.slot, parent_root=block.parent_root, @@ -226,7 +191,7 @@ def convert_block(cls, ) @classmethod - def from_header(cls, header: BeaconBlockHeader) -> 'BeaconBlock': + def from_header(cls, header: BeaconBlockHeader) -> "BeaconBlock": return cls( slot=header.slot, parent_root=header.parent_root, diff --git a/eth2/beacon/types/checkpoints.py b/eth2/beacon/types/checkpoints.py index d39c6fa6c7..fac7f881e7 100644 --- a/eth2/beacon/types/checkpoints.py +++ b/eth2/beacon/types/checkpoints.py @@ -1,43 +1,22 @@ -from eth_typing import ( - Hash32, -) -from eth_utils import ( - encode_hex, -) - -from eth.constants import ( - ZERO_HASH32, -) - +from eth.constants import ZERO_HASH32 +from eth_typing import Hash32 +from eth_utils import encode_hex import ssz -from ssz.sedes import ( - bytes32, - uint64, -) +from ssz.sedes import bytes32, uint64 -from eth2.beacon.typing import ( - Epoch, -) +from eth2.beacon.typing import Epoch -from .defaults import ( - default_epoch, -) +from .defaults import default_epoch class Checkpoint(ssz.Serializable): - fields = [ - ('epoch', uint64), - ('root', bytes32) - ] - - def __init__(self, - epoch: Epoch=default_epoch, - root: Hash32=ZERO_HASH32) -> None: - super().__init__( - epoch=epoch, - root=root, - ) + fields = [("epoch", uint64), ("root", bytes32)] + + def __init__( + self, epoch: Epoch = default_epoch, root: Hash32 = ZERO_HASH32 + ) -> None: + super().__init__(epoch=epoch, root=root) def __str__(self) -> str: return f"({self.epoch}, {encode_hex(self.root)[0:8]})" diff --git a/eth2/beacon/types/compact_committees.py b/eth2/beacon/types/compact_committees.py index fc140fc21b..3314944e7c 100644 --- a/eth2/beacon/types/compact_committees.py +++ b/eth2/beacon/types/compact_committees.py @@ -1,37 +1,22 @@ -from typing import ( - Sequence, -) - -from eth_typing import ( - BLSPubkey, -) +from typing import Sequence +from eth_typing import BLSPubkey import ssz -from ssz.sedes import ( - bytes48, - List, - uint64, -) +from ssz.sedes import List, bytes48, uint64 -from .defaults import ( - default_tuple, -) +from .defaults import default_tuple class CompactCommittee(ssz.Serializable): - fields = [ - ('pubkeys', List(bytes48, 1)), - ('compact_validators', List(uint64, 1)), - ] - - def __init__(self, - pubkeys: Sequence[BLSPubkey]=default_tuple, - compact_validators: Sequence[int]=default_tuple) -> None: - super().__init__( - pubkeys=pubkeys, - compact_validators=compact_validators, - ) + fields = [("pubkeys", List(bytes48, 1)), ("compact_validators", List(uint64, 1))] + + def __init__( + self, + pubkeys: Sequence[BLSPubkey] = default_tuple, + compact_validators: Sequence[int] = default_tuple, + ) -> None: + super().__init__(pubkeys=pubkeys, compact_validators=compact_validators) default_compact_committee = CompactCommittee() diff --git a/eth2/beacon/types/crosslinks.py b/eth2/beacon/types/crosslinks.py index 32fc90d5b7..d289c17038 100644 --- a/eth2/beacon/types/crosslinks.py +++ b/eth2/beacon/types/crosslinks.py @@ -1,49 +1,34 @@ -from eth_typing import ( - Hash32, -) - +from eth_typing import Hash32 +from eth_utils import encode_hex, humanize_hash import ssz -from ssz.sedes import ( - uint64, - bytes32, -) +from ssz.sedes import bytes32, uint64 -from eth2.beacon.constants import ( - ZERO_HASH32, -) -from eth2.beacon.typing import ( - Epoch, - Shard, -) -from eth_utils import ( - encode_hex, - humanize_hash, -) +from eth2.beacon.constants import ZERO_HASH32 +from eth2.beacon.typing import Epoch, Shard -from .defaults import ( - default_shard, - default_epoch, -) +from .defaults import default_epoch, default_shard class Crosslink(ssz.Serializable): fields = [ - ('shard', uint64), - ('parent_root', bytes32), + ("shard", uint64), + ("parent_root", bytes32), # Crosslinking data from epochs [start....end-1] - ('start_epoch', uint64), - ('end_epoch', uint64), + ("start_epoch", uint64), + ("end_epoch", uint64), # Root of the crosslinked shard data since the previous crosslink - ('data_root', bytes32), + ("data_root", bytes32), ] - def __init__(self, - shard: Shard=default_shard, - parent_root: Hash32=ZERO_HASH32, - start_epoch: Epoch=default_epoch, - end_epoch: Epoch=default_epoch, - data_root: Hash32=ZERO_HASH32) -> None: + def __init__( + self, + shard: Shard = default_shard, + parent_root: Hash32 = ZERO_HASH32, + start_epoch: Epoch = default_epoch, + end_epoch: Epoch = default_epoch, + data_root: Hash32 = ZERO_HASH32, + ) -> None: super().__init__( shard=shard, parent_root=parent_root, diff --git a/eth2/beacon/types/defaults.py b/eth2/beacon/types/defaults.py index ea268e97e8..d3d5078b05 100644 --- a/eth2/beacon/types/defaults.py +++ b/eth2/beacon/types/defaults.py @@ -1,24 +1,18 @@ """ This module contains default values to be shared across types in the parent module. """ -from typing import ( - Tuple, - TypeVar, - TYPE_CHECKING, -) +from typing import TYPE_CHECKING, Tuple, TypeVar -from eth2.beacon.constants import ( - EMPTY_PUBKEY, -) +from eth2.beacon.constants import EMPTY_PUBKEY from eth2.beacon.typing import ( # noqa: F401 + default_bitfield, default_epoch, - default_slot, - default_shard, - default_validator_index, default_gwei, - default_timestamp, default_second, - default_bitfield, + default_shard, + default_slot, + default_timestamp, + default_validator_index, default_version, ) @@ -35,12 +29,12 @@ # # for more info, see: https://stackoverflow.com/q/51885518 # updating to ``flake8==3.7.7`` fixes this bug but introduces many other breaking changes. -SomeElement = TypeVar('SomeElement') +SomeElement = TypeVar("SomeElement") default_tuple = tuple() # type: Tuple[Any, ...] def default_tuple_of_size( - size: int, - default_element: SomeElement) -> Tuple[SomeElement, ...]: + size: int, default_element: SomeElement +) -> Tuple[SomeElement, ...]: return (default_element,) * size diff --git a/eth2/beacon/types/deposit_data.py b/eth2/beacon/types/deposit_data.py index 4261e53857..0b37f2071f 100644 --- a/eth2/beacon/types/deposit_data.py +++ b/eth2/beacon/types/deposit_data.py @@ -1,31 +1,13 @@ -from eth.constants import ( - ZERO_HASH32, -) -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) -from eth_utils import ( - encode_hex, -) +from eth.constants import ZERO_HASH32 +from eth_typing import BLSPubkey, BLSSignature, Hash32 +from eth_utils import encode_hex import ssz -from ssz.sedes import ( - uint64, - bytes32, - bytes48, - bytes96, -) +from ssz.sedes import bytes32, bytes48, bytes96, uint64 from eth2.beacon.constants import EMPTY_SIGNATURE -from eth2.beacon.typing import ( - Gwei, -) +from eth2.beacon.typing import Gwei -from .defaults import ( - default_bls_pubkey, - default_gwei, -) +from .defaults import default_bls_pubkey, default_gwei class DepositData(ssz.SignedSerializable): @@ -34,19 +16,22 @@ class DepositData(ssz.SignedSerializable): Ethereum 1.0 deposit contract after a successful call to the ``deposit`` function on that contract. """ + fields = [ - ('pubkey', bytes48), - ('withdrawal_credentials', bytes32), - ('amount', uint64), + ("pubkey", bytes48), + ("withdrawal_credentials", bytes32), + ("amount", uint64), # BLS proof of possession (a BLS signature) - ('signature', bytes96), + ("signature", bytes96), ] - def __init__(self, - pubkey: BLSPubkey=default_bls_pubkey, - withdrawal_credentials: Hash32=ZERO_HASH32, - amount: Gwei=default_gwei, - signature: BLSSignature=EMPTY_SIGNATURE) -> None: + def __init__( + self, + pubkey: BLSPubkey = default_bls_pubkey, + withdrawal_credentials: Hash32 = ZERO_HASH32, + amount: Gwei = default_gwei, + signature: BLSSignature = EMPTY_SIGNATURE, + ) -> None: super().__init__( pubkey=pubkey, withdrawal_credentials=withdrawal_credentials, diff --git a/eth2/beacon/types/deposits.py b/eth2/beacon/types/deposits.py index 709b99d7fe..13618d0f52 100644 --- a/eth2/beacon/types/deposits.py +++ b/eth2/beacon/types/deposits.py @@ -1,34 +1,15 @@ -from typing import ( - Sequence, -) - -from eth.constants import ( - ZERO_HASH32, -) -from eth_utils import ( - encode_hex, -) -from eth_typing import ( - Hash32, -) -import ssz -from ssz.sedes import ( - Vector, - bytes32, -) +from typing import Sequence -from eth2.beacon.constants import ( - DEPOSIT_CONTRACT_TREE_DEPTH, -) +from eth.constants import ZERO_HASH32 +from eth_typing import Hash32 +from eth_utils import encode_hex +import ssz +from ssz.sedes import Vector, bytes32 -from .deposit_data import ( - DepositData, - default_deposit_data, -) +from eth2.beacon.constants import DEPOSIT_CONTRACT_TREE_DEPTH -from .defaults import ( - default_tuple_of_size, -) +from .defaults import default_tuple_of_size +from .deposit_data import DepositData, default_deposit_data DEPOSIT_PROOF_VECTOR_SIZE = DEPOSIT_CONTRACT_TREE_DEPTH + 1 @@ -44,17 +25,16 @@ class Deposit(ssz.Serializable): fields = [ # Merkle path to deposit root - ('proof', Vector(bytes32, DEPOSIT_PROOF_VECTOR_SIZE)), - ('data', DepositData), + ("proof", Vector(bytes32, DEPOSIT_PROOF_VECTOR_SIZE)), + ("data", DepositData), ] - def __init__(self, - proof: Sequence[Hash32]=default_proof_tuple, - data: DepositData=default_deposit_data)-> None: - super().__init__( - proof, - data, - ) + def __init__( + self, + proof: Sequence[Hash32] = default_proof_tuple, + data: DepositData = default_deposit_data, + ) -> None: + super().__init__(proof, data) def __repr__(self) -> str: return f"" diff --git a/eth2/beacon/types/eth1_data.py b/eth2/beacon/types/eth1_data.py index 58b6467304..c594c9c583 100644 --- a/eth2/beacon/types/eth1_data.py +++ b/eth2/beacon/types/eth1_data.py @@ -1,30 +1,23 @@ -from eth_typing import ( - Hash32, -) - +from eth.constants import ZERO_HASH32 +from eth_typing import Hash32 import ssz -from ssz.sedes import ( - bytes32, - uint64, -) - -from eth.constants import ( - ZERO_HASH32, -) +from ssz.sedes import bytes32, uint64 class Eth1Data(ssz.Serializable): fields = [ - ('deposit_root', bytes32), - ('deposit_count', uint64), - ('block_hash', bytes32), + ("deposit_root", bytes32), + ("deposit_count", uint64), + ("block_hash", bytes32), ] - def __init__(self, - deposit_root: Hash32=ZERO_HASH32, - deposit_count: int=0, - block_hash: Hash32=ZERO_HASH32) -> None: + def __init__( + self, + deposit_root: Hash32 = ZERO_HASH32, + deposit_count: int = 0, + block_hash: Hash32 = ZERO_HASH32, + ) -> None: super().__init__( deposit_root=deposit_root, deposit_count=deposit_count, diff --git a/eth2/beacon/types/forks.py b/eth2/beacon/types/forks.py index 4388d10981..8ebd2ed85c 100644 --- a/eth2/beacon/types/forks.py +++ b/eth2/beacon/types/forks.py @@ -1,33 +1,26 @@ import ssz -from ssz.sedes import ( - bytes4, - uint64, -) +from ssz.sedes import bytes4, uint64 -from eth2.beacon.typing import ( - Epoch, - Version, -) +from eth2.beacon.typing import Epoch, Version -from .defaults import ( - default_epoch, - default_version, -) +from .defaults import default_epoch, default_version class Fork(ssz.Serializable): fields = [ - ('previous_version', bytes4), - ('current_version', bytes4), + ("previous_version", bytes4), + ("current_version", bytes4), # Epoch of latest fork - ('epoch', uint64) + ("epoch", uint64), ] - def __init__(self, - previous_version: Version=default_version, - current_version: Version=default_version, - epoch: Epoch=default_epoch) -> None: + def __init__( + self, + previous_version: Version = default_version, + current_version: Version = default_version, + epoch: Epoch = default_epoch, + ) -> None: super().__init__( previous_version=previous_version, current_version=current_version, diff --git a/eth2/beacon/types/historical_batch.py b/eth2/beacon/types/historical_batch.py index 28a36733d1..db40a9951b 100644 --- a/eth2/beacon/types/historical_batch.py +++ b/eth2/beacon/types/historical_batch.py @@ -1,50 +1,35 @@ -from typing import ( - Sequence, -) - -from eth.constants import ( - ZERO_HASH32, -) -from eth_typing import ( - Hash32, -) +from typing import Sequence +from eth.constants import ZERO_HASH32 +from eth_typing import Hash32 import ssz -from ssz.sedes import ( - bytes32, - Vector, -) +from ssz.sedes import Vector, bytes32 -from eth2.configs import ( - Eth2Config, -) +from eth2.configs import Eth2Config -from .defaults import ( - default_tuple, - default_tuple_of_size, -) +from .defaults import default_tuple, default_tuple_of_size class HistoricalBatch(ssz.Serializable): - fields = [ - ('block_roots', Vector(bytes32, 1)), - ('state_roots', Vector(bytes32, 1)), - ] + fields = [("block_roots", Vector(bytes32, 1)), ("state_roots", Vector(bytes32, 1))] - def __init__(self, - *, - block_roots: Sequence[Hash32]=default_tuple, - state_roots: Sequence[Hash32]=default_tuple, - config: Eth2Config=None) -> None: + def __init__( + self, + *, + block_roots: Sequence[Hash32] = default_tuple, + state_roots: Sequence[Hash32] = default_tuple, + config: Eth2Config = None + ) -> None: if config: # try to provide sane defaults if block_roots == default_tuple: - block_roots = default_tuple_of_size(config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32) + block_roots = default_tuple_of_size( + config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32 + ) if state_roots == default_tuple: - state_roots = default_tuple_of_size(config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32) + state_roots = default_tuple_of_size( + config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32 + ) - super().__init__( - block_roots=block_roots, - state_roots=state_roots, - ) + super().__init__(block_roots=block_roots, state_roots=state_roots) diff --git a/eth2/beacon/types/pending_attestations.py b/eth2/beacon/types/pending_attestations.py index aa2b394b25..5f5d0ffddd 100644 --- a/eth2/beacon/types/pending_attestations.py +++ b/eth2/beacon/types/pending_attestations.py @@ -1,39 +1,28 @@ import ssz -from ssz.sedes import ( - Bitlist, - uint64, -) +from ssz.sedes import Bitlist, uint64 -from eth2.beacon.typing import ( - Bitfield, - ValidatorIndex, -) +from eth2.beacon.typing import Bitfield, ValidatorIndex -from .attestation_data import ( - AttestationData, - default_attestation_data, -) - -from .defaults import ( - default_bitfield, - default_validator_index, -) +from .attestation_data import AttestationData, default_attestation_data +from .defaults import default_bitfield, default_validator_index class PendingAttestation(ssz.Serializable): fields = [ - ('aggregation_bits', Bitlist(1)), - ('data', AttestationData), - ('inclusion_delay', uint64), - ('proposer_index', uint64), + ("aggregation_bits", Bitlist(1)), + ("data", AttestationData), + ("inclusion_delay", uint64), + ("proposer_index", uint64), ] - def __init__(self, - aggregation_bits: Bitfield=default_bitfield, - data: AttestationData=default_attestation_data, - inclusion_delay: int=0, - proposer_index: ValidatorIndex=default_validator_index) -> None: + def __init__( + self, + aggregation_bits: Bitfield = default_bitfield, + data: AttestationData = default_attestation_data, + inclusion_delay: int = 0, + proposer_index: ValidatorIndex = default_validator_index, + ) -> None: super().__init__( aggregation_bits=aggregation_bits, data=data, diff --git a/eth2/beacon/types/proposer_slashings.py b/eth2/beacon/types/proposer_slashings.py index 49f876e0ab..7512cf724d 100644 --- a/eth2/beacon/types/proposer_slashings.py +++ b/eth2/beacon/types/proposer_slashings.py @@ -1,38 +1,29 @@ import ssz -from ssz.sedes import ( - uint64, -) +from ssz.sedes import uint64 -from .block_headers import ( - BeaconBlockHeader, - default_beacon_block_header, -) -from eth2.beacon.typing import ( - ValidatorIndex, -) +from eth2.beacon.typing import ValidatorIndex -from .defaults import ( - default_validator_index, -) +from .block_headers import BeaconBlockHeader, default_beacon_block_header +from .defaults import default_validator_index class ProposerSlashing(ssz.Serializable): fields = [ # Proposer index - ('proposer_index', uint64), + ("proposer_index", uint64), # First block header - ('header_1', BeaconBlockHeader), + ("header_1", BeaconBlockHeader), # Second block header - ('header_2', BeaconBlockHeader), + ("header_2", BeaconBlockHeader), ] - def __init__(self, - proposer_index: ValidatorIndex=default_validator_index, - header_1: BeaconBlockHeader=default_beacon_block_header, - header_2: BeaconBlockHeader=default_beacon_block_header) -> None: + def __init__( + self, + proposer_index: ValidatorIndex = default_validator_index, + header_1: BeaconBlockHeader = default_beacon_block_header, + header_2: BeaconBlockHeader = default_beacon_block_header, + ) -> None: super().__init__( - proposer_index=proposer_index, - header_1=header_1, - header_2=header_2, + proposer_index=proposer_index, header_1=header_1, header_2=header_2 ) diff --git a/eth2/beacon/types/states.py b/eth2/beacon/types/states.py index 948da405f3..3ebe31c8eb 100644 --- a/eth2/beacon/types/states.py +++ b/eth2/beacon/types/states.py @@ -1,81 +1,39 @@ -from typing import ( - Any, - Callable, - Sequence, -) - -from eth_typing import ( - Hash32, -) -from eth_utils import ( - encode_hex, -) +from typing import Any, Callable, Sequence +from eth.constants import ZERO_HASH32 +from eth_typing import Hash32 +from eth_utils import encode_hex import ssz -from ssz.sedes import ( - Bitvector, - List, - Vector, - bytes32, - uint64, -) +from ssz.sedes import Bitvector, List, Vector, bytes32, uint64 -from eth.constants import ( - ZERO_HASH32, -) - -from eth2._utils.tuple import ( - update_tuple_item, - update_tuple_item_with_fn, -) -from eth2.configs import Eth2Config -from eth2.beacon.constants import ( - JUSTIFICATION_BITS_LENGTH, -) -from eth2.beacon.helpers import ( - compute_epoch_of_slot, -) +from eth2._utils.tuple import update_tuple_item, update_tuple_item_with_fn +from eth2.beacon.constants import JUSTIFICATION_BITS_LENGTH +from eth2.beacon.helpers import compute_epoch_of_slot from eth2.beacon.typing import ( + Bitfield, Epoch, Gwei, Shard, Slot, Timestamp, ValidatorIndex, - Bitfield, -) - -from .block_headers import ( - BeaconBlockHeader, - default_beacon_block_header, -) -from .eth1_data import ( - Eth1Data, - default_eth1_data, -) -from .checkpoints import ( - Checkpoint, - default_checkpoint, ) -from .crosslinks import ( - Crosslink, - default_crosslink, -) -from .forks import ( - Fork, - default_fork, -) -from .pending_attestations import PendingAttestation -from .validators import Validator +from eth2.configs import Eth2Config +from .block_headers import BeaconBlockHeader, default_beacon_block_header +from .checkpoints import Checkpoint, default_checkpoint +from .crosslinks import Crosslink, default_crosslink from .defaults import ( - default_timestamp, + default_shard, default_slot, + default_timestamp, default_tuple, default_tuple_of_size, - default_shard, ) - +from .eth1_data import Eth1Data, default_eth1_data +from .forks import Fork, default_fork +from .pending_attestations import PendingAttestation +from .validators import Validator default_justification_bits = Bitfield((False,) * JUSTIFICATION_BITS_LENGTH) @@ -84,80 +42,81 @@ class BeaconState(ssz.Serializable): fields = [ # Versioning - ('genesis_time', uint64), - ('slot', uint64), - ('fork', Fork), - + ("genesis_time", uint64), + ("slot", uint64), + ("fork", Fork), # History - ('latest_block_header', BeaconBlockHeader), - ('block_roots', Vector(bytes32, 1)), # Needed to process attestations, older to newer # noqa: E501 - ('state_roots', Vector(bytes32, 1)), - ('historical_roots', List(bytes32, 1)), # allow for a log-sized Merkle proof from any block to any historical block root # noqa: E501 - + ("latest_block_header", BeaconBlockHeader), + ( + "block_roots", + Vector(bytes32, 1), + ), # Needed to process attestations, older to newer # noqa: E501 + ("state_roots", Vector(bytes32, 1)), + ( + "historical_roots", + List(bytes32, 1), + ), # allow for a log-sized Merkle proof from any block to any historical block root # noqa: E501 # Ethereum 1.0 chain - ('eth1_data', Eth1Data), - ('eth1_data_votes', List(Eth1Data, 1)), - ('eth1_deposit_index', uint64), - + ("eth1_data", Eth1Data), + ("eth1_data_votes", List(Eth1Data, 1)), + ("eth1_deposit_index", uint64), # Validator registry - ('validators', List(Validator, 1)), - ('balances', List(uint64, 1)), - + ("validators", List(Validator, 1)), + ("balances", List(uint64, 1)), # Shuffling - ('start_shard', uint64), - ('randao_mixes', Vector(bytes32, 1)), - ('active_index_roots', Vector(bytes32, 1)), - ('compact_committees_roots', Vector(bytes32, 1)), - + ("start_shard", uint64), + ("randao_mixes", Vector(bytes32, 1)), + ("active_index_roots", Vector(bytes32, 1)), + ("compact_committees_roots", Vector(bytes32, 1)), # Slashings - ('slashings', Vector(uint64, 1)), # Balances slashed at every withdrawal period # noqa: E501 - + ( + "slashings", + Vector(uint64, 1), + ), # Balances slashed at every withdrawal period # noqa: E501 # Attestations - ('previous_epoch_attestations', List(PendingAttestation, 1)), - ('current_epoch_attestations', List(PendingAttestation, 1)), - + ("previous_epoch_attestations", List(PendingAttestation, 1)), + ("current_epoch_attestations", List(PendingAttestation, 1)), # Crosslinks - ('previous_crosslinks', Vector(Crosslink, 1)), - ('current_crosslinks', Vector(Crosslink, 1)), - + ("previous_crosslinks", Vector(Crosslink, 1)), + ("current_crosslinks", Vector(Crosslink, 1)), # Justification - ('justification_bits', Bitvector(JUSTIFICATION_BITS_LENGTH)), - ('previous_justified_checkpoint', Checkpoint), - ('current_justified_checkpoint', Checkpoint), - + ("justification_bits", Bitvector(JUSTIFICATION_BITS_LENGTH)), + ("previous_justified_checkpoint", Checkpoint), + ("current_justified_checkpoint", Checkpoint), # Finality - ('finalized_checkpoint', Checkpoint), + ("finalized_checkpoint", Checkpoint), ] def __init__( - self, - *, - genesis_time: Timestamp=default_timestamp, - slot: Slot=default_slot, - fork: Fork=default_fork, - latest_block_header: BeaconBlockHeader=default_beacon_block_header, - block_roots: Sequence[Hash32]=default_tuple, - state_roots: Sequence[Hash32]=default_tuple, - historical_roots: Sequence[Hash32]=default_tuple, - eth1_data: Eth1Data=default_eth1_data, - eth1_data_votes: Sequence[Eth1Data]=default_tuple, - eth1_deposit_index: int=0, - validators: Sequence[Validator]=default_tuple, - balances: Sequence[Gwei]=default_tuple, - start_shard: Shard=default_shard, - randao_mixes: Sequence[Hash32]=default_tuple, - active_index_roots: Sequence[Hash32]=default_tuple, - compact_committees_roots: Sequence[Hash32]=default_tuple, - slashings: Sequence[Gwei]=default_tuple, - previous_epoch_attestations: Sequence[PendingAttestation]=default_tuple, - current_epoch_attestations: Sequence[PendingAttestation]=default_tuple, - previous_crosslinks: Sequence[Crosslink]=default_tuple, - current_crosslinks: Sequence[Crosslink]=default_tuple, - justification_bits: Bitfield=default_justification_bits, - previous_justified_checkpoint: Checkpoint=default_checkpoint, - current_justified_checkpoint: Checkpoint=default_checkpoint, - finalized_checkpoint: Checkpoint=default_checkpoint, - config: Eth2Config=None) -> None: + self, + *, + genesis_time: Timestamp = default_timestamp, + slot: Slot = default_slot, + fork: Fork = default_fork, + latest_block_header: BeaconBlockHeader = default_beacon_block_header, + block_roots: Sequence[Hash32] = default_tuple, + state_roots: Sequence[Hash32] = default_tuple, + historical_roots: Sequence[Hash32] = default_tuple, + eth1_data: Eth1Data = default_eth1_data, + eth1_data_votes: Sequence[Eth1Data] = default_tuple, + eth1_deposit_index: int = 0, + validators: Sequence[Validator] = default_tuple, + balances: Sequence[Gwei] = default_tuple, + start_shard: Shard = default_shard, + randao_mixes: Sequence[Hash32] = default_tuple, + active_index_roots: Sequence[Hash32] = default_tuple, + compact_committees_roots: Sequence[Hash32] = default_tuple, + slashings: Sequence[Gwei] = default_tuple, + previous_epoch_attestations: Sequence[PendingAttestation] = default_tuple, + current_epoch_attestations: Sequence[PendingAttestation] = default_tuple, + previous_crosslinks: Sequence[Crosslink] = default_tuple, + current_crosslinks: Sequence[Crosslink] = default_tuple, + justification_bits: Bitfield = default_justification_bits, + previous_justified_checkpoint: Checkpoint = default_checkpoint, + current_justified_checkpoint: Checkpoint = default_checkpoint, + finalized_checkpoint: Checkpoint = default_checkpoint, + config: Eth2Config = None, + ) -> None: if len(validators) != len(balances): raise ValueError( "The length of validators and balances lists should be the same." @@ -166,38 +125,36 @@ def __init__( if config: # try to provide sane defaults if block_roots == default_tuple: - block_roots = default_tuple_of_size(config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32) + block_roots = default_tuple_of_size( + config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32 + ) if state_roots == default_tuple: - state_roots = default_tuple_of_size(config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32) + state_roots = default_tuple_of_size( + config.SLOTS_PER_HISTORICAL_ROOT, ZERO_HASH32 + ) if randao_mixes == default_tuple: randao_mixes = default_tuple_of_size( - config.EPOCHS_PER_HISTORICAL_VECTOR, - ZERO_HASH32 + config.EPOCHS_PER_HISTORICAL_VECTOR, ZERO_HASH32 ) if active_index_roots == default_tuple: active_index_roots = default_tuple_of_size( - config.EPOCHS_PER_HISTORICAL_VECTOR, - ZERO_HASH32 + config.EPOCHS_PER_HISTORICAL_VECTOR, ZERO_HASH32 ) if compact_committees_roots == default_tuple: compact_committees_roots = default_tuple_of_size( - config.EPOCHS_PER_HISTORICAL_VECTOR, - ZERO_HASH32 + config.EPOCHS_PER_HISTORICAL_VECTOR, ZERO_HASH32 ) if slashings == default_tuple: slashings = default_tuple_of_size( - config.EPOCHS_PER_SLASHINGS_VECTOR, - Gwei(0), + config.EPOCHS_PER_SLASHINGS_VECTOR, Gwei(0) ) if previous_crosslinks == default_tuple: previous_crosslinks = default_tuple_of_size( - config.SHARD_COUNT, - default_crosslink, + config.SHARD_COUNT, default_crosslink ) if current_crosslinks == default_tuple: current_crosslinks = default_tuple_of_size( - config.SHARD_COUNT, - default_crosslink, + config.SHARD_COUNT, default_crosslink ) super().__init__( @@ -235,10 +192,12 @@ def __repr__(self) -> str: def validator_count(self) -> int: return len(self.validators) - def update_validator(self, - validator_index: ValidatorIndex, - validator: Validator, - balance: Gwei=None) -> 'BeaconState': + def update_validator( + self, + validator_index: ValidatorIndex, + validator: Validator, + balance: Gwei = None, + ) -> "BeaconState": """ Replace ``self.validators[validator_index]`` with ``validator``. @@ -246,28 +205,24 @@ def update_validator(self, ``self.balances[validator_index] with ``balance``. """ if ( - validator_index >= len(self.validators) or - validator_index >= len(self.balances) or - validator_index < 0 + validator_index >= len(self.validators) + or validator_index >= len(self.balances) + or validator_index < 0 ): raise IndexError("Incorrect validator index") - state = self.update_validator_with_fn( - validator_index, - lambda *_: validator, - ) + state = self.update_validator_with_fn(validator_index, lambda *_: validator) if balance: - return state._update_validator_balance( - validator_index, - balance, - ) + return state._update_validator_balance(validator_index, balance) else: return state - def update_validator_with_fn(self, - validator_index: ValidatorIndex, - fn: Callable[[Validator, Any], Validator], - *args: Any) -> 'BeaconState': + def update_validator_with_fn( + self, + validator_index: ValidatorIndex, + fn: Callable[[Validator, Any], Validator], + *args: Any, + ) -> "BeaconState": """ Replace ``self.validators[validator_index]`` with the result of calling ``fn`` on the existing ``validator``. @@ -279,16 +234,13 @@ def update_validator_with_fn(self, return self.copy( validators=update_tuple_item_with_fn( - self.validators, - validator_index, - fn, - *args, - ), + self.validators, validator_index, fn, *args + ) ) - def _update_validator_balance(self, - validator_index: ValidatorIndex, - balance: Gwei) -> 'BeaconState': + def _update_validator_balance( + self, validator_index: ValidatorIndex, balance: Gwei + ) -> "BeaconState": """ Update the balance of validator of the given ``validator_index``. """ @@ -296,11 +248,7 @@ def _update_validator_balance(self, raise IndexError("Incorrect validator index") return self.copy( - balances=update_tuple_item( - self.balances, - validator_index, - balance, - ) + balances=update_tuple_item(self.balances, validator_index, balance) ) def current_epoch(self, slots_per_epoch: int) -> Epoch: diff --git a/eth2/beacon/types/transfers.py b/eth2/beacon/types/transfers.py index 1d6bf3f08f..02c9cbb98b 100644 --- a/eth2/beacon/types/transfers.py +++ b/eth2/beacon/types/transfers.py @@ -1,52 +1,42 @@ -from eth_typing import ( - BLSPubkey, - BLSSignature, -) +from eth_typing import BLSPubkey, BLSSignature import ssz -from ssz.sedes import ( - bytes48, - bytes96, - uint64 -) +from ssz.sedes import bytes48, bytes96, uint64 from eth2.beacon.constants import EMPTY_SIGNATURE - -from eth2.beacon.typing import ( - Gwei, - Slot, - ValidatorIndex, -) +from eth2.beacon.typing import Gwei, Slot, ValidatorIndex from .defaults import ( - default_validator_index, + default_bls_pubkey, default_gwei, default_slot, - default_bls_pubkey, + default_validator_index, ) class Transfer(ssz.SignedSerializable): fields = [ - ('sender', uint64), - ('recipient', uint64), - ('amount', uint64), - ('fee', uint64), + ("sender", uint64), + ("recipient", uint64), + ("amount", uint64), + ("fee", uint64), # Inclusion slot - ('slot', uint64), + ("slot", uint64), # Sender withdrawal pubkey - ('pubkey', bytes48), + ("pubkey", bytes48), # Sender signature - ('signature', bytes96), + ("signature", bytes96), ] - def __init__(self, - sender: ValidatorIndex=default_validator_index, - recipient: ValidatorIndex=default_validator_index, - amount: Gwei=default_gwei, - fee: Gwei=default_gwei, - slot: Slot=default_slot, - pubkey: BLSPubkey=default_bls_pubkey, - signature: BLSSignature=EMPTY_SIGNATURE) -> None: + def __init__( + self, + sender: ValidatorIndex = default_validator_index, + recipient: ValidatorIndex = default_validator_index, + amount: Gwei = default_gwei, + fee: Gwei = default_gwei, + slot: Slot = default_slot, + pubkey: BLSPubkey = default_bls_pubkey, + signature: BLSSignature = EMPTY_SIGNATURE, + ) -> None: super().__init__( sender=sender, recipient=recipient, diff --git a/eth2/beacon/types/validators.py b/eth2/beacon/types/validators.py index 25e19e3732..5bafdd7bc7 100644 --- a/eth2/beacon/types/validators.py +++ b/eth2/beacon/types/validators.py @@ -1,30 +1,12 @@ -from eth_typing import ( - BLSPubkey, - Hash32, -) +from eth_typing import BLSPubkey, Hash32 import ssz -from ssz.sedes import ( - boolean, - bytes32, - bytes48, - uint64, -) +from ssz.sedes import boolean, bytes32, bytes48, uint64 +from eth2.beacon.constants import FAR_FUTURE_EPOCH, ZERO_HASH32 +from eth2.beacon.typing import Epoch, Gwei from eth2.configs import Eth2Config -from eth2.beacon.constants import ( - FAR_FUTURE_EPOCH, - ZERO_HASH32, -) -from eth2.beacon.typing import ( - Epoch, - Gwei, -) -from .defaults import ( - default_bls_pubkey, - default_epoch, - default_gwei, -) +from .defaults import default_bls_pubkey, default_epoch, default_gwei def _round_down_to_previous_multiple(amount: int, increment: int) -> int: @@ -35,8 +17,7 @@ def calculate_effective_balance(amount: Gwei, config: Eth2Config) -> Gwei: return Gwei( min( _round_down_to_previous_multiple( - amount, - config.EFFECTIVE_BALANCE_INCREMENT, + amount, config.EFFECTIVE_BALANCE_INCREMENT ), config.MAX_EFFECTIVE_BALANCE, ) @@ -46,30 +27,32 @@ def calculate_effective_balance(amount: Gwei, config: Eth2Config) -> Gwei: class Validator(ssz.Serializable): fields = [ - ('pubkey', bytes48), - ('withdrawal_credentials', bytes32), - ('effective_balance', uint64), - ('slashed', boolean), + ("pubkey", bytes48), + ("withdrawal_credentials", bytes32), + ("effective_balance", uint64), + ("slashed", boolean), # Epoch when validator became eligible for activation - ('activation_eligibility_epoch', uint64), + ("activation_eligibility_epoch", uint64), # Epoch when validator activated - ('activation_epoch', uint64), + ("activation_epoch", uint64), # Epoch when validator exited - ('exit_epoch', uint64), + ("exit_epoch", uint64), # Epoch when validator withdrew - ('withdrawable_epoch', uint64), + ("withdrawable_epoch", uint64), ] - def __init__(self, - *, - pubkey: BLSPubkey=default_bls_pubkey, - withdrawal_credentials: Hash32=ZERO_HASH32, - effective_balance: Gwei=default_gwei, - slashed: bool=False, - activation_eligibility_epoch: Epoch=default_epoch, - activation_epoch: Epoch=default_epoch, - exit_epoch: Epoch=default_epoch, - withdrawable_epoch: Epoch=default_epoch) -> None: + def __init__( + self, + *, + pubkey: BLSPubkey = default_bls_pubkey, + withdrawal_credentials: Hash32 = ZERO_HASH32, + effective_balance: Gwei = default_gwei, + slashed: bool = False, + activation_eligibility_epoch: Epoch = default_epoch, + activation_epoch: Epoch = default_epoch, + exit_epoch: Epoch = default_epoch, + withdrawable_epoch: Epoch = default_epoch + ) -> None: super().__init__( pubkey=pubkey, withdrawal_credentials=withdrawal_credentials, @@ -94,25 +77,26 @@ def is_slashable(self, epoch: Epoch) -> bool: From `is_slashable_validator` in the spec. """ not_slashed = self.slashed is False - active_but_not_withdrawable = self.activation_epoch <= epoch < self.withdrawable_epoch + active_but_not_withdrawable = ( + self.activation_epoch <= epoch < self.withdrawable_epoch + ) return not_slashed and active_but_not_withdrawable @classmethod - def create_pending_validator(cls, - pubkey: BLSPubkey, - withdrawal_credentials: Hash32, - amount: Gwei, - config: Eth2Config) -> 'Validator': + def create_pending_validator( + cls, + pubkey: BLSPubkey, + withdrawal_credentials: Hash32, + amount: Gwei, + config: Eth2Config, + ) -> "Validator": """ Return a new pending ``Validator`` with the given fields. """ return cls( pubkey=pubkey, withdrawal_credentials=withdrawal_credentials, - effective_balance=calculate_effective_balance( - amount, - config, - ), + effective_balance=calculate_effective_balance(amount, config), activation_eligibility_epoch=FAR_FUTURE_EPOCH, activation_epoch=FAR_FUTURE_EPOCH, exit_epoch=FAR_FUTURE_EPOCH, diff --git a/eth2/beacon/types/voluntary_exits.py b/eth2/beacon/types/voluntary_exits.py index 2a81a3877c..88935eaecd 100644 --- a/eth2/beacon/types/voluntary_exits.py +++ b/eth2/beacon/types/voluntary_exits.py @@ -1,39 +1,26 @@ -from eth_typing import ( - BLSSignature, -) +from eth_typing import BLSSignature import ssz -from ssz.sedes import ( - bytes96, - uint64, -) +from ssz.sedes import bytes96, uint64 -from eth2.beacon.typing import ( - Epoch, - ValidatorIndex, -) from eth2.beacon.constants import EMPTY_SIGNATURE +from eth2.beacon.typing import Epoch, ValidatorIndex -from .defaults import ( - default_validator_index, - default_epoch, -) +from .defaults import default_epoch, default_validator_index class VoluntaryExit(ssz.SignedSerializable): fields = [ # Minimum epoch for processing exit - ('epoch', uint64), - ('validator_index', uint64), - ('signature', bytes96), + ("epoch", uint64), + ("validator_index", uint64), + ("signature", bytes96), ] - def __init__(self, - epoch: Epoch=default_epoch, - validator_index: ValidatorIndex=default_validator_index, - signature: BLSSignature=EMPTY_SIGNATURE) -> None: - super().__init__( - epoch, - validator_index, - signature, - ) + def __init__( + self, + epoch: Epoch = default_epoch, + validator_index: ValidatorIndex = default_validator_index, + signature: BLSSignature = EMPTY_SIGNATURE, + ) -> None: + super().__init__(epoch, validator_index, signature) diff --git a/eth2/beacon/typing.py b/eth2/beacon/typing.py index bc783bf8b9..056f478ac9 100644 --- a/eth2/beacon/typing.py +++ b/eth2/beacon/typing.py @@ -1,28 +1,25 @@ -from typing import ( - NamedTuple, - NewType, - Tuple, -) +from typing import NamedTuple, NewType, Tuple +Slot = NewType("Slot", int) # uint64 +Epoch = NewType("Epoch", int) # uint64 +Shard = NewType("Shard", int) # uint64 -Slot = NewType('Slot', int) # uint64 -Epoch = NewType('Epoch', int) # uint64 -Shard = NewType('Shard', int) # uint64 +Bitfield = NewType("Bitfield", Tuple[bool, ...]) -Bitfield = NewType('Bitfield', Tuple[bool, ...]) +ValidatorIndex = NewType("ValidatorIndex", int) # uint64 +CommitteeIndex = NewType( + "CommitteeIndex", int +) # uint64 The i-th position in a committee tuple -ValidatorIndex = NewType('ValidatorIndex', int) # uint64 -CommitteeIndex = NewType('CommitteeIndex', int) # uint64 The i-th position in a committee tuple +Gwei = NewType("Gwei", int) # uint64 -Gwei = NewType('Gwei', int) # uint64 +Timestamp = NewType("Timestamp", int) +Second = NewType("Second", int) -Timestamp = NewType('Timestamp', int) -Second = NewType('Second', int) +Version = NewType("Version", bytes) -Version = NewType('Version', bytes) - -DomainType = NewType('DomainType', bytes) # bytes of length 4 +DomainType = NewType("DomainType", bytes) # bytes of length 4 class FromBlockParams(NamedTuple): @@ -38,4 +35,4 @@ class FromBlockParams(NamedTuple): default_timestamp = Timestamp(0) default_second = Second(0) default_bitfield = Bitfield(tuple()) -default_version = Version(b'\x00' * 4) +default_version = Version(b"\x00" * 4) diff --git a/eth2/beacon/validator_status_helpers.py b/eth2/beacon/validator_status_helpers.py index 99b45362b5..97321fc656 100644 --- a/eth2/beacon/validator_status_helpers.py +++ b/eth2/beacon/validator_status_helpers.py @@ -1,57 +1,45 @@ -from eth_utils.toolz import ( - curry, -) +from eth_utils.toolz import curry -from eth2._utils.tuple import ( - update_tuple_item_with_fn, -) -from eth2.configs import ( - CommitteeConfig, - Eth2Config, -) -from eth2.beacon.committee_helpers import ( - get_beacon_proposer_index, -) +from eth2._utils.tuple import update_tuple_item_with_fn +from eth2.beacon.committee_helpers import get_beacon_proposer_index from eth2.beacon.constants import FAR_FUTURE_EPOCH from eth2.beacon.epoch_processing_helpers import ( + compute_activation_exit_epoch, decrease_balance, get_validator_churn_limit, - compute_activation_exit_epoch, increase_balance, ) from eth2.beacon.types.states import BeaconState from eth2.beacon.types.validators import Validator -from eth2.beacon.typing import ( - Epoch, - Gwei, - ValidatorIndex, -) +from eth2.beacon.typing import Epoch, Gwei, ValidatorIndex +from eth2.configs import CommitteeConfig, Eth2Config def activate_validator(validator: Validator, activation_epoch: Epoch) -> Validator: return validator.copy( - activation_eligibility_epoch=activation_epoch, - activation_epoch=activation_epoch, + activation_eligibility_epoch=activation_epoch, activation_epoch=activation_epoch ) -def _compute_exit_queue_epoch(state: BeaconState, churn_limit: int, config: Eth2Config) -> Epoch: +def _compute_exit_queue_epoch( + state: BeaconState, churn_limit: int, config: Eth2Config +) -> Epoch: slots_per_epoch = config.SLOTS_PER_EPOCH exit_epochs = tuple( - v.exit_epoch for v in state.validators - if v.exit_epoch != FAR_FUTURE_EPOCH + v.exit_epoch for v in state.validators if v.exit_epoch != FAR_FUTURE_EPOCH ) exit_queue_epoch = max( - exit_epochs + (compute_activation_exit_epoch( - state.current_epoch(slots_per_epoch), - config.ACTIVATION_EXIT_DELAY, - ),) + exit_epochs + + ( + compute_activation_exit_epoch( + state.current_epoch(slots_per_epoch), config.ACTIVATION_EXIT_DELAY + ), + ) + ) + exit_queue_churn = len( + tuple(v for v in state.validators if v.exit_epoch == exit_queue_epoch) ) - exit_queue_churn = len(tuple( - v for v in state.validators - if v.exit_epoch == exit_queue_epoch - )) if exit_queue_churn >= churn_limit: exit_queue_epoch += 1 return Epoch(exit_queue_epoch) @@ -59,9 +47,9 @@ def _compute_exit_queue_epoch(state: BeaconState, churn_limit: int, config: Eth2 # NOTE: adding ``curry`` here gets mypy to allow use of this elsewhere. @curry -def initiate_exit_for_validator(validator: Validator, - state: BeaconState, - config: Eth2Config) -> Validator: +def initiate_exit_for_validator( + validator: Validator, state: BeaconState, config: Eth2Config +) -> Validator: """ Performs the mutations to ``validator`` used to initiate an exit. More convenient given our immutability patterns compared to ``initiate_validator_exit``. @@ -74,42 +62,42 @@ def initiate_exit_for_validator(validator: Validator, return validator.copy( exit_epoch=exit_queue_epoch, - withdrawable_epoch=Epoch(exit_queue_epoch + config.MIN_VALIDATOR_WITHDRAWABILITY_DELAY), + withdrawable_epoch=Epoch( + exit_queue_epoch + config.MIN_VALIDATOR_WITHDRAWABILITY_DELAY + ), ) -def initiate_validator_exit(state: BeaconState, - index: ValidatorIndex, - config: Eth2Config) -> BeaconState: +def initiate_validator_exit( + state: BeaconState, index: ValidatorIndex, config: Eth2Config +) -> BeaconState: """ Initiate exit for the validator with the given ``index``. Return the updated state (immutable). """ return state.update_validator_with_fn( - index, - initiate_exit_for_validator, - state, - config, + index, initiate_exit_for_validator, state, config ) @curry -def _set_validator_slashed(v: Validator, - current_epoch: Epoch, - epochs_per_slashings_vector: int) -> Validator: +def _set_validator_slashed( + v: Validator, current_epoch: Epoch, epochs_per_slashings_vector: int +) -> Validator: return v.copy( slashed=True, withdrawable_epoch=max( - v.withdrawable_epoch, - Epoch(current_epoch + epochs_per_slashings_vector), + v.withdrawable_epoch, Epoch(current_epoch + epochs_per_slashings_vector) ), ) -def slash_validator(state: BeaconState, - index: ValidatorIndex, - config: Eth2Config, - whistleblower_index: ValidatorIndex=None) -> BeaconState: +def slash_validator( + state: BeaconState, + index: ValidatorIndex, + config: Eth2Config, + whistleblower_index: ValidatorIndex = None, +) -> BeaconState: """ Slash the validator with index ``index``. @@ -124,10 +112,7 @@ def slash_validator(state: BeaconState, state = initiate_validator_exit(state, index, config) state = state.update_validator_with_fn( - index, - _set_validator_slashed, - current_epoch, - config.EPOCHS_PER_SLASHINGS_VECTOR, + index, _set_validator_slashed, current_epoch, config.EPOCHS_PER_SLASHINGS_VECTOR ) slashed_balance = state.validators[index].effective_balance @@ -140,7 +125,9 @@ def slash_validator(state: BeaconState, slashed_balance, ) ) - state = decrease_balance(state, index, slashed_balance // config.MIN_SLASHING_PENALTY_QUOTIENT) + state = decrease_balance( + state, index, slashed_balance // config.MIN_SLASHING_PENALTY_QUOTIENT + ) proposer_index = get_beacon_proposer_index(state, CommitteeConfig(config)) if whistleblower_index is None: @@ -149,9 +136,7 @@ def slash_validator(state: BeaconState, proposer_reward = Gwei(whistleblower_reward // config.PROPOSER_REWARD_QUOTIENT) state = increase_balance(state, proposer_index, proposer_reward) state = increase_balance( - state, - whistleblower_index, - Gwei(whistleblower_reward - proposer_reward), + state, whistleblower_index, Gwei(whistleblower_reward - proposer_reward) ) return state diff --git a/eth2/configs.py b/eth2/configs.py index 615e15417f..d49e965511 100644 --- a/eth2/configs.py +++ b/eth2/configs.py @@ -1,70 +1,62 @@ -from typing import ( - NamedTuple, -) - -from eth2.beacon.typing import ( - Epoch, - Gwei, - Second, - Slot, -) +from typing import NamedTuple +from eth2.beacon.typing import Epoch, Gwei, Second, Slot Eth2Config = NamedTuple( - 'Eth2Config', + "Eth2Config", ( # Misc - ('SHARD_COUNT', int), - ('TARGET_COMMITTEE_SIZE', int), - ('MAX_VALIDATORS_PER_COMMITTEE', int), - ('MIN_PER_EPOCH_CHURN_LIMIT', int), - ('CHURN_LIMIT_QUOTIENT', int), - ('SHUFFLE_ROUND_COUNT', int), + ("SHARD_COUNT", int), + ("TARGET_COMMITTEE_SIZE", int), + ("MAX_VALIDATORS_PER_COMMITTEE", int), + ("MIN_PER_EPOCH_CHURN_LIMIT", int), + ("CHURN_LIMIT_QUOTIENT", int), + ("SHUFFLE_ROUND_COUNT", int), # Genesis - ('MIN_GENESIS_ACTIVE_VALIDATOR_COUNT', int), - ('MIN_GENESIS_TIME', int), + ("MIN_GENESIS_ACTIVE_VALIDATOR_COUNT", int), + ("MIN_GENESIS_TIME", int), # Gwei values, - ('MIN_DEPOSIT_AMOUNT', Gwei), - ('MAX_EFFECTIVE_BALANCE', Gwei), - ('EJECTION_BALANCE', Gwei), - ('EFFECTIVE_BALANCE_INCREMENT', Gwei), + ("MIN_DEPOSIT_AMOUNT", Gwei), + ("MAX_EFFECTIVE_BALANCE", Gwei), + ("EJECTION_BALANCE", Gwei), + ("EFFECTIVE_BALANCE_INCREMENT", Gwei), # Initial values - ('GENESIS_SLOT', Slot), - ('GENESIS_EPOCH', Epoch), - ('BLS_WITHDRAWAL_PREFIX', int), + ("GENESIS_SLOT", Slot), + ("GENESIS_EPOCH", Epoch), + ("BLS_WITHDRAWAL_PREFIX", int), # Time parameters - ('SECONDS_PER_SLOT', Second), - ('MIN_ATTESTATION_INCLUSION_DELAY', int), - ('SLOTS_PER_EPOCH', int), - ('MIN_SEED_LOOKAHEAD', int), - ('ACTIVATION_EXIT_DELAY', int), - ('SLOTS_PER_ETH1_VOTING_PERIOD', int), - ('SLOTS_PER_HISTORICAL_ROOT', int), - ('MIN_VALIDATOR_WITHDRAWABILITY_DELAY', int), - ('PERSISTENT_COMMITTEE_PERIOD', int), - ('MAX_EPOCHS_PER_CROSSLINK', int), - ('MIN_EPOCHS_TO_INACTIVITY_PENALTY', int), + ("SECONDS_PER_SLOT", Second), + ("MIN_ATTESTATION_INCLUSION_DELAY", int), + ("SLOTS_PER_EPOCH", int), + ("MIN_SEED_LOOKAHEAD", int), + ("ACTIVATION_EXIT_DELAY", int), + ("SLOTS_PER_ETH1_VOTING_PERIOD", int), + ("SLOTS_PER_HISTORICAL_ROOT", int), + ("MIN_VALIDATOR_WITHDRAWABILITY_DELAY", int), + ("PERSISTENT_COMMITTEE_PERIOD", int), + ("MAX_EPOCHS_PER_CROSSLINK", int), + ("MIN_EPOCHS_TO_INACTIVITY_PENALTY", int), # State list lengths - ('EPOCHS_PER_HISTORICAL_VECTOR', int), - ('EPOCHS_PER_SLASHINGS_VECTOR', int), - ('HISTORICAL_ROOTS_LIMIT', int), - ('VALIDATOR_REGISTRY_LIMIT', int), + ("EPOCHS_PER_HISTORICAL_VECTOR", int), + ("EPOCHS_PER_SLASHINGS_VECTOR", int), + ("HISTORICAL_ROOTS_LIMIT", int), + ("VALIDATOR_REGISTRY_LIMIT", int), # Rewards and penalties - ('BASE_REWARD_FACTOR', int), - ('WHISTLEBLOWER_REWARD_QUOTIENT', int), - ('PROPOSER_REWARD_QUOTIENT', int), - ('INACTIVITY_PENALTY_QUOTIENT', int), - ('MIN_SLASHING_PENALTY_QUOTIENT', int), + ("BASE_REWARD_FACTOR", int), + ("WHISTLEBLOWER_REWARD_QUOTIENT", int), + ("PROPOSER_REWARD_QUOTIENT", int), + ("INACTIVITY_PENALTY_QUOTIENT", int), + ("MIN_SLASHING_PENALTY_QUOTIENT", int), # Max operations per block - ('MAX_PROPOSER_SLASHINGS', int), - ('MAX_ATTESTER_SLASHINGS', int), - ('MAX_ATTESTATIONS', int), - ('MAX_DEPOSITS', int), - ('MAX_VOLUNTARY_EXITS', int), - ('MAX_TRANSFERS', int), + ("MAX_PROPOSER_SLASHINGS", int), + ("MAX_ATTESTER_SLASHINGS", int), + ("MAX_ATTESTATIONS", int), + ("MAX_DEPOSITS", int), + ("MAX_VOLUNTARY_EXITS", int), + ("MAX_TRANSFERS", int), # Deposit contract - ('DEPOSIT_CONTRACT_ADDRESS', bytes), - ) + ("DEPOSIT_CONTRACT_ADDRESS", bytes), + ), ) diff --git a/newsfragments/917.misc.rst b/newsfragments/917.misc.rst new file mode 100644 index 0000000000..43058064c9 --- /dev/null +++ b/newsfragments/917.misc.rst @@ -0,0 +1 @@ +Update to ``lahja>=0.14.2`` to fix warnings during endpoint shutdown. diff --git a/newsfragments/962.feature.rst b/newsfragments/962.feature.rst new file mode 100644 index 0000000000..df41f8c9b6 --- /dev/null +++ b/newsfragments/962.feature.rst @@ -0,0 +1 @@ +``p2p.peer.BasePeer`` now uses ``ConnectionAPI`` for underlying protocol interactions. diff --git a/newsfragments/963.feature.rst b/newsfragments/963.feature.rst new file mode 100644 index 0000000000..eba2b6b02e --- /dev/null +++ b/newsfragments/963.feature.rst @@ -0,0 +1,2 @@ +Allow Trinity to automatically resolve a checkpoint through the etherscan API +using this syntax: ``--beam-from-checkpoint="eth://block/byetherscan/latest"`` \ No newline at end of file diff --git a/newsfragments/965.misc.rst b/newsfragments/965.misc.rst new file mode 100644 index 0000000000..f0c4b2007a --- /dev/null +++ b/newsfragments/965.misc.rst @@ -0,0 +1 @@ +Remove mutation of geth testing fixtures from ``./tests/integration/test_lightchain_integration.py`` test. diff --git a/newsfragments/975.feature.rst b/newsfragments/975.feature.rst new file mode 100644 index 0000000000..40dec62ec7 --- /dev/null +++ b/newsfragments/975.feature.rst @@ -0,0 +1,2 @@ +Fetch missing data from remote peers, if requested over json-rpc during beam sync. +Requests for data at an old block will fail; remote peers probably don't have it. diff --git a/newsfragments/983.misc.rst b/newsfragments/983.misc.rst new file mode 100644 index 0000000000..30c40070fd --- /dev/null +++ b/newsfragments/983.misc.rst @@ -0,0 +1 @@ +Expand test coverage to ensure Ropsten and custom nets do actually work. \ No newline at end of file diff --git a/newsfragments/985.misc.rst b/newsfragments/985.misc.rst new file mode 100644 index 0000000000..c416a68be7 --- /dev/null +++ b/newsfragments/985.misc.rst @@ -0,0 +1 @@ +Add ``ABC`` base class for ``p2p.service.BaseService`` diff --git a/newsfragments/986.feature.rst b/newsfragments/986.feature.rst new file mode 100644 index 0000000000..68f0959c7f --- /dev/null +++ b/newsfragments/986.feature.rst @@ -0,0 +1 @@ +Add ``ConnectionAPI.get_p2p_receipt`` for fetching the ``HandshakeReceipt`` for the base ``p2p`` protocol. diff --git a/newsfragments/987.feature.rst b/newsfragments/987.feature.rst new file mode 100644 index 0000000000..d3e3402f70 --- /dev/null +++ b/newsfragments/987.feature.rst @@ -0,0 +1 @@ +``p2p.protocol.Protocol.supports_command`` is now a ``classmethod`` diff --git a/newsfragments/988.misc.rst b/newsfragments/988.misc.rst new file mode 100644 index 0000000000..b615c42146 --- /dev/null +++ b/newsfragments/988.misc.rst @@ -0,0 +1 @@ +Add ``ABC`` base class ``p2p.abc.HandshakeReceiptAPI`` diff --git a/newsfragments/989.feature.rst b/newsfragments/989.feature.rst new file mode 100644 index 0000000000..e3612adb1c --- /dev/null +++ b/newsfragments/989.feature.rst @@ -0,0 +1 @@ +The ``HandlerSubscriptionAPI`` now supports a context manager interface, removing/cancelling the subscription when the context exits diff --git a/newsfragments/990.feature.rst b/newsfragments/990.feature.rst new file mode 100644 index 0000000000..5eeda335d9 --- /dev/null +++ b/newsfragments/990.feature.rst @@ -0,0 +1 @@ +Handler functions for ``Connection.add_protocol_handler`` and ``Connection.add_command_handler`` now expect the ``Connection`` instance as the first argument. diff --git a/newsfragments/991.misc.rst b/newsfragments/991.misc.rst new file mode 100644 index 0000000000..7838f2d4f5 --- /dev/null +++ b/newsfragments/991.misc.rst @@ -0,0 +1 @@ +Add ``ConnectionAPI.is_dial_out`` and ``ConnectionAPI.start_protocol_streams`` to ABC definition. diff --git a/newsfragments/992.bugfix.rst b/newsfragments/992.bugfix.rst new file mode 100644 index 0000000000..d5e750ab7f --- /dev/null +++ b/newsfragments/992.bugfix.rst @@ -0,0 +1 @@ +Add missing exception handling inside of ``Connection.run`` for ``PeerConnectionLost`` exception that bubbles from multiplexer. ``Connection`` is now responsible for calling ``Multiplexer.close`` on shutdown. Detect a closed connection during handshake. diff --git a/newsfragments/993.misc.rst b/newsfragments/993.misc.rst new file mode 100644 index 0000000000..77923d90c1 --- /dev/null +++ b/newsfragments/993.misc.rst @@ -0,0 +1 @@ +Relax some input types from ``Tuple[thing, ...]`` to ``Sequence[thing]`` diff --git a/newsfragments/994.bugfix.rst b/newsfragments/994.bugfix.rst new file mode 100644 index 0000000000..276720a361 --- /dev/null +++ b/newsfragments/994.bugfix.rst @@ -0,0 +1 @@ +Fix ``P2PProtocol.send_disconnect`` to accept enum values from ``p2p.disconnect.DisconnectReason`` diff --git a/newsfragments/995.misc.rst b/newsfragments/995.misc.rst new file mode 100644 index 0000000000..21b10b0290 --- /dev/null +++ b/newsfragments/995.misc.rst @@ -0,0 +1 @@ +``BasePeerPool.__aiter_`` now checks if the peer is operational. diff --git a/p2p/abc.py b/p2p/abc.py index e6e0f740bf..6f7cee6c01 100644 --- a/p2p/abc.py +++ b/p2p/abc.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import asyncio from typing import ( Any, AsyncContextManager, @@ -6,9 +7,11 @@ Awaitable, Callable, ClassVar, + ContextManager, Dict, Generic, List, + Optional, Tuple, Type, TYPE_CHECKING, @@ -21,11 +24,15 @@ from cancel_token import CancelToken +from eth_utils import ExtendedDebugLogger + from eth_keys import datatypes from p2p.typing import Capability, Capabilities, Payload, Structure +from p2p.transport_state import TransportState if TYPE_CHECKING: + from p2p.handshake import DevP2PReceipt # noqa: F401 from p2p.p2p_proto import ( # noqa: F401 BaseP2PProtocol, ) @@ -190,6 +197,7 @@ class RequestAPI(ABC, Generic[TRequestPayload]): class TransportAPI(ABC): remote: NodeAPI + read_state: TransportState @property @abstractmethod @@ -243,8 +251,9 @@ def __init__(self, transport: TransportAPI, cmd_id_offset: int, snappy_support: def send_request(self, request: RequestAPI[Payload]) -> None: ... + @classmethod @abstractmethod - def supports_command(self, cmd_type: Type[CommandAPI]) -> bool: + def supports_command(cls, cmd_type: Type[CommandAPI]) -> bool: ... @classmethod @@ -325,16 +334,84 @@ def multiplex(self) -> AsyncContextManager[None]: ... -class HandlerSubscriptionAPI: +class ServiceEventsAPI(ABC): + started: asyncio.Event + stopped: asyncio.Event + cleaned_up: asyncio.Event + cancelled: asyncio.Event + finished: asyncio.Event + + +TReturn = TypeVar('TReturn') + + +class AsyncioServiceAPI(ABC): + events: ServiceEventsAPI + cancel_token: CancelToken + + @property + @abstractmethod + def logger(self) -> ExtendedDebugLogger: + ... + + @abstractmethod + def cancel_nowait(self) -> None: + ... + + @property + @abstractmethod + def is_cancelled(self) -> bool: + ... + + @property + @abstractmethod + def is_running(self) -> bool: + ... + + @abstractmethod + async def run( + self, + finished_callback: Optional[Callable[['AsyncioServiceAPI'], None]] = None) -> None: + ... + + @abstractmethod + async def cancel(self) -> None: + ... + + @abstractmethod + def run_daemon(self, service: 'AsyncioServiceAPI') -> None: + ... + + @abstractmethod + async def wait(self, + awaitable: Awaitable[TReturn], + token: CancelToken = None, + timeout: float = None) -> TReturn: + ... + + +class HandshakeReceiptAPI(ABC): + protocol: ProtocolAPI + + +class HandlerSubscriptionAPI(ContextManager['HandlerSubscriptionAPI']): @abstractmethod def cancel(self) -> None: ... -class ConnectionAPI(ABC): +ProtocolHandlerFn = Callable[['ConnectionAPI', CommandAPI, Payload], Awaitable[Any]] +CommandHandlerFn = Callable[['ConnectionAPI', Payload], Awaitable[Any]] + + +class ConnectionAPI(AsyncioServiceAPI): + protocol_receipts: Tuple[HandshakeReceiptAPI, ...] + # # Primary properties of the connection # + is_dial_out: bool + @property @abstractmethod def is_dial_in(self) -> bool: @@ -348,17 +425,21 @@ def remote(self) -> NodeAPI: # # Subscriptions/Handler API # + @abstractmethod + def start_protocol_streams(self) -> None: + ... + @abstractmethod def add_protocol_handler(self, protocol_type: Type[ProtocolAPI], - handler_fn: Callable[[CommandAPI, Payload], Awaitable[Any]], + handler_fn: ProtocolHandlerFn, ) -> HandlerSubscriptionAPI: ... @abstractmethod def add_command_handler(self, command_type: Type[CommandAPI], - handler_fn: Callable[[Payload], Awaitable[Any]], + handler_fn: CommandHandlerFn, ) -> HandlerSubscriptionAPI: ... @@ -376,6 +457,13 @@ def get_multiplexer(self) -> MultiplexerAPI: def get_base_protocol(self) -> 'BaseP2PProtocol': ... + @abstractmethod + def get_p2p_receipt(self) -> 'DevP2PReceipt': + ... + + # + # Connection Metadata + # @property @abstractmethod def remote_capabilities(self) -> Capabilities: diff --git a/p2p/connection.py b/p2p/connection.py index a564afa418..0773a9ec80 100644 --- a/p2p/connection.py +++ b/p2p/connection.py @@ -1,60 +1,52 @@ import asyncio import collections import functools -from typing import Any, Awaitable, Callable, DefaultDict, Set, Tuple, Type +from typing import DefaultDict, Sequence, Set, Type from eth_keys import keys from p2p.abc import ( CommandAPI, HandlerSubscriptionAPI, + HandshakeReceiptAPI, MultiplexerAPI, NodeAPI, ProtocolAPI, + ProtocolHandlerFn, + CommandHandlerFn, ConnectionAPI, ) from p2p.exceptions import ( + PeerConnectionLost, UnknownProtocol, UnknownProtocolCommand, ) -from p2p.handshake import ( - DevP2PReceipt, - HandshakeReceipt, -) +from p2p.handshake import DevP2PReceipt +from p2p.handler_subscription import HandlerSubscription from p2p.service import BaseService -from p2p.p2p_proto import ( - BaseP2PProtocol, -) -from p2p.typing import Capabilities, Payload - - -class HandlerSubscription(HandlerSubscriptionAPI): - def __init__(self, remove_fn: Callable[[], Any]) -> None: - self._remove_fn = remove_fn - - def cancel(self) -> None: - self._remove_fn() +from p2p.p2p_proto import BaseP2PProtocol +from p2p.typing import Capabilities class Connection(ConnectionAPI, BaseService): _protocol_handlers: DefaultDict[ Type[ProtocolAPI], - Set[Callable[[CommandAPI, Payload], Awaitable[Any]]] + Set[ProtocolHandlerFn] ] _command_handlers: DefaultDict[ Type[CommandAPI], - Set[Callable[[Payload], Awaitable[Any]]] + Set[CommandHandlerFn] ] def __init__(self, multiplexer: MultiplexerAPI, devp2p_receipt: DevP2PReceipt, - protocol_receipts: Tuple[HandshakeReceipt, ...], + protocol_receipts: Sequence[HandshakeReceiptAPI], is_dial_out: bool) -> None: super().__init__(token=multiplexer.cancel_token, loop=multiplexer.cancel_token.loop) self._multiplexer = multiplexer self._devp2p_receipt = devp2p_receipt - self._protocol_receipts = protocol_receipts + self.protocol_receipts = tuple(protocol_receipts) self.is_dial_out = is_dial_out self._protocol_handlers = collections.defaultdict(set) @@ -81,11 +73,17 @@ def remote(self) -> NodeAPI: return self._multiplexer.remote async def _run(self) -> None: - async with self._multiplexer.multiplex(): - for protocol in self._multiplexer.get_protocols(): - self.run_daemon_task(self._feed_protocol_handlers(protocol)) + try: + async with self._multiplexer.multiplex(): + for protocol in self._multiplexer.get_protocols(): + self.run_daemon_task(self._feed_protocol_handlers(protocol)) + + await self.cancellation() + except (PeerConnectionLost, asyncio.CancelledError): + pass - await self.cancellation() + async def _cleanup(self) -> None: + self._multiplexer.close() # # Subscriptions/Handler API @@ -116,7 +114,7 @@ async def _feed_protocol_handlers(self, protocol: ProtocolAPI) -> None: protocol, type(cmd), ) - self.run_task(proto_handler_fn(cmd, msg)) + self.run_task(proto_handler_fn(self, cmd, msg)) command_handlers = set(self._command_handlers[type(cmd)]) for cmd_handler_fn in command_handlers: self.logger.debug2( @@ -125,12 +123,12 @@ async def _feed_protocol_handlers(self, protocol: ProtocolAPI) -> None: protocol, type(cmd), ) - self.run_task(cmd_handler_fn(msg)) + self.run_task(cmd_handler_fn(self, msg)) def add_protocol_handler(self, protocol_class: Type[ProtocolAPI], - handler_fn: Callable[[CommandAPI, Payload], Awaitable[Any]], - ) -> HandlerSubscription: + handler_fn: ProtocolHandlerFn, + ) -> HandlerSubscriptionAPI: if not self._multiplexer.has_protocol(protocol_class): raise UnknownProtocol( f"Protocol {protocol_class} was not found int he connected " @@ -145,8 +143,8 @@ def add_protocol_handler(self, def add_command_handler(self, command_type: Type[CommandAPI], - handler_fn: Callable[[Payload], Awaitable[Any]], - ) -> HandlerSubscription: + handler_fn: CommandHandlerFn, + ) -> HandlerSubscriptionAPI: for protocol in self._multiplexer.get_protocols(): if protocol.supports_command(command_type): self._command_handlers[command_type].add(handler_fn) @@ -173,6 +171,12 @@ def get_multiplexer(self) -> MultiplexerAPI: def get_base_protocol(self) -> BaseP2PProtocol: return self._multiplexer.get_base_protocol() + def get_p2p_receipt(self) -> DevP2PReceipt: + return self._devp2p_receipt + + # + # Connection Metadata + # @property def remote_capabilities(self) -> Capabilities: return self._devp2p_receipt.capabilities diff --git a/p2p/discovery.py b/p2p/discovery.py index 295cab7359..a63390358a 100644 --- a/p2p/discovery.py +++ b/p2p/discovery.py @@ -441,7 +441,7 @@ def datagram_received(self, data: Union[bytes, Text], addr: Tuple[str, int]) -> ip_address, udp_port = addr address = Address(ip_address, udp_port) # The prefix below is what geth uses to identify discv5 msgs. - # https://github.com/ethereum/go-ethereum/blob/c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5/p2p/discv5/udp.go#L149 # noqa: E501 + # https://github.com/ethereum/go-ethereum/blob/c4712bf96bc1bae4a5ad4600e9719e4a74bde7d5/p2p/discv5/udp.go#L149 if text_if_str(to_bytes, data).startswith(V5_ID_STRING): self.receive_v5(address, cast(bytes, data)) else: diff --git a/p2p/handler_subscription.py b/p2p/handler_subscription.py new file mode 100644 index 0000000000..2d80f1f68d --- /dev/null +++ b/p2p/handler_subscription.py @@ -0,0 +1,25 @@ +from typing import ( + Any, + Callable, + Type, +) +from types import TracebackType + +from p2p.abc import HandlerSubscriptionAPI + + +class HandlerSubscription(HandlerSubscriptionAPI): + def __init__(self, remove_fn: Callable[[], Any]) -> None: + self._remove_fn = remove_fn + + def cancel(self) -> None: + self._remove_fn() + + def __enter__(self) -> HandlerSubscriptionAPI: + return self + + def __exit__(self, + exc_type: Type[BaseException], + exc_value: BaseException, + exc_tb: TracebackType) -> None: + self._remove_fn() diff --git a/p2p/handshake.py b/p2p/handshake.py index 13789a6eae..41043beed2 100644 --- a/p2p/handshake.py +++ b/p2p/handshake.py @@ -9,6 +9,7 @@ Dict, Iterable, NamedTuple, + Sequence, Type, Tuple, ) @@ -21,7 +22,12 @@ from eth.tools.logging import ExtendedDebugLogger from p2p._utils import duplicates -from p2p.abc import TransportAPI, MultiplexerAPI, ProtocolAPI +from p2p.abc import ( + HandshakeReceiptAPI, + MultiplexerAPI, + ProtocolAPI, + TransportAPI, +) from p2p.constants import ( DEVP2P_V5, ) @@ -46,7 +52,7 @@ ) -class HandshakeReceipt: +class HandshakeReceipt(HandshakeReceiptAPI): """ Data storage object for ephemeral data exchanged during protocol handshakes. @@ -200,7 +206,7 @@ async def _do_p2p_handshake(transport: TransportAPI, async def negotiate_protocol_handshakes(transport: TransportAPI, p2p_handshake_params: DevP2PHandshakeParams, - protocol_handshakers: Tuple[Handshaker, ...], + protocol_handshakers: Sequence[Handshaker], token: CancelToken, ) -> Tuple[MultiplexerAPI, DevP2PReceipt, Tuple[HandshakeReceipt, ...]]: # noqa: E501 """ diff --git a/p2p/multiplexer.py b/p2p/multiplexer.py index 4ffeaf9ef8..c75fe8fc1d 100644 --- a/p2p/multiplexer.py +++ b/p2p/multiplexer.py @@ -39,6 +39,7 @@ from p2p.p2p_proto import BaseP2PProtocol from p2p.protocol import Protocol from p2p.resource_lock import ResourceLock +from p2p.transport_state import TransportState from p2p.typing import Payload @@ -291,14 +292,53 @@ async def multiplex(self) -> AsyncIterator[None]: 'multiplex', loop=self.cancel_token.loop, ).chain(self.cancel_token) + + stop = asyncio.Event() self._multiplex_token = multiplex_token - fut = asyncio.ensure_future(self._do_multiplexing(multiplex_token)) + fut = asyncio.ensure_future(self._do_multiplexing(stop, multiplex_token)) # wait for the multiplexing to actually start try: yield finally: - multiplex_token.trigger() - del self._multiplex_token + # + # Prevent corruption of the Transport: + # + # On exit the `Transport` can be in a few states: + # + # 1. IDLE: between reads + # 2. HEADER: waiting to read the bytes for the message header + # 3. BODY: already read the header, waiting for body bytes. + # + # In the IDLE case we get a clean shutdown by simply signaling + # to `_do_multiplexing` that it should exit which is done with + # an `asyncio.EVent` + # + # In the HEADER case we can issue a hard stop either via + # cancellation or the cancel token. The read *should* be + # interrupted without consuming any bytes from the + # `StreamReader`. + # + # In the BODY case we want to give the `Transport.recv` call a + # moment to finish reading the body after which it will be IDLE + # and will exit via the IDLE exit mechanism. + stop.set() + + # If the transport is waiting to read the body of the message + # we want to give it a moment to finish that read. Otherwise + # this leaves the transport in a corrupt state. + if self._transport.read_state is TransportState.BODY: + try: + await asyncio.wait_for(fut, timeout=1) + except asyncio.TimeoutError: + pass + + # After giving the transport an opportunity to shutdown + # cleanly, we issue a hard shutdown, first via cancellation and + # then via the cancel token. This should only end up + # corrupting the transport in the case where the header data is + # read but the body data takes too long to arrive which should + # be very rare and would likely indicate a malicious or broken + # peer. if fut.done(): fut.result() else: @@ -308,7 +348,10 @@ async def multiplex(self) -> AsyncIterator[None]: except asyncio.CancelledError: pass - async def _do_multiplexing(self, token: CancelToken) -> None: + multiplex_token.trigger() + del self._multiplex_token + + async def _do_multiplexing(self, stop: asyncio.Event, token: CancelToken) -> None: """ Background task that reads messages from the transport and feeds them into individual queues for each of the protocols. @@ -338,3 +381,6 @@ async def _do_multiplexing(self, token: CancelToken) -> None: protocol, cmd, ) + + if stop.is_set(): + break diff --git a/p2p/p2p_proto.py b/p2p/p2p_proto.py index 9c79bed56f..001434d07f 100644 --- a/p2p/p2p_proto.py +++ b/p2p/p2p_proto.py @@ -118,7 +118,7 @@ def send_hello(self, self.transport.send(header, body) def send_disconnect(self, reason: _DisconnectReason) -> None: - msg: Dict[str, Any] = {"reason": reason} + msg: Dict[str, Any] = {"reason": reason.value} header, body = Disconnect( self.cmd_id_offset, self.snappy_support diff --git a/p2p/peer.py b/p2p/peer.py index 4378a3e5cb..6f1d82ca7d 100644 --- a/p2p/peer.py +++ b/p2p/peer.py @@ -28,20 +28,23 @@ from cancel_token import CancelToken -from p2p.abc import CommandAPI, MultiplexerAPI, NodeAPI, ProtocolAPI +from p2p.abc import ( + CommandAPI, + ConnectionAPI, + HandshakeReceiptAPI, + NodeAPI, + ProtocolAPI, +) from p2p.constants import BLACKLIST_SECONDS_BAD_PROTOCOL from p2p.disconnect import DisconnectReason from p2p.exceptions import ( - MalformedMessage, - PeerConnectionLost, - UnexpectedMessage, UnknownProtocol, ) +from p2p.connection import Connection from p2p.handshake import ( negotiate_protocol_handshakes, DevP2PHandshakeParams, DevP2PReceipt, - HandshakeReceipt, Handshaker, ) from p2p.service import BaseService @@ -49,7 +52,6 @@ BaseP2PProtocol, Disconnect, Ping, - Pong, ) from p2p.protocol import ( Command, @@ -61,7 +63,6 @@ NoopConnectionTracker, ) - if TYPE_CHECKING: from p2p.peer_pool import BasePeerPool # noqa: F401 @@ -70,12 +71,11 @@ async def handshake(remote: NodeAPI, private_key: datatypes.PrivateKey, p2p_handshake_params: DevP2PHandshakeParams, protocol_handshakers: Tuple[Handshaker, ...], - token: CancelToken, - ) -> Tuple[MultiplexerAPI, DevP2PReceipt, Tuple[HandshakeReceipt, ...]]: + token: CancelToken) -> ConnectionAPI: """ Perform the auth and P2P handshakes with the given remote. - Return a `Multiplexer` object along with the handshake receipts. + Return a `Connection` object housing all of the negotiated sub protocols. Raises UnreachablePeer if we cannot connect to the peer or HandshakeFailure if the remote disconnects before completing the @@ -104,7 +104,13 @@ async def handshake(remote: NodeAPI, await asyncio.sleep(0) raise - return multiplexer, devp2p_receipt, protocol_receipts + connection = Connection( + multiplexer=multiplexer, + devp2p_receipt=devp2p_receipt, + protocol_receipts=protocol_receipts, + is_dial_out=True, + ) + return connection async def receive_handshake(reader: asyncio.StreamReader, @@ -112,8 +118,7 @@ async def receive_handshake(reader: asyncio.StreamReader, private_key: datatypes.PrivateKey, p2p_handshake_params: DevP2PHandshakeParams, protocol_handshakers: Tuple[Handshaker, ...], - token: CancelToken, - ) -> Tuple[MultiplexerAPI, DevP2PReceipt, Tuple[HandshakeReceipt, ...]]: + token: CancelToken) -> Connection: transport = await Transport.receive_connection( reader=reader, writer=writer, @@ -136,7 +141,13 @@ async def receive_handshake(reader: asyncio.StreamReader, await asyncio.sleep(0) raise - return multiplexer, devp2p_receipt, protocol_receipts + connection = Connection( + multiplexer=multiplexer, + devp2p_receipt=devp2p_receipt, + protocol_receipts=protocol_receipts, + is_dial_out=False, + ) + return connection class BasePeerBootManager(BaseService): @@ -181,23 +192,20 @@ class BasePeer(BaseService): base_protocol: BaseP2PProtocol def __init__(self, - multiplexer: MultiplexerAPI, - devp2p_receipt: DevP2PReceipt, - protocol_receipts: Sequence[HandshakeReceipt], + connection: ConnectionAPI, context: BasePeerContext, - inbound: bool, event_bus: EndpointAPI = None, ) -> None: - super().__init__(token=multiplexer.cancel_token, loop=multiplexer.cancel_token.loop) + super().__init__(token=connection.cancel_token, loop=connection.cancel_token.loop) - # This is currently only used to have access to the `vm_configuration` - # for ETH/LES peers to do their DAO fork check. + # Peer context object self.context = context # Connection instance - self.multiplexer = multiplexer + self.connection = connection + self.multiplexer = connection.get_multiplexer() - self.base_protocol = self.multiplexer.get_base_protocol() + self.base_protocol = self.connection.get_base_protocol() # TODO: need to remove this property but for now it is here to support # backwards compat @@ -210,10 +218,9 @@ def __init__(self, break else: raise ValidationError("No supported subprotocols found in multiplexer") - self.sub_proto = self.multiplexer.get_protocols()[1] # The self-identifying string that the remote names itself. - self.client_version_string = devp2p_receipt.client_version_string + self.client_version_string = self.connection.safe_client_version_string # Optional event bus handle self._event_bus = event_bus @@ -222,7 +229,7 @@ def __init__(self, # established from a dial-out or dial-in (True: dial-in, False: # dial-out) # TODO: rename to `dial_in` and have a computed property for `dial_out` - self.inbound = inbound + self.inbound = connection.is_dial_in self._subscribers: List[PeerSubscriber] = [] # A counter of the number of messages this peer has received for each @@ -233,13 +240,16 @@ def __init__(self, self.boot_manager = self.get_boot_manager() self.connection_tracker = self.setup_connection_tracker() - self.process_handshake_receipts(devp2p_receipt, protocol_receipts) + self.process_handshake_receipts( + connection.get_p2p_receipt(), + connection.protocol_receipts, + ) def process_handshake_receipts(self, devp2p_receipt: DevP2PReceipt, - protocol_receipts: Sequence[HandshakeReceipt]) -> None: + protocol_receipts: Sequence[HandshakeReceiptAPI]) -> None: """ - Hook for subclasses to initialize data based on the protocol handshake. + Noop implementation for subclasses to override. """ pass @@ -270,7 +280,7 @@ def __repr__(self) -> str: # @cached_property def remote(self) -> NodeAPI: - return self.multiplexer.remote + return self.connection.remote @property def is_closing(self) -> bool: @@ -298,84 +308,60 @@ def remove_subscriber(self, subscriber: 'PeerSubscriber') -> None: self._subscribers.remove(subscriber) async def _cleanup(self) -> None: - self.multiplexer.close() + self.connection.cancel_nowait() + + def setup_protocol_handlers(self) -> None: + """ + Hook for subclasses to setup handlers for protocols specific messages. + """ + pass async def _run(self) -> None: + # setup handler to respond to ping messages + self.connection.add_command_handler(Ping, self._ping_handler) + + # setup handler for disconnect messages + self.connection.add_command_handler(Disconnect, self._disconnect_handler) + + # setup handler for protocol messages to pass messages to subscribers + for protocol in self.multiplexer.get_protocols(): + self.connection.add_protocol_handler(type(protocol), self._handle_subscriber_message) + + self.setup_protocol_handlers() + # The `boot` process is run in the background to allow the `run` loop # to continue so that all of the Peer APIs can be used within the # `boot` task. self.run_child_service(self.boot_manager) - try: - async with self.multiplexer.multiplex(): - self.run_daemon_task(self.handle_p2p_proto_stream()) - self.run_daemon_task(self.handle_sub_proto_stream()) - await self.cancellation() - except PeerConnectionLost as err: - self.logger.debug('Peer connection lost: %s: %r', self, err) - self.cancel_nowait() - except MalformedMessage as err: - self.logger.debug('MalformedMessage error with peer: %s: %r', self, err) - await self.disconnect(DisconnectReason.subprotocol_error) - except TimeoutError as err: - # TODO: we should send a ping and see if we get back a pong... - self.logger.debug('TimeoutError error with peer: %s: %r', self, err) - await self.disconnect(DisconnectReason.timeout) - - async def handle_p2p_proto_stream(self) -> None: - """Handle the base protocol (P2P) messages.""" - async for cmd, msg in self.multiplexer.stream_protocol_messages(self.base_protocol): - self.handle_p2p_msg(cmd, msg) - - def handle_p2p_msg(self, cmd: CommandAPI, msg: Payload) -> None: - """Handle the base protocol (P2P) messages.""" - if isinstance(cmd, Disconnect): - msg = cast(Dict[str, Any], msg) - try: - reason = DisconnectReason(msg['reason']) - except TypeError: - self.logger.info('Unrecognized reason: %s', msg['reason']) - else: - self.disconnect_reason = reason - self.cancel_nowait() - return - elif isinstance(cmd, Ping): - self.base_protocol.send_pong() - elif isinstance(cmd, Pong): - # Currently we don't do anything when we get a pong, but eventually we should - # update the last time we heard from a peer in our DB (which doesn't exist yet). - pass - else: - raise UnexpectedMessage(f"Unexpected msg: {cmd} ({msg})") - async def handle_sub_proto_stream(self) -> None: - async for cmd, msg in self.multiplexer.stream_protocol_messages(self.sub_proto): - self.handle_sub_proto_msg(cmd, msg) + # Trigger the connection to start feeding messages though the handlers + self.connection.start_protocol_streams() - def handle_sub_proto_msg(self, cmd: CommandAPI, msg: Payload) -> None: - cmd_type = type(cmd) + await self.cancellation() - if self._subscribers: - was_added = tuple( - subscriber.add_msg(PeerMessage(self, cmd, msg)) - for subscriber - in self._subscribers - ) - if not any(was_added): - self.logger.warning( - "Peer %s has no subscribers for msg type %s", - self, - cmd_type.__name__, - ) + async def _ping_handler(self, connection: ConnectionAPI, msg: Payload) -> None: + self.base_protocol.send_pong() + + async def _disconnect_handler(self, connection: ConnectionAPI, msg: Payload) -> None: + msg = cast(Dict[str, Any], msg) + try: + reason = DisconnectReason(msg['reason']) + except TypeError: + self.logger.info('Unrecognized reason: %s', msg['reason_name']) else: - self.logger.warning("Peer %s has no subscribers, discarding %s msg", self, cmd) + self.disconnect_reason = reason - def _disconnect(self, reason: DisconnectReason) -> None: - if not isinstance(reason, DisconnectReason): - raise ValueError( - f"Reason must be an item of DisconnectReason, got {reason}" - ) + self.cancel_nowait() - self.disconnect_reason = reason + async def _handle_subscriber_message(self, + connection: ConnectionAPI, + cmd: CommandAPI, + msg: Payload) -> None: + subscriber_msg = PeerMessage(self, cmd, msg) + for subscriber in self._subscribers: + subscriber.add_msg(subscriber_msg) + + def _disconnect(self, reason: DisconnectReason) -> None: if reason is DisconnectReason.bad_protocol: self.connection_tracker.record_blacklist( self.remote, @@ -384,8 +370,8 @@ def _disconnect(self, reason: DisconnectReason) -> None: ) self.logger.debug("Disconnecting from remote peer %s; reason: %s", self.remote, reason.name) - self.base_protocol.send_disconnect(reason.value) - self.multiplexer.close() + self.base_protocol.send_disconnect(reason) + self.cancel_nowait() async def disconnect(self, reason: DisconnectReason) -> None: """Send a disconnect msg to the remote node and stop this Peer. @@ -592,30 +578,19 @@ async def handshake(self, remote: NodeAPI) -> BasePeer: self.context.p2p_version, ) handshakers = await self.get_handshakers() - multiplexer, devp2p_receipt, protocol_receipts = await handshake( + connection = await handshake( remote=remote, private_key=self.privkey, p2p_handshake_params=p2p_handshake_params, protocol_handshakers=handshakers, token=self.cancel_token ) - return self.create_peer( - multiplexer=multiplexer, - devp2p_receipt=devp2p_receipt, - protocol_receipts=protocol_receipts, - inbound=False, - ) + return self.create_peer(connection) def create_peer(self, - multiplexer: MultiplexerAPI, - devp2p_receipt: DevP2PReceipt, - protocol_receipts: Sequence[HandshakeReceipt], - inbound: bool) -> BasePeer: + connection: ConnectionAPI) -> BasePeer: return self.peer_class( - multiplexer=multiplexer, - devp2p_receipt=devp2p_receipt, - protocol_receipts=protocol_receipts, + connection=connection, context=self.context, - inbound=False, event_bus=self.event_bus, ) diff --git a/p2p/peer_pool.py b/p2p/peer_pool.py index 6070c8beb5..4043870d58 100644 --- a/p2p/peer_pool.py +++ b/p2p/peer_pool.py @@ -34,7 +34,7 @@ EndpointAPI, ) -from p2p.abc import NodeAPI +from p2p.abc import AsyncioServiceAPI, NodeAPI from p2p.constants import ( DEFAULT_MAX_PEERS, DEFAULT_PEER_BOOT_TIMEOUT, @@ -252,6 +252,9 @@ def unsubscribe(self, subscriber: PeerSubscriber) -> None: peer.remove_subscriber(subscriber) async def start_peer(self, peer: BasePeer) -> None: + self.run_child_service(peer.connection) + await self.wait(peer.connection.events.started.wait(), timeout=1) + self.run_child_service(peer) await self.wait(peer.events.started.wait(), timeout=1) if peer.is_operational: @@ -439,7 +442,7 @@ async def connect_to_node(self, node: NodeAPI) -> None: else: await self.start_peer(peer) - def _peer_finished(self, peer: BaseService) -> None: + def _peer_finished(self, peer: AsyncioServiceAPI) -> None: """ Remove the given peer from our list of connected nodes. This is passed as a callback to be called when a peer finishes. @@ -463,7 +466,7 @@ async def __aiter__(self) -> AsyncIterator[BasePeer]: # Yield control to ensure we process any disconnection requests from peers. Otherwise # we could return peers that should have been disconnected already. await asyncio.sleep(0) - if not peer.is_closing: + if peer.is_operational and not peer.is_closing: yield peer async def _periodically_report_stats(self) -> None: diff --git a/p2p/protocol.py b/p2p/protocol.py index 5fa06f8b51..50ec82624d 100644 --- a/p2p/protocol.py +++ b/p2p/protocol.py @@ -172,8 +172,9 @@ def send_request(self, request: RequestAPI[Payload]) -> None: header, body = command.encode(request.command_payload) self.transport.send(header, body) - def supports_command(self, cmd_type: Type[CommandAPI]) -> bool: - return cmd_type in self.cmd_by_type + @classmethod + def supports_command(cls, cmd_type: Type[CommandAPI]) -> bool: + return cmd_type in cls._commands @classmethod def as_capability(cls) -> Capability: diff --git a/p2p/service.py b/p2p/service.py index df369462f8..6de662fb65 100644 --- a/p2p/service.py +++ b/p2p/service.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod import asyncio import concurrent import functools @@ -25,10 +25,11 @@ from eth.tools.logging import ExtendedDebugLogger +from p2p.abc import ServiceEventsAPI, AsyncioServiceAPI from p2p.cancellable import CancellableMixin -class ServiceEvents: +class ServiceEvents(ServiceEventsAPI): def __init__(self) -> None: self.started = asyncio.Event() self.stopped = asyncio.Event() @@ -37,12 +38,12 @@ def __init__(self) -> None: self.finished = asyncio.Event() -class BaseService(ABC, CancellableMixin): +class BaseService(CancellableMixin, AsyncioServiceAPI): logger: ExtendedDebugLogger = None # Use a WeakSet so that we don't have to bother updating it when tasks finish. - _child_services: 'WeakSet[BaseService]' + _child_services: 'WeakSet[AsyncioServiceAPI]' _tasks: 'WeakSet[asyncio.Future[Any]]' - _finished_callbacks: List[Callable[['BaseService'], None]] + _finished_callbacks: List[Callable[[AsyncioServiceAPI], None]] # Number of seconds cancel() will wait for run() to finish. _wait_until_finished_timeout = 5 @@ -95,7 +96,7 @@ def get_event_loop(self) -> asyncio.AbstractEventLoop: async def run( self, - finished_callback: Optional[Callable[['BaseService'], None]] = None) -> None: + finished_callback: Optional[Callable[[AsyncioServiceAPI], None]] = None) -> None: """Await for the service's _run() coroutine. Once _run() returns, triggers the cancel token, call cleanup() and @@ -141,7 +142,7 @@ async def run( self.events.finished.set() self.logger.debug("%s halted cleanly", self) - def add_finished_callback(self, finished_callback: Callable[['BaseService'], None]) -> None: + def add_finished_callback(self, finished_callback: Callable[[AsyncioServiceAPI], None]) -> None: self._finished_callbacks.append(finished_callback) def run_task(self, awaitable: Awaitable[Any]) -> None: @@ -185,7 +186,7 @@ async def _run_daemon_task_wrapper() -> None: self.cancel_nowait() self.run_task(_run_daemon_task_wrapper()) - def run_child_service(self, child_service: 'BaseService') -> None: + def run_child_service(self, child_service: AsyncioServiceAPI) -> None: """ Run a child service and keep a reference to it to be considered during the cleanup. """ @@ -201,7 +202,7 @@ def run_child_service(self, child_service: 'BaseService') -> None: self._child_services.add(child_service) self.run_task(child_service.run()) - def run_daemon(self, service: 'BaseService') -> None: + def run_daemon(self, service: AsyncioServiceAPI) -> None: """ Run a service and keep a reference to it to be considered during the cleanup. @@ -396,7 +397,7 @@ def service_timeout(timeout: int) -> Callable[..., Any]: """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) - async def wrapped(service: BaseService, *args: Any, **kwargs: Any) -> Any: + async def wrapped(service: AsyncioServiceAPI, *args: Any, **kwargs: Any) -> Any: return await service.wait( func(service, *args, **kwargs), timeout=timeout, @@ -413,7 +414,7 @@ async def _cleanup(self) -> None: pass -TService = TypeVar('TService', bound=BaseService) +TService = TypeVar('TService', bound=AsyncioServiceAPI) @asynccontextmanager diff --git a/p2p/tools/connection.py b/p2p/tools/connection.py index 26d66cf58f..bd858ab759 100644 --- a/p2p/tools/connection.py +++ b/p2p/tools/connection.py @@ -9,11 +9,11 @@ async def do_ping_pong_test(alice_connection: ConnectionAPI, bob_connection: Con got_ping = asyncio.Event() got_pong = asyncio.Event() - async def _handle_ping(msg: Any) -> None: + async def _handle_ping(connection: ConnectionAPI, msg: Any) -> None: got_ping.set() bob_connection.get_base_protocol().send_pong() - async def _handle_pong(msg: Any) -> None: + async def _handle_pong(connection: ConnectionAPI, msg: Any) -> None: got_pong.set() alice_connection.add_command_handler(Pong, _handle_pong) diff --git a/p2p/tools/factories/peer.py b/p2p/tools/factories/peer.py index b36197300f..67f8924e48 100644 --- a/p2p/tools/factories/peer.py +++ b/p2p/tools/factories/peer.py @@ -1,4 +1,3 @@ -import asyncio from typing import cast, AsyncContextManager, AsyncIterator, Tuple, Type from async_generator import asynccontextmanager @@ -10,15 +9,13 @@ from eth_keys import keys from p2p.abc import NodeAPI -from p2p.handshake import negotiate_protocol_handshakes from p2p.peer import BasePeer, BasePeerContext, BasePeerFactory from p2p.service import run_service from p2p.tools.paragon import ParagonPeer, ParagonContext, ParagonPeerFactory from .cancel_token import CancelTokenFactory -from .p2p_proto import DevP2PHandshakeParamsFactory -from .transport import MemoryTransportPairFactory +from .connection import ConnectionPairFactory @asynccontextmanager @@ -56,65 +53,28 @@ async def PeerPairFactory(*, event_bus=event_bus, ) - # Establish linked transports for peer communication. - alice_transport, bob_transport = MemoryTransportPairFactory( + alice_handshakers = await alice_factory.get_handshakers() + bob_handshakers = await bob_factory.get_handshakers() + + connection_pair = ConnectionPairFactory( + alice_handshakers=alice_handshakers, + bob_handshakers=bob_handshakers, alice_remote=alice_remote, alice_private_key=alice_private_key, + alice_client_version=alice_client_version, + alice_p2p_version=alice_p2p_version, bob_remote=bob_remote, bob_private_key=bob_private_key, + bob_client_version=bob_client_version, + bob_p2p_version=bob_p2p_version, + cancel_token=cancel_token, ) + async with connection_pair as (alice_connection, bob_connection): + alice = alice_factory.create_peer(connection=alice_connection) + bob = bob_factory.create_peer(connection=bob_connection) - # Get their respective base DevP2P handshake parameters - alice_p2p_handshake_params = DevP2PHandshakeParamsFactory( - listen_port=alice_transport.remote.address.tcp_port, - client_version_string=alice_client_version, - version=alice_p2p_version, - ) - bob_p2p_handshake_params = DevP2PHandshakeParamsFactory( - listen_port=bob_transport.remote.address.tcp_port, - client_version_string=bob_client_version, - version=bob_p2p_version, - ) - - # Get their protocol handshakers - alice_handshakers = await alice_factory.get_handshakers() - bob_handshakers = await bob_factory.get_handshakers() - - # Perform the handshake between the two peers. - ( - (alice_multiplexer, alice_devp2p_receipt, alice_protocol_receipts), - (bob_multiplexer, bob_devp2p_receipt, bob_protocol_receipts), - ) = await asyncio.wait_for(asyncio.gather( - negotiate_protocol_handshakes( - transport=alice_transport, - p2p_handshake_params=alice_p2p_handshake_params, - protocol_handshakers=alice_handshakers, - token=cancel_token, - ), - negotiate_protocol_handshakes( - transport=bob_transport, - p2p_handshake_params=bob_p2p_handshake_params, - protocol_handshakers=bob_handshakers, - token=cancel_token, - ), - ), timeout=1) - - # Create the peer instances - alice = alice_factory.create_peer( - multiplexer=alice_multiplexer, - devp2p_receipt=alice_devp2p_receipt, - protocol_receipts=alice_protocol_receipts, - inbound=False, - ) - bob = bob_factory.create_peer( - multiplexer=bob_multiplexer, - devp2p_receipt=bob_devp2p_receipt, - protocol_receipts=bob_protocol_receipts, - inbound=True, - ) - - async with run_service(alice), run_service(bob): - yield alice, bob + async with run_service(alice), run_service(bob): + yield alice, bob def ParagonPeerPairFactory(*, diff --git a/p2p/tools/memory_transport.py b/p2p/tools/memory_transport.py index 88601ffd9c..e9a3b8e9cd 100644 --- a/p2p/tools/memory_transport.py +++ b/p2p/tools/memory_transport.py @@ -11,8 +11,9 @@ from p2p._utils import get_devp2p_cmd_id from p2p.abc import NodeAPI, TransportAPI -from p2p.tools.asyncio_streams import get_directly_connected_streams from p2p.exceptions import PeerConnectionLost +from p2p.tools.asyncio_streams import get_directly_connected_streams +from p2p.transport_state import TransportState CONNECTION_LOST_ERRORS = ( @@ -24,6 +25,7 @@ class MemoryTransport(TransportAPI): logger = logging.getLogger('p2p.tools.memory_transport.MemoryTransport') + read_state = TransportState.IDLE def __init__(self, remote: NodeAPI, @@ -72,9 +74,16 @@ def write(self, data: bytes) -> None: self._writer.write(data) async def recv(self, token: CancelToken) -> bytes: - encoded_size = await self.read(3, token) + self.read_state = TransportState.HEADER + try: + encoded_size = await self.read(3, token) + except asyncio.CancelledError: + self.read_state = TransportState.IDLE + raise (size,) = struct.unpack(b'>I', b'\x00' + encoded_size) + self.read_state = TransportState.BODY data = await self.read(size, token) + self.read_state = TransportState.IDLE return data def send(self, header: bytes, body: bytes) -> None: diff --git a/p2p/tools/paragon/__init__.py b/p2p/tools/paragon/__init__.py index 1edf3ea6a1..1ed6427b4d 100644 --- a/p2p/tools/paragon/__init__.py +++ b/p2p/tools/paragon/__init__.py @@ -7,6 +7,7 @@ ParagonProtocol, ) from .peer import ( # noqa: F401 + ParagonHandshaker, ParagonContext, ParagonMockPeerPoolWithConnectedPeers, ParagonPeer, diff --git a/p2p/transport.py b/p2p/transport.py index f8df127af9..50dc13e677 100644 --- a/p2p/transport.py +++ b/p2p/transport.py @@ -47,10 +47,17 @@ UnreachablePeer, ) from p2p.kademlia import Address, Node +from p2p.transport_state import TransportState class Transport(TransportAPI): - logger = cast(ExtendedDebugLogger, logging.getLogger('p2p.connection.Transport')) + logger = cast(ExtendedDebugLogger, logging.getLogger('p2p.transport.Transport')) + + # This status flag allows those managing a `Transport` to determine the + # proper cancellation strategy if the transport is mid-read. Hard + # cancellations are allowed for both `IDLE` and `HEADER`. A hard + # cancellation during `BODY` will leave the transport in a corrupt state. + read_state: TransportState = TransportState.IDLE def __init__(self, remote: NodeAPI, @@ -193,6 +200,9 @@ async def receive_connection(cls, auth_ack_msg = responder.create_auth_ack_message(responder_nonce) auth_ack_ciphertext = responder.encrypt_auth_ack_message(auth_ack_msg) + if writer.transport.is_closing() or reader.at_eof(): + raise HandshakeFailure("Connection is closing") + # Use the `writer` to send the reply to the remote writer.write(auth_ack_ciphertext) await token.cancellable_wait(writer.drain()) @@ -236,7 +246,30 @@ def write(self, data: bytes) -> None: self._writer.write(data) async def recv(self, token: CancelToken) -> bytes: - header_data = await self.read(HEADER_LEN + MAC_LEN, token) + # Check that Transport read state is IDLE. + if self.read_state is not TransportState.IDLE: + # This is logged at INFO level because it indicates we are not + # properly managing the Transport and are interrupting it mid-read + # somewhere. + self.logger.info( + 'Corrupted transport: %s - state=%s', + self, + self.read_state.name, + ) + raise Exception(f"Corrupted transport: {self} - state={self.read_state.name}") + + # Set status to indicate we are waiting to read the message header + self.read_state = TransportState.HEADER + + try: + header_data = await self.read(HEADER_LEN + MAC_LEN, token) + except asyncio.CancelledError: + self.logger.debug('Transport cancelled during header read. resetting to IDLE state') + self.read_state = TransportState.IDLE + raise + + # Set status to indicate we are waiting to read the message body + self.read_state = TransportState.BODY try: header = self._decrypt_header(header_data) except DecryptionError as err: @@ -259,6 +292,8 @@ async def recv(self, token: CancelToken) -> bytes: ) raise MalformedMessage from err + # Reset status back to IDLE + self.read_state = TransportState.IDLE return msg def send(self, header: bytes, body: bytes) -> None: diff --git a/p2p/transport_state.py b/p2p/transport_state.py new file mode 100644 index 0000000000..2943615cc1 --- /dev/null +++ b/p2p/transport_state.py @@ -0,0 +1,7 @@ +import enum + + +class TransportState(enum.Enum): + IDLE = enum.auto() + HEADER = enum.auto() + BODY = enum.auto() diff --git a/setup.py b/setup.py index 51f7c4323d..d134f871a4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import os from setuptools import setup, find_packages -PYEVM_DEPENDENCY = "py-evm==0.3.0a5" # noqa: E501 +PYEVM_DEPENDENCY = "py-evm==0.3.0a5" deps = { @@ -33,9 +33,7 @@ "plyvel==1.0.5", PYEVM_DEPENDENCY, "web3==4.4.1", - "lahja==0.14.0", - # Exact version pin until connection timeout issue is resolved - # "lahja>=0.14.0,<0.15.0", + "lahja>=0.14.2,<0.15.0", "termcolor>=1.1.0,<2.0.0", "uvloop==0.11.2;platform_system=='Linux' or platform_system=='Darwin' or platform_system=='FreeBSD'", # noqa: E501 "websockets==5.0.1", @@ -111,10 +109,14 @@ "py-ecc==1.7.1", "rlp>=1.1.0,<2.0.0", PYEVM_DEPENDENCY, - "ssz==0.1.3", + "ssz==0.1.4", "milagro-bls-binding==0.1.3", "blspy>=0.1.8,<1", # for `bls_chia` ], + 'eth2-lint': [ + "black==19.3b0", + "isort==4.3.21", + ], } # NOTE: Snappy breaks RTD builds. Until we have a more mature solution @@ -130,7 +132,8 @@ deps['test'] + deps['doc'] + deps['lint'] + - deps['eth2'] + deps['eth2'] + + deps['eth2-lint'] ) diff --git a/tests/conftest.py b/tests/conftest.py index 203735fb02..3cb75fa689 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -243,7 +243,7 @@ def _chain_with_block_validation(base_db, genesis_state, chain_cls=Chain): "extra_data": b"B", "gas_limit": 3141592, "gas_used": 0, - "mix_hash": decode_hex("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421"), # noqa: E501 + "mix_hash": decode_hex("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421"), "nonce": decode_hex("0102030405060708"), "block_number": 0, "parent_hash": decode_hex("0000000000000000000000000000000000000000000000000000000000000000"), # noqa: E501 @@ -335,7 +335,9 @@ def chain_without_block_validation( @pytest.fixture() def rpc(chain_with_block_validation, event_bus): return RPCServer( - initialize_eth1_modules(chain_with_block_validation, event_bus), event_bus, + initialize_eth1_modules(chain_with_block_validation, event_bus), + chain_with_block_validation, + event_bus, ) diff --git a/tests/core/json-rpc/test_rpc_during_beam_sync.py b/tests/core/json-rpc/test_rpc_during_beam_sync.py index a93d837edb..311ec34ec6 100644 --- a/tests/core/json-rpc/test_rpc_during_beam_sync.py +++ b/tests/core/json-rpc/test_rpc_during_beam_sync.py @@ -3,6 +3,9 @@ import os import pytest import time +from typing import Dict + +from async_generator import asynccontextmanager from eth_hash.auto import keccak from eth_utils.toolz import ( @@ -10,6 +13,8 @@ ) from eth_utils import ( decode_hex, + function_signature_to_4byte_selector, + to_hex, ) from eth.db.account import AccountDB @@ -88,8 +93,8 @@ async def get_ipc_response( @pytest.fixture -def chain(chain_with_block_validation): - return chain_with_block_validation +def chain(chain_without_block_validation): + return chain_without_block_validation @pytest.fixture @@ -151,61 +156,81 @@ async def make_request(*args): return make_request +@pytest.fixture +def fake_beam_syncer(chain, event_bus): + @asynccontextmanager + async def fake_beam_sync(removed_nodes: Dict): + # beam sync starts, it fetches requested nodes from remote peers + + def replace_missing_node(missing_node_hash): + if missing_node_hash not in removed_nodes: + raise Exception(f'An unexpected node was requested: {missing_node_hash}') + chain.chaindb.db[missing_node_hash] = removed_nodes.pop(missing_node_hash) + + async def collect_accounts(event: CollectMissingAccount): + replace_missing_node(event.missing_node_hash) + await event_bus.broadcast( + MissingAccountCollected(1), event.broadcast_config() + ) + accounts_sub = event_bus.subscribe(CollectMissingAccount, collect_accounts) + + async def collect_bytecodes(event: CollectMissingBytecode): + replace_missing_node(event.bytecode_hash) + await event_bus.broadcast( + MissingBytecodeCollected(), event.broadcast_config() + ) + bytecode_sub = event_bus.subscribe(CollectMissingBytecode, collect_bytecodes) + + async def collect_storage(event: CollectMissingStorage): + replace_missing_node(event.missing_node_hash) + await event_bus.broadcast( + MissingStorageCollected(1), event.broadcast_config() + ) + storage_sub = event_bus.subscribe(CollectMissingStorage, collect_storage) + + await event_bus.wait_until_any_endpoint_subscribed_to(CollectMissingAccount) + await event_bus.wait_until_any_endpoint_subscribed_to(CollectMissingBytecode) + await event_bus.wait_until_any_endpoint_subscribed_to(CollectMissingStorage) + + try: + yield + finally: + accounts_sub.unsubscribe() + bytecode_sub.unsubscribe() + storage_sub.unsubscribe() + + return fake_beam_sync + + # Test that eth_getBalance works during beam sync @pytest.mark.asyncio -async def test_get_balance_works( - ipc_request, funded_address, funded_address_initial_balance): +async def test_getBalance_during_beam_sync( + chain, ipc_request, funded_address, funded_address_initial_balance, + fake_beam_syncer): """ Sanity check, if we call eth_getBalance we get back the expected response. """ + + # sanity check, by default it works response = await ipc_request('eth_getBalance', [funded_address.hex(), 'latest']) assert 'error' not in response assert response['result'] == hex(funded_address_initial_balance) + state_root_hash = chain.get_canonical_head().state_root + state_root = chain.chaindb.db.pop(state_root_hash) -@pytest.fixture -def missing_node(chain): - state_root = chain.get_canonical_head().state_root - return chain.chaindb.db.pop(state_root) - - -@pytest.mark.asyncio -async def test_fails_when_state_is_missing(ipc_request, funded_address, missing_node): - """ - If the state root is missing then eth_getBalance throws an error. - """ - response = await ipc_request('eth_getBalance', [funded_address.hex(), 'latest']) - assert 'error' in response - assert response['error'].startswith('State trie database is missing node for hash') - - -@pytest.mark.asyncio -async def test_missing_state_is_fetched_if_fetcher_exists( - ipc_request, funded_address, funded_address_initial_balance, - missing_node, chain, event_bus): - - # beam sync is not running, so we receive an error + # now that the hash is missing we should receive an error response = await ipc_request('eth_getBalance', [funded_address.hex(), 'latest']) assert 'error' in response assert response['error'].startswith('State trie database is missing node for hash') - # beam sync starts, it fetches requested nodes from remote peers - async def find_and_insert_node(event: CollectMissingAccount): - state_root = chain.get_canonical_head().state_root - chain.chaindb.db[state_root] = missing_node - await event_bus.broadcast(MissingAccountCollected(1), event.broadcast_config()) - event_bus.subscribe(CollectMissingAccount, find_and_insert_node) - await event_bus.wait_until_any_endpoint_subscribed_to(CollectMissingAccount) - - # beam sync fetches the missing node so no error is returned - response = await ipc_request('eth_getBalance', [funded_address.hex(), 'latest']) - assert 'error' not in response - assert response['result'] == hex(funded_address_initial_balance) - - -# Test that eth_getCode works during beam sync + # with a beam syncer running it should work again! It sends requests to the syncer + async with fake_beam_syncer({state_root_hash: state_root}): + response = await ipc_request('eth_getBalance', [funded_address.hex(), 'latest']) + assert 'error' not in response + assert response['result'] == hex(funded_address_initial_balance) @pytest.fixture @@ -214,107 +239,164 @@ async def contract_code_hash(genesis_state, simple_contract_address): @pytest.mark.asyncio -async def test_getCode(ipc_request, simple_contract_address, contract_code_hash): - """ - Sanity check, if we call eth_getBalance we get back the expected response. - """ +async def test_getCode_during_beam_sync( + chain, ipc_request, simple_contract_address, contract_code_hash, + fake_beam_syncer): + + # sanity check, by default it works response = await ipc_request('eth_getCode', [simple_contract_address.hex(), 'latest']) assert 'error' not in response assert keccak(decode_hex(response['result'])) == contract_code_hash + missing_bytecode = chain.chaindb.db.pop(contract_code_hash) -@pytest.fixture -def missing_bytecode(chain, contract_code_hash): - return chain.chaindb.db.pop(contract_code_hash) - - -@pytest.mark.asyncio -async def test_getCode_fails_when_state_is_missing( - ipc_request, simple_contract_address, missing_bytecode): - """ - If the state root is missing then eth_getBalance throws an error. - """ + # now that the hash is missing we should receive an error response = await ipc_request('eth_getCode', [simple_contract_address.hex(), 'latest']) assert 'error' in response assert response['error'].startswith('Database is missing bytecode for code hash') + # with a beam syncer running it should work again! It sends requests to the syncer + async with fake_beam_syncer({contract_code_hash: missing_bytecode}): + response = await ipc_request('eth_getCode', [simple_contract_address.hex(), 'latest']) + assert 'error' not in response + assert keccak(decode_hex(response['result'])) == contract_code_hash -@pytest.mark.asyncio -async def test_missing_code_is_fetched_if_fetcher_exists( - ipc_request, simple_contract_address, contract_code_hash, missing_bytecode, chain, - event_bus): - # beam sync is not running, so we receive an error - response = await ipc_request('eth_getCode', [simple_contract_address.hex(), 'latest']) - assert 'error' in response - assert response['error'].startswith('Database is missing bytecode for code hash') +# Test that eth_getStorageAt works during Beam Sync - # beam sync starts, it fetches requested nodes from remote peers - async def find_and_insert_node(event: CollectMissingBytecode): - chain.chaindb.db[contract_code_hash] = missing_bytecode - await event_bus.broadcast(MissingBytecodeCollected(), event.broadcast_config()) - event_bus.subscribe(CollectMissingBytecode, find_and_insert_node) - await event_bus.wait_until_any_endpoint_subscribed_to(CollectMissingBytecode) - # beam sync fetches the missing node so no error is returned - response = await ipc_request('eth_getCode', [simple_contract_address.hex(), 'latest']) - assert 'error' not in response - assert keccak(decode_hex(response['result'])) == contract_code_hash +@pytest.fixture +def storage_root(chain, simple_contract_address): + state_root = chain.get_canonical_head().state_root + account_db = AccountDB(chain.chaindb.db, state_root) + return account_db._get_storage_root(simple_contract_address) -# Test that eth_getStorageAt works during Beam Sync +@pytest.mark.asyncio +async def test_getStorageAt_during_beam_sync( + ipc_request, simple_contract_address, storage_root, chain, fake_beam_syncer): + params = [simple_contract_address.hex(), 1, 'latest'] -@pytest.mark.asyncio -async def test_getStorageAt(ipc_request, simple_contract_address): - """ - Sanity check, if we call eth_getBalance we get back the expected response. - """ - response = await ipc_request('eth_getStorageAt', [simple_contract_address.hex(), 1, 'latest']) + # sanity check, by default it works + response = await ipc_request('eth_getStorageAt', params) assert 'error' not in response assert response['result'] == '0x01' # this was set in the genesis_state fixture + missing_node = chain.chaindb.db.pop(storage_root) -@pytest.fixture -def storage_root(chain, simple_contract_address): - state_root = chain.get_canonical_head().state_root - account_db = AccountDB(chain.chaindb.db, state_root) - return account_db._get_storage_root(simple_contract_address) + # now that the hash is missing we should receive an error + response = await ipc_request('eth_getStorageAt', params) + assert 'error' in response + assert response['error'].startswith('Storage trie database is missing hash') + + # with a beam syncer running it should work again! It sends requests to the syncer + async with fake_beam_syncer({storage_root: missing_node}): + response = await ipc_request('eth_getStorageAt', params) + assert 'error' not in response + assert response['result'] == '0x01' # this was set in the genesis_state fixture @pytest.fixture -def missing_storage_root(chain, storage_root): - return chain.chaindb.db.pop(storage_root) +def transaction(simple_contract_address): + function_selector = function_signature_to_4byte_selector('getMeaningOfLife()') + return { + 'from': '0x' + 'ff' * 20, # unfunded address + 'to': to_hex(simple_contract_address), + 'gasPrice': to_hex(0), + 'data': to_hex(function_selector), + } @pytest.mark.asyncio -async def test_missing_root_get_storage(ipc_request, simple_contract_address, missing_storage_root): - """ - Sanity check, if we call eth_getBalance we get back the expected response. - """ - response = await ipc_request('eth_getStorageAt', [simple_contract_address.hex(), 1, 'latest']) +async def test_eth_call( + ipc_request, contract_code_hash, chain, transaction, fake_beam_syncer): + + # sanity check, by default it works + response = await ipc_request('eth_call', [transaction, 'latest']) + assert 'error' not in response + assert response['result'].endswith('002a') + + bytecode = chain.chaindb.db.pop(contract_code_hash) + + # now that the hash is missing we should receive an error + response = await ipc_request('eth_call', [transaction, 'latest']) assert 'error' in response - assert response['error'].startswith('Storage trie database is missing hash') + assert response['error'].startswith('Database is missing bytecode for code hash') + + # with a beam syncer running it should work again! It sends requests to the syncer + async with fake_beam_syncer({contract_code_hash: bytecode}): + response = await ipc_request('eth_call', [transaction, 'latest']) + assert 'error' not in response + assert response['result'].endswith('002a') @pytest.mark.asyncio -async def test_missing_storage_is_fetched_if_fetcher_exists( - ipc_request, simple_contract_address, storage_root, missing_storage_root, chain, - event_bus): +async def test_eth_call_multiple_missing_nodes( + ipc_request, contract_code_hash, storage_root, + chain, transaction, fake_beam_syncer): + + state_root_hash = chain.get_canonical_head().state_root + missing_nodes = { + state_root_hash: chain.chaindb.db.pop(state_root_hash), + contract_code_hash: chain.chaindb.db.pop(contract_code_hash), + storage_root: chain.chaindb.db.pop(storage_root), + } - # beam sync is not running, so we receive an error - response = await ipc_request('eth_getStorageAt', [simple_contract_address.hex(), 1, 'latest']) + # now that the hash is missing we should receive an error + response = await ipc_request('eth_call', [transaction, 'latest']) assert 'error' in response - assert response['error'].startswith('Storage trie database is missing hash') + assert 'missing' in response['error'] + + # with a beam syncer running it should work again! It sends requests to the syncer + async with fake_beam_syncer(missing_nodes): + response = await ipc_request('eth_call', [transaction, 'latest']) + assert 'error' not in response + assert response['result'].endswith('002a') - # beam sync starts, it fetches requested nodes from remote peers - async def find_and_insert_node(event: CollectMissingStorage): - chain.chaindb.db[storage_root] = missing_storage_root - await event_bus.broadcast(MissingStorageCollected(1), event.broadcast_config()) - event_bus.subscribe(CollectMissingStorage, find_and_insert_node) - await event_bus.wait_until_any_endpoint_subscribed_to(CollectMissingStorage) - # beam sync fetches the missing node so no error is returned - response = await ipc_request('eth_getStorageAt', [simple_contract_address.hex(), 1, 'latest']) +@pytest.mark.asyncio +async def test_eth_estimateGas( + ipc_request, contract_code_hash, chain, transaction, fake_beam_syncer): + + # sanity check, by default it works + response = await ipc_request('eth_estimateGas', [transaction, 'latest']) assert 'error' not in response - assert response['result'] == '0x01' # this was set in the genesis_state fixture + assert response['result'] == '0x82a8' + + bytecode = chain.chaindb.db.pop(contract_code_hash) + + # now that the hash is missing we should receive an error + response = await ipc_request('eth_estimateGas', [transaction, 'latest']) + assert 'error' in response + assert response['error'].startswith('Database is missing bytecode for code hash') + + # with a beam syncer running it should work again! It sends requests to the syncer + async with fake_beam_syncer({contract_code_hash: bytecode}): + response = await ipc_request('eth_estimateGas', [transaction, 'latest']) + assert 'error' not in response + assert response['result'] == '0x82a8' + + +@pytest.mark.asyncio +async def test_rpc_with_old_block( + ipc_request, contract_code_hash, transaction, chain, fake_beam_syncer): + response = await ipc_request('eth_estimateGas', [transaction, 'latest']) + assert 'error' not in response + assert response['result'] == '0x82a8' + + for _ in range(65): + chain.mine_block() + + bytecode = chain.chaindb.db.pop(contract_code_hash) + + # if there is no beam syncer we return the original error + response = await ipc_request('eth_estimateGas', [transaction, 'earliest']) + assert 'error' in response + assert response['error'].startswith('Database is missing bytecode for code hash') + + # if there is a beam syncer we return a more useful error + async with fake_beam_syncer({contract_code_hash: bytecode}): + response = await ipc_request('eth_estimateGas', [transaction, 'earliest']) + assert 'error' in response + assert response['error'].startswith('block "earliest" is too old to be fetched') diff --git a/tests/core/p2p-proto/test_les_protocol_commands.py b/tests/core/p2p-proto/test_les_protocol_commands.py index 8d945ed5af..feb623c7d6 100644 --- a/tests/core/p2p-proto/test_les_protocol_commands.py +++ b/tests/core/p2p-proto/test_les_protocol_commands.py @@ -1,11 +1,11 @@ import asyncio import pytest -from p2p.peer import MsgBuffer - from trinity.protocol.les.proto import ( LESProtocol, + LESProtocolV2, ) + from trinity.tools.factories import LESV2PeerPairFactory @@ -35,25 +35,24 @@ async def test_les_protocol_methods_request_id( assert isinstance(peer.sub_proto, LESProtocol) assert isinstance(remote.sub_proto, LESProtocol) - collector = MsgBuffer() - remote.add_subscriber(collector) + # setup message collection + messages = [] + got_message = asyncio.Event() + + async def collect_messages(conn, cmd, msg): + messages.append((cmd, msg)) + got_message.set() + + peer.connection.add_protocol_handler(LESProtocolV2, collect_messages) # Test for get_block_headers - generated_request_id = peer.sub_proto.send_get_block_headers( + generated_request_id = remote.sub_proto.send_get_block_headers( b'1234', 1, 0, False, request_id=request_id ) + await asyncio.wait_for(got_message.wait(), timeout=1) - # yield to let remote and peer transmit messages. This can take a - # small amount of time so we give it a few rounds of the event loop to - # finish transmitting. - for _ in range(10): - await asyncio.sleep(0.01) - if collector.msg_queue.qsize() >= 1: - break - - messages = collector.get_messages() assert len(messages) == 1 - peer, cmd, msg = messages[0] + cmd, msg = messages[0] # Asserted that the reply message has the request_id as that which was generated assert generated_request_id == msg['request_id'] diff --git a/tests/core/p2p-proto/test_server.py b/tests/core/p2p-proto/test_server.py index ebaedd1a91..fc6a84c12c 100644 --- a/tests/core/p2p-proto/test_server.py +++ b/tests/core/p2p-proto/test_server.py @@ -10,6 +10,7 @@ from eth.db.chain import ChainDB from p2p.auth import HandshakeInitiator, _handshake +from p2p.connection import Connection from p2p.kademlia import ( Node, Address, @@ -137,12 +138,13 @@ async def test_server_incoming_connection(monkeypatch, server, event_loop): protocol_handshakers=handshakers, token=token, ) - initiator_peer = factory.create_peer( + connection = Connection( multiplexer=multiplexer, devp2p_receipt=devp2p_receipt, protocol_receipts=protocol_receipts, - inbound=False, + is_dial_out=False, ) + initiator_peer = factory.create_peer(connection=connection) # wait for peer to be processed for _ in range(100): diff --git a/tests/eth2/conftest.py b/tests/eth2/conftest.py index acb0135c73..399bae55f3 100644 --- a/tests/eth2/conftest.py +++ b/tests/eth2/conftest.py @@ -1,13 +1,10 @@ import functools import eth_utils.toolz as toolz - import pytest from eth2._utils.bls import bls -from eth2._utils.hash import ( - hash_eth2, -) +from eth2._utils.hash import hash_eth2 def _serialize_bls_pubkeys(key): @@ -24,10 +21,7 @@ def _deserialize_bls_pubkey(key_data): def _deserialize_pair(pair): index, pubkey = pair - return ( - int(index), - _deserialize_bls_pubkey(pubkey), - ) + return (int(index), _deserialize_bls_pubkey(pubkey)) class privkey_view: @@ -47,10 +41,7 @@ def __init__(self, key_cache): def __getitem__(self, index): if isinstance(index, slice): - return list( - self.key_cache._get_pubkey_at(i) - for i in range(index.stop) - ) + return list(self.key_cache._get_pubkey_at(i) for i in range(index.stop)) return self.key_cache._get_pubkey_at(index) @@ -69,8 +60,7 @@ def __init__(self, backing_cache_reader, backing_cache_writer): def _restore_from_cache(self, cached_data): self.all_pubkeys_by_index = toolz.itemmap( - _deserialize_pair, - cached_data["pubkeys_by_index"], + _deserialize_pair, cached_data["pubkeys_by_index"] ) for index, pubkey in self.all_pubkeys_by_index.items(): privkey = self._get_privkey_for(index) @@ -86,9 +76,8 @@ def _serialize(self): """ return { "pubkeys_by_index": toolz.valmap( - _serialize_bls_pubkeys, - self.all_pubkeys_by_index, - ), + _serialize_bls_pubkeys, self.all_pubkeys_by_index + ) } def _privkey_view(self): @@ -113,7 +102,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _get_privkey_for(self, index): # Want privkey an intger slightly less than the curve order - privkey = int.from_bytes(hash_eth2(index.to_bytes(32, 'little')), 'little') % 2**254 + privkey = ( + int.from_bytes(hash_eth2(index.to_bytes(32, "little")), "little") % 2 ** 254 + ) self.all_privkeys_by_index[index] = privkey return privkey @@ -173,14 +164,8 @@ def _key_cache(request, _should_persist_bls_keys): cache_key = f"eth2/bls/key-cache/{bls.backend.__name__}" if _should_persist_bls_keys: - backing_cache_reader = functools.partial( - request.config.cache.get, - cache_key, - ) - backing_cache_writer = functools.partial( - request.config.cache.set, - cache_key, - ) + backing_cache_reader = functools.partial(request.config.cache.get, cache_key) + backing_cache_writer = functools.partial(request.config.cache.set, cache_key) else: backing_cache_reader = None backing_cache_writer = None diff --git a/tests/eth2/core/beacon/chains/conftest.py b/tests/eth2/core/beacon/chains/conftest.py index b17c00f18c..7a4add3ea6 100644 --- a/tests/eth2/core/beacon/chains/conftest.py +++ b/tests/eth2/core/beacon/chains/conftest.py @@ -1,17 +1,16 @@ import pytest -from eth2.beacon.chains.base import ( - BeaconChain, -) +from eth2.beacon.chains.base import BeaconChain def _beacon_chain_with_block_validation( - base_db, - genesis_block, - genesis_state, - fixture_sm_class, - config, - chain_cls=BeaconChain): + base_db, + genesis_block, + genesis_state, + fixture_sm_class, + config, + chain_cls=BeaconChain, +): """ Return a Chain object containing just the genesis block. @@ -25,34 +24,21 @@ def _beacon_chain_with_block_validation( """ klass = chain_cls.configure( - __name__='TestChain', - sm_configuration=( - (genesis_state.slot, fixture_sm_class), - ), + __name__="TestChain", + sm_configuration=((genesis_state.slot, fixture_sm_class),), chain_id=5566, ) - chain = klass.from_genesis( - base_db, - genesis_state, - genesis_block, - config, - ) + chain = klass.from_genesis(base_db, genesis_state, genesis_block, config) return chain @pytest.fixture -def beacon_chain_with_block_validation(base_db, - genesis_block, - genesis_state, - fixture_sm_class, - config): +def beacon_chain_with_block_validation( + base_db, genesis_block, genesis_state, fixture_sm_class, config +): return _beacon_chain_with_block_validation( - base_db, - genesis_block, - genesis_state, - fixture_sm_class, - config, + base_db, genesis_block, genesis_state, fixture_sm_class, config ) @@ -62,11 +48,8 @@ def import_block_without_validation(chain, block): @pytest.fixture(params=[BeaconChain]) def beacon_chain_without_block_validation( - request, - base_db, - genesis_state, - genesis_block, - fixture_sm_class): + request, base_db, genesis_state, genesis_block, fixture_sm_class +): """ Return a Chain object containing just the genesis block. @@ -77,22 +60,14 @@ def beacon_chain_without_block_validation( chain itself. """ # Disable block validation so that we don't need to construct finalized blocks. - overrides = { - 'import_block': import_block_without_validation, - } + overrides = {"import_block": import_block_without_validation} chain_class = request.param klass = chain_class.configure( - __name__='TestChainWithoutBlockValidation', - sm_configuration=( - (0, fixture_sm_class), - ), + __name__="TestChainWithoutBlockValidation", + sm_configuration=((0, fixture_sm_class),), chain_id=5566, **overrides, ) - chain = klass.from_genesis( - base_db, - genesis_state, - genesis_block, - ) + chain = klass.from_genesis(base_db, genesis_state, genesis_block) return chain diff --git a/tests/eth2/core/beacon/chains/test_beacon_chain.py b/tests/eth2/core/beacon/chains/test_beacon_chain.py index 66c8578858..775db39cf3 100644 --- a/tests/eth2/core/beacon/chains/test_beacon_chain.py +++ b/tests/eth2/core/beacon/chains/test_beacon_chain.py @@ -2,30 +2,13 @@ import pytest - -from eth2.beacon.exceptions import ( - BlockClassError, -) -from eth2.beacon.chains.base import ( - BeaconChain, -) -from eth2.beacon.db.exceptions import ( - AttestationRootNotFound, - StateSlotNotFound, -) -from eth2.beacon.types.blocks import ( - BeaconBlock, -) -from eth2.beacon.tools.builder.proposer import ( - create_mock_block, - -) -from eth2.beacon.tools.builder.validator import ( - create_mock_signed_attestations_at_slot, -) -from eth2.beacon.state_machines.forks.serenity.blocks import ( - SerenityBeaconBlock, -) +from eth2.beacon.chains.base import BeaconChain +from eth2.beacon.db.exceptions import AttestationRootNotFound, StateNotFound +from eth2.beacon.exceptions import BlockClassError +from eth2.beacon.state_machines.forks.serenity.blocks import SerenityBeaconBlock +from eth2.beacon.tools.builder.proposer import create_mock_block +from eth2.beacon.tools.builder.validator import create_mock_signed_attestations_at_slot +from eth2.beacon.types.blocks import BeaconBlock @pytest.fixture @@ -39,12 +22,8 @@ def valid_chain(beacon_chain_with_block_validation): @pytest.mark.parametrize( - ( - 'validator_count,slots_per_epoch,target_committee_size,shard_count' - ), - [ - (100, 20, 10, 20), - ] + ("validator_count,slots_per_epoch,target_committee_size,shard_count"), + [(100, 20, 10, 20)], ) def test_canonical_chain(valid_chain, genesis_slot, fork_choice_scoring): genesis_block = valid_chain.get_canonical_block_by_slot(genesis_slot) @@ -56,8 +35,7 @@ def test_canonical_chain(valid_chain, genesis_slot, fork_choice_scoring): assert valid_chain.get_score(genesis_block.signing_root) == 0 block = genesis_block.copy( - slot=genesis_block.slot + 1, - parent_root=genesis_block.signing_root, + slot=genesis_block.slot + 1, parent_root=genesis_block.signing_root ) valid_chain.chaindb.persist_block(block, block.__class__, fork_choice_scoring) @@ -68,9 +46,7 @@ def test_canonical_chain(valid_chain, genesis_slot, fork_choice_scoring): assert valid_chain.get_score(block.signing_root) == scoring_fn(block) assert scoring_fn(block) != 0 - canonical_block_1 = valid_chain.get_canonical_block_by_slot( - genesis_block.slot + 1, - ) + canonical_block_1 = valid_chain.get_canonical_block_by_slot(genesis_block.slot + 1) assert canonical_block_1 == block result_block = valid_chain.get_block_by_root(block.signing_root) @@ -78,35 +54,24 @@ def test_canonical_chain(valid_chain, genesis_slot, fork_choice_scoring): @pytest.mark.parametrize( - ( - 'validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - ), - [ - (100, 16, 10, 16), - ] + ("validator_count," "slots_per_epoch," "target_committee_size," "shard_count,"), + [(100, 16, 10, 16)], ) -def test_get_state_by_slot(valid_chain, - genesis_block, - genesis_state, - config, - keymap): +def test_get_state_by_slot(valid_chain, genesis_block, genesis_state, config, keymap): # Fisrt, skip block and check if `get_state_by_slot` returns the expected state state_machine = valid_chain.get_state_machine(genesis_block.slot) - state = state_machine.state + state = valid_chain.get_head_state() block_skipped_slot = genesis_block.slot + 1 block_skipped_state = state_machine.state_transition.apply_state_transition( - state, - future_slot=block_skipped_slot, + state, future_slot=block_skipped_slot ) - with pytest.raises(StateSlotNotFound): + with pytest.raises(StateNotFound): valid_chain.get_state_by_slot(block_skipped_slot) valid_chain.chaindb.persist_state(block_skipped_state) - assert valid_chain.get_state_by_slot( - block_skipped_slot - ).hash_tree_root == block_skipped_state.hash_tree_root + assert ( + valid_chain.get_state_by_slot(block_skipped_slot).hash_tree_root + == block_skipped_state.hash_tree_root + ) # Next, import proposed block and check if `get_state_by_slot` returns the expected state proposed_slot = block_skipped_slot + 1 @@ -121,24 +86,19 @@ def test_get_state_by_slot(valid_chain, attestations=(), ) valid_chain.import_block(block) - state = valid_chain.get_state_machine().state - assert valid_chain.get_state_by_slot(proposed_slot).hash_tree_root == state.hash_tree_root + state = valid_chain.get_head_state() + assert ( + valid_chain.get_state_by_slot(proposed_slot).hash_tree_root + == state.hash_tree_root + ) @pytest.mark.long @pytest.mark.parametrize( - ( - 'validator_count,slots_per_epoch,target_committee_size,shard_count' - ), - [ - (100, 16, 10, 16), - ] + ("validator_count,slots_per_epoch,target_committee_size,shard_count"), + [(100, 16, 10, 16)], ) -def test_import_blocks(valid_chain, - genesis_block, - genesis_state, - config, - keymap): +def test_import_blocks(valid_chain, genesis_block, genesis_state, config, keymap): state = genesis_state blocks = (genesis_block,) valid_chain_2 = copy.deepcopy(valid_chain) @@ -158,12 +118,8 @@ def test_import_blocks(valid_chain, state = valid_chain.get_state_by_slot(block.slot) - assert block == valid_chain.get_canonical_block_by_slot( - block.slot - ) - assert block.signing_root == valid_chain.get_canonical_block_root( - block.slot - ) + assert block == valid_chain.get_canonical_block_by_slot(block.slot) + assert block.signing_root == valid_chain.get_canonical_block_root(block.slot) blocks += (block,) assert valid_chain.get_canonical_head() != valid_chain_2.get_canonical_head() @@ -173,23 +129,14 @@ def test_import_blocks(valid_chain, assert valid_chain.get_canonical_head() == valid_chain_2.get_canonical_head() assert valid_chain.get_state_by_slot(blocks[-1].slot).slot != 0 - assert ( - valid_chain.get_state_by_slot(blocks[-1].slot) == - valid_chain_2.get_state_by_slot(blocks[-1].slot) - ) + assert valid_chain.get_state_by_slot( + blocks[-1].slot + ) == valid_chain_2.get_state_by_slot(blocks[-1].slot) -def test_from_genesis(base_db, - genesis_block, - genesis_state, - fixture_sm_class, - config): +def test_from_genesis(base_db, genesis_block, genesis_state, fixture_sm_class, config): klass = BeaconChain.configure( - __name__='TestChain', - sm_configuration=( - (0, fixture_sm_class), - ), - chain_id=5566, + __name__="TestChain", sm_configuration=((0, fixture_sm_class),), chain_id=5566 ) assert type(genesis_block) == SerenityBeaconBlock @@ -197,33 +144,28 @@ def test_from_genesis(base_db, assert type(block) == BeaconBlock with pytest.raises(BlockClassError): - klass.from_genesis( - base_db, - genesis_state, - block, - config, - ) + klass.from_genesis(base_db, genesis_state, block, config) @pytest.mark.long @pytest.mark.parametrize( ( - 'validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - 'min_attestation_inclusion_delay,' + "validator_count," + "slots_per_epoch," + "target_committee_size," + "shard_count," + "min_attestation_inclusion_delay," ), - [ - (100, 16, 10, 16, 0), - ] + [(100, 16, 10, 16, 0)], ) -def test_get_attestation_root(valid_chain, - genesis_block, - genesis_state, - config, - keymap, - min_attestation_inclusion_delay): +def test_get_attestation_root( + valid_chain, + genesis_block, + genesis_state, + config, + keymap, + min_attestation_inclusion_delay, +): state_machine = valid_chain.get_state_machine() attestations = create_mock_signed_attestations_at_slot( state=genesis_state, @@ -248,9 +190,7 @@ def test_get_attestation_root(valid_chain, a0 = attestations[0] assert valid_chain.get_attestation_by_root(a0.hash_tree_root) == a0 assert valid_chain.attestation_exists(a0.hash_tree_root) - fake_attestation = a0.copy( - signature=b'\x78' * 96, - ) + fake_attestation = a0.copy(signature=b"\x78" * 96) with pytest.raises(AttestationRootNotFound): valid_chain.get_attestation_by_root(fake_attestation.hash_tree_root) assert not valid_chain.attestation_exists(fake_attestation.hash_tree_root) diff --git a/tests/eth2/core/beacon/chains/test_chains.py b/tests/eth2/core/beacon/chains/test_chains.py index 3cc2eeccdf..806e7d153c 100644 --- a/tests/eth2/core/beacon/chains/test_chains.py +++ b/tests/eth2/core/beacon/chains/test_chains.py @@ -1,16 +1,9 @@ import pytest + from eth2.beacon.chains.testnet import TestnetChain as _TestnetChain -@pytest.mark.parametrize( - "chain_klass", - ( - _TestnetChain, - ) -) -def test_chain_class_well_defined(base_db, - chain_klass, - empty_attestation_pool, - config): +@pytest.mark.parametrize("chain_klass", (_TestnetChain,)) +def test_chain_class_well_defined(base_db, chain_klass, empty_attestation_pool, config): chain = chain_klass(base_db, empty_attestation_pool, config) assert chain.sm_configuration is not () and chain.sm_configuration is not None diff --git a/tests/eth2/core/beacon/conftest.py b/tests/eth2/core/beacon/conftest.py index b4322d27a4..a51511aaa0 100644 --- a/tests/eth2/core/beacon/conftest.py +++ b/tests/eth2/core/beacon/conftest.py @@ -1,73 +1,34 @@ +from eth.constants import ZERO_HASH32 +from eth_typing import BLSPubkey import pytest -from eth.constants import ( - ZERO_HASH32, -) -from eth_typing import ( - BLSPubkey, -) - -from eth2.configs import ( - Eth2Config, - CommitteeConfig, - Eth2GenesisConfig, -) from eth2.beacon.constants import ( DEPOSIT_CONTRACT_TREE_DEPTH, FAR_FUTURE_EPOCH, GWEI_PER_ETH, JUSTIFICATION_BITS_LENGTH, ) -from eth2.beacon.fork_choice.higher_slot import ( - higher_slot_scoring, -) +from eth2.beacon.db.chain import BeaconChainDB +from eth2.beacon.fork_choice.higher_slot import higher_slot_scoring +from eth2.beacon.genesis import get_genesis_block from eth2.beacon.operations.attestation_pool import AttestationPool -from eth2.beacon.types.attestations import IndexedAttestation +from eth2.beacon.state_machines.forks.serenity import SerenityStateMachine +from eth2.beacon.state_machines.forks.serenity.blocks import SerenityBeaconBlock +from eth2.beacon.state_machines.forks.serenity.configs import SERENITY_CONFIG +from eth2.beacon.tools.builder.initializer import create_mock_validator +from eth2.beacon.tools.builder.state import create_mock_genesis_state_from_validators +from eth2.beacon.tools.misc.ssz_vector import override_lengths from eth2.beacon.types.attestation_data import AttestationData -from eth2.beacon.types.blocks import BeaconBlock +from eth2.beacon.types.attestations import IndexedAttestation +from eth2.beacon.types.blocks import BeaconBlock, BeaconBlockBody, BeaconBlockHeader from eth2.beacon.types.checkpoints import Checkpoint from eth2.beacon.types.crosslinks import Crosslink from eth2.beacon.types.deposit_data import DepositData from eth2.beacon.types.eth1_data import Eth1Data +from eth2.beacon.types.forks import Fork from eth2.beacon.types.states import BeaconState - -from eth2.beacon.genesis import ( - get_genesis_block, -) -from eth2.beacon.tools.misc.ssz_vector import ( - override_lengths, -) -from eth2.beacon.tools.builder.state import ( - create_mock_genesis_state_from_validators, -) -from eth2.beacon.types.blocks import ( - BeaconBlockBody, - BeaconBlockHeader, -) -from eth2.beacon.types.forks import ( - Fork, -) -from eth2.beacon.typing import ( - Gwei, - ValidatorIndex, - Timestamp, - Version, -) -from eth2.beacon.state_machines.forks.serenity import ( - SerenityStateMachine, -) -from eth2.beacon.state_machines.forks.serenity.blocks import ( - SerenityBeaconBlock, -) -from eth2.beacon.state_machines.forks.serenity.configs import SERENITY_CONFIG - -from eth2.beacon.tools.builder.initializer import ( - create_mock_validator, -) - -from eth2.beacon.db.chain import ( - BeaconChainDB, -) +from eth2.beacon.typing import Gwei, Timestamp, ValidatorIndex, Version +from eth2.configs import CommitteeConfig, Eth2Config, Eth2GenesisConfig # SSZ @@ -295,48 +256,50 @@ def deposit_contract_address(): @pytest.fixture -def config(shard_count, - target_committee_size, - max_validators_per_committee, - min_per_epoch_churn_limit, - churn_limit_quotient, - shuffle_round_count, - min_genesis_active_validator_count, - min_genesis_time, - min_deposit_amount, - max_effective_balance, - ejection_balance, - effective_balance_increment, - genesis_slot, - genesis_epoch, - bls_withdrawal_prefix, - seconds_per_slot, - min_attestation_inclusion_delay, - slots_per_epoch, - min_seed_lookahead, - activation_exit_delay, - slots_per_eth1_voting_period, - slots_per_historical_root, - min_validator_withdrawability_delay, - persistent_committee_period, - max_epochs_per_crosslink, - min_epochs_to_inactivity_penalty, - epochs_per_historical_vector, - epochs_per_slashings_vector, - historical_roots_limit, - validator_registry_limit, - base_reward_factor, - whistleblower_reward_quotient, - proposer_reward_quotient, - inactivity_penalty_quotient, - min_slashing_penalty_quotient, - max_proposer_slashings, - max_attester_slashings, - max_attestations, - max_deposits, - max_voluntary_exits, - max_transfers, - deposit_contract_address): +def config( + shard_count, + target_committee_size, + max_validators_per_committee, + min_per_epoch_churn_limit, + churn_limit_quotient, + shuffle_round_count, + min_genesis_active_validator_count, + min_genesis_time, + min_deposit_amount, + max_effective_balance, + ejection_balance, + effective_balance_increment, + genesis_slot, + genesis_epoch, + bls_withdrawal_prefix, + seconds_per_slot, + min_attestation_inclusion_delay, + slots_per_epoch, + min_seed_lookahead, + activation_exit_delay, + slots_per_eth1_voting_period, + slots_per_historical_root, + min_validator_withdrawability_delay, + persistent_committee_period, + max_epochs_per_crosslink, + min_epochs_to_inactivity_penalty, + epochs_per_historical_vector, + epochs_per_slashings_vector, + historical_roots_limit, + validator_registry_limit, + base_reward_factor, + whistleblower_reward_quotient, + proposer_reward_quotient, + inactivity_penalty_quotient, + min_slashing_penalty_quotient, + max_proposer_slashings, + max_attester_slashings, + max_attestations, + max_deposits, + max_voluntary_exits, + max_transfers, + deposit_contract_address, +): # adding some config validity conditions here # abstract out into the config object? assert shard_count >= slots_per_epoch @@ -402,126 +365,120 @@ def genesis_config(config): # @pytest.fixture def sample_signature(): - return b'\56' * 96 + return b"\56" * 96 @pytest.fixture def sample_fork_params(): return { - 'previous_version': Version((0).to_bytes(4, 'little')), - 'current_version': Version((0).to_bytes(4, 'little')), - 'epoch': 2**32, + "previous_version": Version((0).to_bytes(4, "little")), + "current_version": Version((0).to_bytes(4, "little")), + "epoch": 2 ** 32, } @pytest.fixture def sample_validator_record_params(): return { - 'pubkey': b'\x67' * 48, - 'withdrawal_credentials': b'\x01' * 32, - 'effective_balance': Gwei(32 * GWEI_PER_ETH), - 'slashed': False, - 'activation_eligibility_epoch': FAR_FUTURE_EPOCH, - 'activation_epoch': FAR_FUTURE_EPOCH, - 'exit_epoch': FAR_FUTURE_EPOCH, - 'withdrawable_epoch': FAR_FUTURE_EPOCH, + "pubkey": b"\x67" * 48, + "withdrawal_credentials": b"\x01" * 32, + "effective_balance": Gwei(32 * GWEI_PER_ETH), + "slashed": False, + "activation_eligibility_epoch": FAR_FUTURE_EPOCH, + "activation_epoch": FAR_FUTURE_EPOCH, + "exit_epoch": FAR_FUTURE_EPOCH, + "withdrawable_epoch": FAR_FUTURE_EPOCH, } @pytest.fixture def sample_crosslink_record_params(): return { - 'shard': 0, - 'parent_root': b'\x34' * 32, - 'start_epoch': 0, - 'end_epoch': 1, - 'data_root': b'\x43' * 32, + "shard": 0, + "parent_root": b"\x34" * 32, + "start_epoch": 0, + "end_epoch": 1, + "data_root": b"\x43" * 32, } @pytest.fixture def sample_attestation_data_params(sample_crosslink_record_params): return { - 'beacon_block_root': b'\x11' * 32, - 'source': Checkpoint( - epoch=11, - root=b'\x22' * 32, - ), - 'target': Checkpoint( - epoch=12, - root=b'\x33' * 32, - ), - 'crosslink': Crosslink(**sample_crosslink_record_params), + "beacon_block_root": b"\x11" * 32, + "source": Checkpoint(epoch=11, root=b"\x22" * 32), + "target": Checkpoint(epoch=12, root=b"\x33" * 32), + "crosslink": Crosslink(**sample_crosslink_record_params), } @pytest.fixture def sample_attestation_data_and_custody_bit_params(sample_attestation_data_params): return { - 'data': AttestationData(**sample_attestation_data_params), - 'custody_bit': False, + "data": AttestationData(**sample_attestation_data_params), + "custody_bit": False, } @pytest.fixture def sample_indexed_attestation_params(sample_signature, sample_attestation_data_params): return { - 'custody_bit_0_indices': (10, 11, 12, 15, 28), - 'custody_bit_1_indices': tuple(), - 'data': AttestationData(**sample_attestation_data_params), - 'signature': sample_signature, + "custody_bit_0_indices": (10, 11, 12, 15, 28), + "custody_bit_1_indices": tuple(), + "data": AttestationData(**sample_attestation_data_params), + "signature": sample_signature, } @pytest.fixture def sample_pending_attestation_record_params(sample_attestation_data_params): return { - 'aggregation_bits': (True, False) * 4 * 16, - 'data': AttestationData(**sample_attestation_data_params), - 'inclusion_delay': 1, - 'proposer_index': ValidatorIndex(12), + "aggregation_bits": (True, False) * 4 * 16, + "data": AttestationData(**sample_attestation_data_params), + "inclusion_delay": 1, + "proposer_index": ValidatorIndex(12), } @pytest.fixture def sample_eth1_data_params(): return { - 'deposit_root': b'\x43' * 32, - 'deposit_count': 22, - 'block_hash': b'\x46' * 32, + "deposit_root": b"\x43" * 32, + "deposit_count": 22, + "block_hash": b"\x46" * 32, } @pytest.fixture def sample_historical_batch_params(config): return { - 'block_roots': tuple( + "block_roots": tuple( (bytes([i] * 32) for i in range(config.SLOTS_PER_HISTORICAL_ROOT)) ), - 'state_roots': tuple( + "state_roots": tuple( (bytes([i] * 32) for i in range(config.SLOTS_PER_HISTORICAL_ROOT)) - ) + ), } @pytest.fixture def sample_deposit_data_params(sample_signature): return { - 'pubkey': BLSPubkey(b'\x67' * 48), - 'withdrawal_credentials': b'\11' * 32, - 'amount': Gwei(56), - 'signature': sample_signature, + "pubkey": BLSPubkey(b"\x67" * 48), + "withdrawal_credentials": b"\11" * 32, + "amount": Gwei(56), + "signature": sample_signature, } @pytest.fixture def sample_block_header_params(): return { - 'slot': 10, - 'parent_root': b'\x22' * 32, - 'state_root': b'\x33' * 32, - 'body_root': b'\x43' * 32, - 'signature': b'\x56' * 96, + "slot": 10, + "parent_root": b"\x22" * 32, + "state_root": b"\x33" * 32, + "body_root": b"\x43" * 32, + "signature": b"\x56" * 96, } @@ -529,146 +486,133 @@ def sample_block_header_params(): def sample_proposer_slashing_params(sample_block_header_params): block_header_data = BeaconBlockHeader(**sample_block_header_params) return { - 'proposer_index': 1, - 'header_1': block_header_data, - 'header_2': block_header_data, + "proposer_index": 1, + "header_1": block_header_data, + "header_2": block_header_data, } @pytest.fixture def sample_attester_slashing_params(sample_indexed_attestation_params): - indexed_attestation = IndexedAttestation( - **sample_indexed_attestation_params - ) - return { - 'attestation_1': indexed_attestation, - 'attestation_2': indexed_attestation, - } + indexed_attestation = IndexedAttestation(**sample_indexed_attestation_params) + return {"attestation_1": indexed_attestation, "attestation_2": indexed_attestation} @pytest.fixture def sample_attestation_params(sample_signature, sample_attestation_data_params): return { - 'aggregation_bits': (True,) * 16, - 'data': AttestationData(**sample_attestation_data_params), - 'custody_bits': (False,) * 16, - 'signature': sample_signature, + "aggregation_bits": (True,) * 16, + "data": AttestationData(**sample_attestation_data_params), + "custody_bits": (False,) * 16, + "signature": sample_signature, } @pytest.fixture def sample_deposit_params(sample_deposit_data_params, deposit_contract_tree_depth): return { - 'proof': (b'\x22' * 32,) * (deposit_contract_tree_depth + 1), - 'data': DepositData(**sample_deposit_data_params) + "proof": (b"\x22" * 32,) * (deposit_contract_tree_depth + 1), + "data": DepositData(**sample_deposit_data_params), } @pytest.fixture def sample_voluntary_exit_params(sample_signature): - return { - 'epoch': 123, - 'validator_index': 15, - 'signature': sample_signature, - } + return {"epoch": 123, "validator_index": 15, "signature": sample_signature} @pytest.fixture def sample_transfer_params(): return { - 'sender': 10, - 'recipient': 12, - 'amount': 10 * 10**9, - 'fee': 5 * 10**9, - 'slot': 5, - 'pubkey': b'\x67' * 48, - 'signature': b'\x43' * 96, + "sender": 10, + "recipient": 12, + "amount": 10 * 10 ** 9, + "fee": 5 * 10 ** 9, + "slot": 5, + "pubkey": b"\x67" * 48, + "signature": b"\x43" * 96, } @pytest.fixture def sample_beacon_block_body_params(sample_signature, sample_eth1_data_params): return { - 'randao_reveal': sample_signature, - 'eth1_data': Eth1Data(**sample_eth1_data_params), - 'graffiti': ZERO_HASH32, - 'proposer_slashings': (), - 'attester_slashings': (), - 'attestations': (), - 'deposits': (), - 'voluntary_exits': (), - 'transfers': (), + "randao_reveal": sample_signature, + "eth1_data": Eth1Data(**sample_eth1_data_params), + "graffiti": ZERO_HASH32, + "proposer_slashings": (), + "attester_slashings": (), + "attestations": (), + "deposits": (), + "voluntary_exits": (), + "transfers": (), } @pytest.fixture -def sample_beacon_block_params(sample_signature, sample_beacon_block_body_params, genesis_slot): +def sample_beacon_block_params( + sample_signature, sample_beacon_block_body_params, genesis_slot +): return { - 'slot': genesis_slot + 10, - 'parent_root': ZERO_HASH32, - 'state_root': b'\x55' * 32, - 'body': BeaconBlockBody(**sample_beacon_block_body_params), - 'signature': sample_signature, + "slot": genesis_slot + 10, + "parent_root": ZERO_HASH32, + "state_root": b"\x55" * 32, + "body": BeaconBlockBody(**sample_beacon_block_body_params), + "signature": sample_signature, } @pytest.fixture -def sample_beacon_state_params(config, - genesis_slot, - genesis_epoch, - sample_fork_params, - sample_eth1_data_params, - sample_block_header_params, - sample_crosslink_record_params): +def sample_beacon_state_params( + config, + genesis_slot, + genesis_epoch, + sample_fork_params, + sample_eth1_data_params, + sample_block_header_params, + sample_crosslink_record_params, +): return { # Versioning - 'genesis_time': 0, - 'slot': genesis_slot + 100, - 'fork': Fork(**sample_fork_params), + "genesis_time": 0, + "slot": genesis_slot + 100, + "fork": Fork(**sample_fork_params), # History - 'latest_block_header': BeaconBlockHeader(**sample_block_header_params), - 'block_roots': (ZERO_HASH32,) * config.SLOTS_PER_HISTORICAL_ROOT, - 'state_roots': (ZERO_HASH32,) * config.SLOTS_PER_HISTORICAL_ROOT, - 'historical_roots': (), + "latest_block_header": BeaconBlockHeader(**sample_block_header_params), + "block_roots": (ZERO_HASH32,) * config.SLOTS_PER_HISTORICAL_ROOT, + "state_roots": (ZERO_HASH32,) * config.SLOTS_PER_HISTORICAL_ROOT, + "historical_roots": (), # Eth1 - 'eth1_data': Eth1Data(**sample_eth1_data_params), - 'eth1_data_votes': (), - 'eth1_deposit_index': 0, + "eth1_data": Eth1Data(**sample_eth1_data_params), + "eth1_data_votes": (), + "eth1_deposit_index": 0, # Registry - 'validators': (), - 'balances': (), + "validators": (), + "balances": (), # Shuffling - 'start_shard': 1, - 'randao_mixes': (ZERO_HASH32,) * config.EPOCHS_PER_HISTORICAL_VECTOR, - 'active_index_roots': (ZERO_HASH32,) * config.EPOCHS_PER_HISTORICAL_VECTOR, - 'compact_committees_roots': (ZERO_HASH32,) * config.EPOCHS_PER_HISTORICAL_VECTOR, + "start_shard": 1, + "randao_mixes": (ZERO_HASH32,) * config.EPOCHS_PER_HISTORICAL_VECTOR, + "active_index_roots": (ZERO_HASH32,) * config.EPOCHS_PER_HISTORICAL_VECTOR, + "compact_committees_roots": (ZERO_HASH32,) + * config.EPOCHS_PER_HISTORICAL_VECTOR, # Slashings - 'slashings': (0,) * config.EPOCHS_PER_SLASHINGS_VECTOR, + "slashings": (0,) * config.EPOCHS_PER_SLASHINGS_VECTOR, # Attestations - 'previous_epoch_attestations': (), - 'current_epoch_attestations': (), + "previous_epoch_attestations": (), + "current_epoch_attestations": (), # Crosslinks - 'previous_crosslinks': ( + "previous_crosslinks": ( (Crosslink(**sample_crosslink_record_params),) * config.SHARD_COUNT ), - 'current_crosslinks': ( + "current_crosslinks": ( (Crosslink(**sample_crosslink_record_params),) * config.SHARD_COUNT ), # Justification - 'justification_bits': (False,) * JUSTIFICATION_BITS_LENGTH, - 'previous_justified_checkpoint': Checkpoint( - epoch=0, - root=b'\x99' * 32, - ), - 'current_justified_checkpoint': Checkpoint( - epoch=0, - root=b'\x55' * 32, - ), + "justification_bits": (False,) * JUSTIFICATION_BITS_LENGTH, + "previous_justified_checkpoint": Checkpoint(epoch=0, root=b"\x99" * 32), + "current_justified_checkpoint": Checkpoint(epoch=0, root=b"\x55" * 32), # Finality - 'finalized_checkpoint': Checkpoint( - epoch=0, - root=b'\x33' * 32, - ) + "finalized_checkpoint": Checkpoint(epoch=0, root=b"\x33" * 32), } @@ -696,10 +640,8 @@ def genesis_validators(validator_count, pubkeys, config): Returns ``validator_count`` number of activated validators. """ return tuple( - create_mock_validator( - pubkey=pubkey, - config=config, - ) for pubkey in pubkeys[:validator_count] + create_mock_validator(pubkey=pubkey, config=config) + for pubkey in pubkeys[:validator_count] ) @@ -709,30 +651,21 @@ def genesis_balances(validator_count, max_effective_balance): @pytest.fixture -def genesis_state(genesis_validators, - genesis_balances, - genesis_time, - sample_eth1_data_params, - config): +def genesis_state( + genesis_validators, genesis_balances, genesis_time, sample_eth1_data_params, config +): genesis_eth1_data = Eth1Data(**sample_eth1_data_params).copy( - deposit_count=len(genesis_validators), + deposit_count=len(genesis_validators) ) return create_mock_genesis_state_from_validators( - genesis_time, - genesis_eth1_data, - genesis_validators, - genesis_balances, - config, + genesis_time, genesis_eth1_data, genesis_validators, genesis_balances, config ) @pytest.fixture def genesis_block(genesis_state): - return get_genesis_block( - genesis_state.hash_tree_root, - SerenityBeaconBlock, - ) + return get_genesis_block(genesis_state.hash_tree_root, SerenityBeaconBlock) # @@ -741,9 +674,9 @@ def genesis_block(genesis_state): @pytest.fixture def fixture_sm_class(config, fork_choice_scoring): return SerenityStateMachine.configure( - __name__='SerenityStateMachineForTesting', + __name__="SerenityStateMachineForTesting", config=config, - get_fork_choice_scoring=lambda self: fork_choice_scoring + get_fork_choice_scoring=lambda self: fork_choice_scoring, ) diff --git a/tests/eth2/core/beacon/db/test_beacon_chaindb.py b/tests/eth2/core/beacon/db/test_beacon_chaindb.py index b8e3549834..3a633da14a 100644 --- a/tests/eth2/core/beacon/db/test_beacon_chaindb.py +++ b/tests/eth2/core/beacon/db/test_beacon_chaindb.py @@ -1,28 +1,14 @@ import random +from eth.constants import GENESIS_PARENT_HASH +from eth.exceptions import BlockNotFound, ParentNotFound +from hypothesis import given +from hypothesis import strategies as st import pytest - -from hypothesis import ( - given, - strategies as st, -) - import ssz -from eth.constants import ( - GENESIS_PARENT_HASH, -) -from eth.exceptions import ( - BlockNotFound, - ParentNotFound, -) -from eth2._utils.hash import ( - hash_eth2, -) -from eth2._utils.ssz import ( - validate_ssz_equal, -) - +from eth2._utils.hash import hash_eth2 +from eth2._utils.ssz import validate_ssz_equal from eth2.beacon.db.exceptions import ( AttestationRootNotFound, FinalizedHeadNotFound, @@ -30,19 +16,16 @@ JustifiedHeadNotFound, ) from eth2.beacon.db.schema import SchemaV1 -from eth2.beacon.state_machines.forks.serenity.blocks import ( - BeaconBlock, -) +from eth2.beacon.state_machines.forks.serenity.blocks import BeaconBlock from eth2.beacon.types.attestations import Attestation -from eth2.beacon.types.states import BeaconState from eth2.beacon.types.checkpoints import Checkpoint +from eth2.beacon.types.states import BeaconState @pytest.fixture(params=[1, 10, 999]) def block(request, sample_beacon_block_params): return BeaconBlock(**sample_beacon_block_params).copy( - parent_root=GENESIS_PARENT_HASH, - slot=request.param, + parent_root=GENESIS_PARENT_HASH, slot=request.param ) @@ -52,7 +35,9 @@ def state(sample_beacon_state_params): @pytest.fixture() -def block_with_attestation(chaindb, sample_block, sample_attestation_params, fork_choice_scoring): +def block_with_attestation( + chaindb, sample_block, sample_attestation_params, fork_choice_scoring +): sample_attestation = Attestation(**sample_attestation_params) genesis = sample_block @@ -60,16 +45,14 @@ def block_with_attestation(chaindb, sample_block, sample_attestation_params, for block1 = genesis.copy( parent_root=genesis.signing_root, slot=genesis.slot + 1, - body=genesis.body.copy( - attestations=(sample_attestation,), - ) + body=genesis.body.copy(attestations=(sample_attestation,)), ) return block1, sample_attestation @pytest.fixture() def maximum_score_value(): - return 2**64 - 1 + return 2 ** 64 - 1 def test_chaindb_add_block_number_to_root_lookup(chaindb, block, fork_choice_scoring): @@ -92,7 +75,9 @@ def test_chaindb_persist_block_and_slot_to_root(chaindb, block, fork_choice_scor @given(seed=st.binary(min_size=32, max_size=32)) -def test_chaindb_persist_block_and_unknown_parent(chaindb, block, fork_choice_scoring, seed): +def test_chaindb_persist_block_and_unknown_parent( + chaindb, block, fork_choice_scoring, seed +): n_block = block.copy(parent_root=hash_eth2(seed)) with pytest.raises(ParentNotFound): chaindb.persist_block(n_block, n_block.__class__, fork_choice_scoring) @@ -107,19 +92,21 @@ def test_chaindb_persist_block_and_block_to_root(chaindb, block, fork_choice_sco def test_chaindb_get_score(chaindb, sample_beacon_block_params, fork_choice_scoring): genesis = BeaconBlock(**sample_beacon_block_params).copy( - parent_root=GENESIS_PARENT_HASH, - slot=0, + parent_root=GENESIS_PARENT_HASH, slot=0 ) chaindb.persist_block(genesis, genesis.__class__, fork_choice_scoring) - genesis_score_key = SchemaV1.make_block_root_to_score_lookup_key(genesis.signing_root) - genesis_score = ssz.decode(chaindb.db.get(genesis_score_key), sedes=ssz.sedes.uint64) + genesis_score_key = SchemaV1.make_block_root_to_score_lookup_key( + genesis.signing_root + ) + genesis_score = ssz.decode( + chaindb.db.get(genesis_score_key), sedes=ssz.sedes.uint64 + ) assert genesis_score == 0 assert chaindb.get_score(genesis.signing_root) == 0 block1 = BeaconBlock(**sample_beacon_block_params).copy( - parent_root=genesis.signing_root, - slot=1, + parent_root=genesis.signing_root, slot=1 ) chaindb.persist_block(block1, block1.__class__, fork_choice_scoring) @@ -161,9 +148,7 @@ def test_chaindb_get_head_state_slot(chaindb, state): with pytest.raises(HeadStateSlotNotFound): chaindb.get_head_state_slot() current_slot = state.slot + 10 - current_state = state.copy( - slot=current_slot, - ) + current_state = state.copy(slot=current_slot) chaindb.persist_state(current_state) assert chaindb.get_head_state_slot() == current_state.slot past_state = state @@ -174,37 +159,41 @@ def test_chaindb_get_head_state_slot(chaindb, state): def test_chaindb_state(chaindb, state): chaindb.persist_state(state) state_class = BeaconState + result_state_root = chaindb.get_state_root_by_slot(state.slot) + assert result_state_root == state.hash_tree_root result_state = chaindb.get_state_by_root(state.hash_tree_root, state_class) assert result_state.hash_tree_root == state.hash_tree_root - result_state = chaindb.get_state_by_slot(state.slot, state_class) - assert result_state.hash_tree_root == state.hash_tree_root def test_chaindb_get_finalized_head_at_genesis(chaindb_at_genesis, genesis_block): - assert chaindb_at_genesis.get_finalized_head(genesis_block.__class__) == genesis_block + assert ( + chaindb_at_genesis.get_finalized_head(genesis_block.__class__) == genesis_block + ) def test_chaindb_get_justified_head_at_genesis(chaindb_at_genesis, genesis_block): - assert chaindb_at_genesis.get_justified_head(genesis_block.__class__) == genesis_block + assert ( + chaindb_at_genesis.get_justified_head(genesis_block.__class__) == genesis_block + ) -def test_chaindb_get_finalized_head(chaindb_at_genesis, - genesis_block, - genesis_state, - sample_beacon_block_params, - fork_choice_scoring): +def test_chaindb_get_finalized_head( + chaindb_at_genesis, + genesis_block, + genesis_state, + sample_beacon_block_params, + fork_choice_scoring, +): chaindb = chaindb_at_genesis block = BeaconBlock(**sample_beacon_block_params).copy( - parent_root=genesis_block.signing_root, + parent_root=genesis_block.signing_root ) assert chaindb.get_finalized_head(genesis_block.__class__) == genesis_block assert chaindb.get_justified_head(genesis_block.__class__) == genesis_block state_with_finalized_block = genesis_state.copy( - finalized_checkpoint=Checkpoint( - root=block.signing_root, - ) + finalized_checkpoint=Checkpoint(root=block.signing_root) ) chaindb.persist_state(state_with_finalized_block) chaindb.persist_block(block, BeaconBlock, fork_choice_scoring) @@ -213,15 +202,17 @@ def test_chaindb_get_finalized_head(chaindb_at_genesis, assert chaindb.get_justified_head(genesis_block.__class__) == genesis_block -def test_chaindb_get_justified_head(chaindb_at_genesis, - genesis_block, - genesis_state, - sample_beacon_block_params, - fork_choice_scoring, - config): +def test_chaindb_get_justified_head( + chaindb_at_genesis, + genesis_block, + genesis_state, + sample_beacon_block_params, + fork_choice_scoring, + config, +): chaindb = chaindb_at_genesis block = BeaconBlock(**sample_beacon_block_params).copy( - parent_root=genesis_block.signing_root, + parent_root=genesis_block.signing_root ) assert chaindb.get_finalized_head(genesis_block.__class__) == genesis_block @@ -230,8 +221,7 @@ def test_chaindb_get_justified_head(chaindb_at_genesis, # test that there is only one justified head per epoch state_with_bad_epoch = genesis_state.copy( current_justified_checkpoint=Checkpoint( - root=block.signing_root, - epoch=config.GENESIS_EPOCH, + root=block.signing_root, epoch=config.GENESIS_EPOCH ) ) chaindb.persist_state(state_with_bad_epoch) @@ -243,8 +233,7 @@ def test_chaindb_get_justified_head(chaindb_at_genesis, # test that the we can update justified head if we satisfy the invariants state_with_justified_block = genesis_state.copy( current_justified_checkpoint=Checkpoint( - root=block.signing_root, - epoch=config.GENESIS_EPOCH + 1, + root=block.signing_root, epoch=config.GENESIS_EPOCH + 1 ) ) chaindb.persist_state(state_with_justified_block) @@ -272,18 +261,12 @@ def test_chaindb_get_canonical_head(chaindb, block, fork_choice_scoring): result_block = chaindb.get_canonical_head(block.__class__) assert result_block == block - block_2 = block.copy( - slot=block.slot + 1, - parent_root=block.signing_root, - ) + block_2 = block.copy(slot=block.slot + 1, parent_root=block.signing_root) chaindb.persist_block(block_2, block_2.__class__, fork_choice_scoring) result_block = chaindb.get_canonical_head(block.__class__) assert result_block == block_2 - block_3 = block.copy( - slot=block_2.slot + 1, - parent_root=block_2.signing_root, - ) + block_3 = block.copy(slot=block_2.slot + 1, parent_root=block_2.signing_root) chaindb.persist_block(block_3, block_3.__class__, fork_choice_scoring) result_block = chaindb.get_canonical_head(block.__class__) assert result_block == block_3 @@ -296,20 +279,23 @@ def test_get_slot_by_root(chaindb, block, fork_choice_scoring): assert result_slot == block_slot -def test_chaindb_add_attestations_root_to_block_lookup(chaindb, - block_with_attestation, - fork_choice_scoring): +def test_chaindb_add_attestations_root_to_block_lookup( + chaindb, block_with_attestation, fork_choice_scoring +): block, attestation = block_with_attestation assert not chaindb.attestation_exists(attestation.hash_tree_root) chaindb.persist_block(block, block.__class__, fork_choice_scoring) assert chaindb.attestation_exists(attestation.hash_tree_root) -def test_chaindb_get_attestation_key_by_root(chaindb, block_with_attestation, fork_choice_scoring): +def test_chaindb_get_attestation_key_by_root( + chaindb, block_with_attestation, fork_choice_scoring +): block, attestation = block_with_attestation with pytest.raises(AttestationRootNotFound): chaindb.get_attestation_key_by_root(attestation.hash_tree_root) chaindb.persist_block(block, block.__class__, fork_choice_scoring) - assert chaindb.get_attestation_key_by_root( - attestation.hash_tree_root - ) == (block.signing_root, 0) + assert chaindb.get_attestation_key_by_root(attestation.hash_tree_root) == ( + block.signing_root, + 0, + ) diff --git a/tests/eth2/core/beacon/fork_choice/test_higher_slot.py b/tests/eth2/core/beacon/fork_choice/test_higher_slot.py index 1db0fcd8a5..62792b1621 100644 --- a/tests/eth2/core/beacon/fork_choice/test_higher_slot.py +++ b/tests/eth2/core/beacon/fork_choice/test_higher_slot.py @@ -4,14 +4,9 @@ from eth2.beacon.types.blocks import BeaconBlock -@pytest.mark.parametrize( - "slot", - (i for i in range(10)), -) +@pytest.mark.parametrize("slot", (i for i in range(10))) def test_higher_slot_fork_choice_scoring(sample_beacon_block_params, slot): - block = BeaconBlock(**sample_beacon_block_params).copy( - slot=slot, - ) + block = BeaconBlock(**sample_beacon_block_params).copy(slot=slot) expected_score = slot diff --git a/tests/eth2/core/beacon/fork_choice/test_lmd_ghost.py b/tests/eth2/core/beacon/fork_choice/test_lmd_ghost.py index c994830a61..6c3f73c4a1 100644 --- a/tests/eth2/core/beacon/fork_choice/test_lmd_ghost.py +++ b/tests/eth2/core/beacon/fork_choice/test_lmd_ghost.py @@ -1,6 +1,5 @@ import random -import pytest from eth_utils import to_dict from eth_utils.toolz import ( first, @@ -11,11 +10,10 @@ second, sliding_window, ) +import pytest from eth2._utils import bitfield -from eth2.beacon.attestation_helpers import ( - get_attestation_data_slot, -) +from eth2.beacon.attestation_helpers import get_attestation_data_slot from eth2.beacon.committee_helpers import ( get_committee_count, get_crosslink_committee, @@ -29,13 +27,11 @@ score_block_by_root, ) from eth2.beacon.helpers import ( - compute_start_slot_of_epoch, compute_epoch_of_slot, + compute_start_slot_of_epoch, get_active_validator_indices, ) -from eth2.beacon.tools.builder.validator import ( - get_crosslink_committees_at_slot, -) +from eth2.beacon.tools.builder.validator import get_crosslink_committees_at_slot from eth2.beacon.types.attestation_data import AttestationData from eth2.beacon.types.attestations import Attestation from eth2.beacon.types.blocks import BeaconBlock @@ -55,18 +51,13 @@ def _mk_attestation_inputs_in_epoch(epoch, state, config): config.SLOTS_PER_EPOCH, config.TARGET_COMMITTEE_SIZE, ) - epoch_start_shard = get_start_shard( - state, - epoch, - CommitteeConfig(config), - ) - for shard_offset in random.sample(range(epoch_committee_count), epoch_committee_count): + epoch_start_shard = get_start_shard(state, epoch, CommitteeConfig(config)) + for shard_offset in random.sample( + range(epoch_committee_count), epoch_committee_count + ): shard = Shard((epoch_start_shard + shard_offset) % config.SHARD_COUNT) committee = get_crosslink_committee( - state, - epoch, - shard, - CommitteeConfig(config), + state, epoch, shard, CommitteeConfig(config) ) if not committee: @@ -74,12 +65,7 @@ def _mk_attestation_inputs_in_epoch(epoch, state, config): continue attestation_data = AttestationData( - target=Checkpoint( - epoch=epoch, - ), - crosslink=Crosslink( - shard=shard, - ), + target=Checkpoint(epoch=epoch), crosslink=Crosslink(shard=shard) ) committee_count = len(committee) aggregation_bits = bitfield.get_empty_bitfield(committee_count) @@ -90,30 +76,18 @@ def _mk_attestation_inputs_in_epoch(epoch, state, config): yield ( index, ( - get_attestation_data_slot( - state, - attestation_data, - config, - ), - ( - aggregation_bits, - attestation_data, - ), + get_attestation_data_slot(state, attestation_data, config), + (aggregation_bits, attestation_data), ), ) -def _mk_attestations_for_epoch_by_count(number_of_committee_samples, - epoch, - state, - config): +def _mk_attestations_for_epoch_by_count( + number_of_committee_samples, epoch, state, config +): results = {} for _ in range(number_of_committee_samples): - sample = _mk_attestation_inputs_in_epoch( - epoch, - state, - config, - ) + sample = _mk_attestation_inputs_in_epoch(epoch, state, config) results = merge(results, sample) return results @@ -122,10 +96,7 @@ def _extract_attestations_from_index_keying(values): results = () for value in values: aggregation_bits, data = second(value) - attestation = Attestation( - aggregation_bits=aggregation_bits, - data=data, - ) + attestation = Attestation(aggregation_bits=aggregation_bits, data=data) if attestation not in results: results += (attestation,) return results @@ -145,16 +116,16 @@ def _find_collision(state, config, index, epoch): validator w/ the given index. """ active_validators = get_active_validator_indices(state.validators, epoch) - committees_per_slot = get_committee_count( - len(active_validators), - config.SHARD_COUNT, - config.SLOTS_PER_EPOCH, - config.TARGET_COMMITTEE_SIZE, - ) // config.SLOTS_PER_EPOCH - epoch_start_slot = compute_start_slot_of_epoch( - epoch, - config.SLOTS_PER_EPOCH, + committees_per_slot = ( + get_committee_count( + len(active_validators), + config.SHARD_COUNT, + config.SLOTS_PER_EPOCH, + config.TARGET_COMMITTEE_SIZE, + ) + // config.SLOTS_PER_EPOCH ) + epoch_start_slot = compute_start_slot_of_epoch(epoch, config.SLOTS_PER_EPOCH) epoch_start_shard = get_start_shard(state, epoch, CommitteeConfig(config)) for slot in range(epoch_start_slot, epoch_start_slot + config.SLOTS_PER_EPOCH): @@ -162,16 +133,13 @@ def _find_collision(state, config, index, epoch): slot_start_shard = (epoch_start_shard + offset) % config.SHARD_COUNT for i in range(committees_per_slot): shard = Shard((slot_start_shard + i) % config.SHARD_COUNT) - committee = get_crosslink_committee(state, epoch, shard, CommitteeConfig(config)) + committee = get_crosslink_committee( + state, epoch, shard, CommitteeConfig(config) + ) if index in committee: # TODO(ralexstokes) refactor w/ tools/builder attestation_data = AttestationData( - target=Checkpoint( - epoch=epoch, - ), - crosslink=Crosslink( - shard=shard, - ), + target=Checkpoint(epoch=epoch), crosslink=Crosslink(shard=shard) ) committee_count = len(committee) aggregation_bits = bitfield.get_empty_bitfield(committee_count) @@ -179,18 +147,14 @@ def _find_collision(state, config, index, epoch): aggregation_bits = bitfield.set_voted(aggregation_bits, i) return { - index: ( - slot, (aggregation_bits, attestation_data) - ) + index: (slot, (aggregation_bits, attestation_data)) for index in committee } else: raise Exception("should have found a duplicate validator") -def _introduce_collisions(all_attestations_by_index, - state, - config): +def _introduce_collisions(all_attestations_by_index, state, config): """ Find some attestations for later epochs for the validators that are current attesting in each source of attestation. @@ -223,57 +187,34 @@ def _get_committee_count(state, epoch, config): @pytest.mark.parametrize( - ( - "validator_count", - ), + ("validator_count",), [ - (8,), # low number of validators - (128,), # medium number of validators + (8,), # low number of validators + (128,), # medium number of validators # NOTE: the test at 1024 count takes too long :( (256,), # high number of validators - ] + ], ) -@pytest.mark.parametrize( - ( - "collisions_from_another_epoch", - ), - [ - (True,), - (False,), - ] -) -def test_store_get_latest_attestation(genesis_state, - empty_attestation_pool, - config, - collisions_from_another_epoch): +@pytest.mark.parametrize(("collisions_from_another_epoch",), [(True,), (False,)]) +def test_store_get_latest_attestation( + genesis_state, empty_attestation_pool, config, collisions_from_another_epoch +): """ Given some attestations across the various sources, can we find the latest ones for each validator? """ some_epoch = 3 state = genesis_state.copy( - slot=compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH), + slot=compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH) ) previous_epoch = state.previous_epoch(config.SLOTS_PER_EPOCH, config.GENESIS_EPOCH) - previous_epoch_committee_count = _get_committee_count( - state, - previous_epoch, - config, - ) + previous_epoch_committee_count = _get_committee_count(state, previous_epoch, config) current_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) - current_epoch_committee_count = _get_committee_count( - state, - current_epoch, - config, - ) + current_epoch_committee_count = _get_committee_count(state, current_epoch, config) next_epoch = state.next_epoch(config.SLOTS_PER_EPOCH) - next_epoch_committee_count = _get_committee_count( - state, - next_epoch, - config, - ) + next_epoch_committee_count = _get_committee_count(state, next_epoch, config) number_of_committee_samples = 4 assert number_of_committee_samples <= previous_epoch_committee_count @@ -282,42 +223,30 @@ def test_store_get_latest_attestation(genesis_state, # prepare samples from previous epoch previous_epoch_attestations_by_index = _mk_attestations_for_epoch_by_count( - number_of_committee_samples, - previous_epoch, - state, - config, + number_of_committee_samples, previous_epoch, state, config ) previous_epoch_attestations = _extract_attestations_from_index_keying( - previous_epoch_attestations_by_index.values(), + previous_epoch_attestations_by_index.values() ) # prepare samples from current epoch current_epoch_attestations_by_index = _mk_attestations_for_epoch_by_count( - number_of_committee_samples, - current_epoch, - state, - config, + number_of_committee_samples, current_epoch, state, config ) current_epoch_attestations_by_index = keyfilter( lambda index: index not in previous_epoch_attestations_by_index, current_epoch_attestations_by_index, ) current_epoch_attestations = _extract_attestations_from_index_keying( - current_epoch_attestations_by_index.values(), + current_epoch_attestations_by_index.values() ) # prepare samples for pool, taking half from the current epoch and half from the next epoch pool_attestations_in_current_epoch_by_index = _mk_attestations_for_epoch_by_count( - number_of_committee_samples // 2, - current_epoch, - state, - config, + number_of_committee_samples // 2, current_epoch, state, config ) pool_attestations_in_next_epoch_by_index = _mk_attestations_for_epoch_by_count( - number_of_committee_samples // 2, - next_epoch, - state, - config, + number_of_committee_samples // 2, next_epoch, state, config ) pool_attestations_by_index = merge( pool_attestations_in_current_epoch_by_index, @@ -325,13 +254,13 @@ def test_store_get_latest_attestation(genesis_state, ) pool_attestations_by_index = keyfilter( lambda index: ( - index not in previous_epoch_attestations_by_index or - index not in current_epoch_attestations_by_index + index not in previous_epoch_attestations_by_index + or index not in current_epoch_attestations_by_index ), pool_attestations_by_index, ) pool_attestations = _extract_attestations_from_index_keying( - pool_attestations_by_index.values(), + pool_attestations_by_index.values() ) all_attestations_by_index = ( @@ -345,20 +274,16 @@ def test_store_get_latest_attestation(genesis_state, previous_epoch_attestations_by_index, current_epoch_attestations_by_index, pool_attestations_by_index, - ) = _introduce_collisions( - all_attestations_by_index, - state, - config, - ) + ) = _introduce_collisions(all_attestations_by_index, state, config) previous_epoch_attestations = _extract_attestations_from_index_keying( - previous_epoch_attestations_by_index.values(), + previous_epoch_attestations_by_index.values() ) current_epoch_attestations = _extract_attestations_from_index_keying( - current_epoch_attestations_by_index.values(), + current_epoch_attestations_by_index.values() ) pool_attestations = _extract_attestations_from_index_keying( - pool_attestations_by_index.values(), + pool_attestations_by_index.values() ) # build expected results @@ -400,12 +325,9 @@ def _mk_block(block_params, slot, parent, block_offset): ) -def _build_block_tree(block_params, - root_block, - base_slot, - forking_descriptor, - forking_asymmetry, - config): +def _build_block_tree( + block_params, root_block, base_slot, forking_descriptor, forking_asymmetry, config +): """ build a block tree according to the data in ``forking_descriptor``, starting at the block with root ``base_root``. @@ -419,12 +341,7 @@ def _build_block_tree(block_params, if random.choice([True, False]): continue for block_offset in range(block_count): - block = _mk_block( - block_params, - slot, - parent, - block_offset, - ) + block = _mk_block(block_params, slot, parent, block_offset) blocks.append(block) tree.append(blocks) # other code written w/ expectation that root is not in the tree @@ -448,19 +365,14 @@ def _iter_block_tree_by_block(tree): yield block -def _get_committees(state, - target_slot, - config, - sampling_fraction): +def _get_committees(state, target_slot, config, sampling_fraction): crosslink_committees_at_slot = get_crosslink_committees_at_slot( - state, - target_slot, - config=config, + state, target_slot, config=config ) return tuple( random.sample( crosslink_committees_at_slot, - int((sampling_fraction * len(crosslink_committees_at_slot))) + int((sampling_fraction * len(crosslink_committees_at_slot))), ) ) @@ -470,7 +382,7 @@ def _attach_committee_to_block(block, committee): def _get_committee_from_block(block): - return getattr(block, '_committee_data', None) + return getattr(block, "_committee_data", None) def _attach_attestation_to_block(block, attestation): @@ -478,24 +390,18 @@ def _attach_attestation_to_block(block, attestation): def _get_attestation_from_block(block): - return getattr(block, '_attestation', None) + return getattr(block, "_attestation", None) -def _attach_committees_to_block_tree(state, - block_tree, - committees_by_slot, - config, - forking_asymmetry): +def _attach_committees_to_block_tree( + state, block_tree, committees_by_slot, config, forking_asymmetry +): for level, committees in zip( - _iter_block_tree_by_slot(block_tree), - committees_by_slot, + _iter_block_tree_by_slot(block_tree), committees_by_slot ): block_count = len(level) partitions = partition(block_count, committees) - for block, committee in zip( - _iter_block_level_by_block(level), - partitions, - ): + for block, committee in zip(_iter_block_level_by_block(level), partitions): if forking_asymmetry: if random.choice([True, False]): # random drop out @@ -515,11 +421,9 @@ def _mk_attestation_for_block_with_committee(block, committee, shard, config): data=AttestationData( beacon_block_root=block.signing_root, target=Checkpoint( - epoch=compute_epoch_of_slot(block.slot, config.SLOTS_PER_EPOCH), - ), - crosslink=Crosslink( - shard=shard, + epoch=compute_epoch_of_slot(block.slot, config.SLOTS_PER_EPOCH) ), + crosslink=Crosslink(shard=shard), ), ) return attestation @@ -532,7 +436,9 @@ def _attach_attestations_to_block_tree_with_committees(block_tree, config): # w/ asymmetry in forking we may need to skip this step continue committee, shard = committee_data - attestation = _mk_attestation_for_block_with_committee(block, committee, shard, config) + attestation = _mk_attestation_for_block_with_committee( + block, committee, shard, config + ) _attach_attestation_to_block(block, attestation) @@ -553,10 +459,7 @@ def _build_score_index_from_decorated_block_tree(block_tree, store, state, confi def _iter_attestation_by_validator_index(state, attestation, config): for index in get_attesting_indices( - state, - attestation.data, - attestation.aggregation_bits, - config, + state, attestation.data, attestation.aggregation_bits, config ): yield index @@ -573,8 +476,7 @@ def __init__(self, state, root_block, block_tree, attestation_pool, config): self._config = config self._latest_attestations = self._find_attestation_targets() self._block_index = { - block.signing_root: block - for block in _iter_block_tree_by_block(block_tree) + block.signing_root: block for block in _iter_block_tree_by_block(block_tree) } self._block_index[root_block.signing_root] = root_block self._blocks_by_parent_root = { @@ -586,20 +488,15 @@ def _find_attestation_targets(self): result = {} for _, attestation in self._attestation_pool: target_slot = get_attestation_data_slot( - self._state, - attestation.data, - self._config, + self._state, attestation.data, self._config ) for validator_index in _iter_attestation_by_validator_index( - self._state, - attestation, - self._config): + self._state, attestation, self._config + ): if validator_index in result: existing = result[validator_index] existing_slot = get_attestation_data_slot( - self._state, - existing.data, - self._config + self._state, existing.data, self._config ) if existing_slot > target_slot: continue @@ -623,14 +520,12 @@ def _get_ancestor(self, block, slot): @pytest.mark.parametrize( - ( - "validator_count", - ), + ("validator_count",), [ - (8,), # low number of validators - (128,), # medium number of validators + (8,), # low number of validators + (128,), # medium number of validators (1024,), # high number of validators - ] + ], ) @pytest.mark.parametrize( ( @@ -646,7 +541,7 @@ def _get_ancestor(self, block, slot): ((3, 2),), ((1, 4),), ((1, 2, 1),), - ] + ], ) @pytest.mark.parametrize( ( @@ -660,17 +555,19 @@ def _get_ancestor(self, block, slot): # the number of children prescribed in ``forking_descriptor``. # => randomly drop some blocks from receiving attestations (False,), - ] + ], ) -def test_lmd_ghost_fork_choice_scoring(sample_beacon_block_params, - chaindb_at_genesis, - # see note below on how this is used - fork_choice_scoring, - forking_descriptor, - forking_asymmetry, - genesis_state, - empty_attestation_pool, - config): +def test_lmd_ghost_fork_choice_scoring( + sample_beacon_block_params, + chaindb_at_genesis, + # see note below on how this is used + fork_choice_scoring, + forking_descriptor, + forking_asymmetry, + genesis_state, + empty_attestation_pool, + config, +): """ Given some blocks and some attestations, can we score them correctly? """ @@ -681,11 +578,11 @@ def test_lmd_ghost_fork_choice_scoring(sample_beacon_block_params, some_slot_offset = 10 state = genesis_state.copy( - slot=compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH) + some_slot_offset, + slot=compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH) + + some_slot_offset, current_justified_checkpoint=Checkpoint( - epoch=some_epoch, - root=root_block.signing_root, - ) + epoch=some_epoch, root=root_block.signing_root + ), ) assert some_epoch >= state.current_justified_checkpoint.epoch @@ -704,31 +601,21 @@ def test_lmd_ghost_fork_choice_scoring(sample_beacon_block_params, committee_sampling_fraction = 1 committees_by_slot = tuple( _get_committees( - state, - base_slot + slot_offset, - config, - committee_sampling_fraction, + state, base_slot + slot_offset, config, committee_sampling_fraction ) for slot_offset in range(slot_count) ) _attach_committees_to_block_tree( - state, - block_tree, - committees_by_slot, - config, - forking_asymmetry, + state, block_tree, committees_by_slot, config, forking_asymmetry ) - _attach_attestations_to_block_tree_with_committees( - block_tree, - config, - ) + _attach_attestations_to_block_tree_with_committees(block_tree, config) attestations = tuple( - _get_attestation_from_block(block) for block in _iter_block_tree_by_block( - block_tree, - ) if _get_attestation_from_block(block) + _get_attestation_from_block(block) + for block in _iter_block_tree_by_block(block_tree) + if _get_attestation_from_block(block) ) attestation_pool = empty_attestation_pool @@ -738,17 +625,16 @@ def test_lmd_ghost_fork_choice_scoring(sample_beacon_block_params, store = _store(state, root_block, block_tree, attestation_pool, config) score_index = _build_score_index_from_decorated_block_tree( - block_tree, - store, - state, - config, + block_tree, store, state, config ) for block in _iter_block_tree_by_block(block_tree): # NOTE: we use the ``fork_choice_scoring`` fixture, it doesn't matter for this test chain_db.persist_block(block, BeaconBlock, fork_choice_scoring) - scoring_fn = lmd_ghost_scoring(chain_db, attestation_pool, state, config, BeaconBlock) + scoring_fn = lmd_ghost_scoring( + chain_db, attestation_pool, state, config, BeaconBlock + ) for block in _iter_block_tree_by_block(block_tree): score = scoring_fn(block) diff --git a/tests/eth2/core/beacon/operations/test_pool.py b/tests/eth2/core/beacon/operations/test_pool.py index ca32ac02bb..971122314d 100644 --- a/tests/eth2/core/beacon/operations/test_pool.py +++ b/tests/eth2/core/beacon/operations/test_pool.py @@ -8,9 +8,7 @@ def mk_attestation(index, sample_attestation_params): - return Attestation(**sample_attestation_params).copy( - custody_bits=(True,) * 128, - ) + return Attestation(**sample_attestation_params).copy(custody_bits=(True,) * 128) def test_iterating_operation_pool(sample_attestation_params): diff --git a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_attestation_validation.py b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_attestation_validation.py index 9b0e07cf21..938a3ebd20 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_attestation_validation.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_attestation_validation.py @@ -1,42 +1,25 @@ +from eth.constants import ZERO_HASH32 +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) - -from eth.constants import ( - ZERO_HASH32, -) -from eth2.beacon.committee_helpers import ( - get_start_shard, -) -from eth2.beacon.helpers import ( - compute_start_slot_of_epoch, -) +from eth2.beacon.committee_helpers import get_start_shard +from eth2.beacon.helpers import compute_start_slot_of_epoch from eth2.beacon.state_machines.forks.serenity.block_validation import ( - validate_attestation_slot, _validate_attestation_data, _validate_crosslink, + validate_attestation_slot, ) from eth2.beacon.types.attestation_data import AttestationData from eth2.beacon.types.checkpoints import Checkpoint from eth2.beacon.types.crosslinks import Crosslink - from eth2.configs import CommitteeConfig @pytest.mark.parametrize( - ('slots_per_epoch', 'min_attestation_inclusion_delay'), - [ - (4, 2), - ] + ("slots_per_epoch", "min_attestation_inclusion_delay"), [(4, 2)] ) @pytest.mark.parametrize( - ( - 'attestation_slot,' - 'state_slot,' - 'is_valid,' - ), + ("attestation_slot," "state_slot," "is_valid,"), [ # in bounds at lower end (8, 2 + 8, True), @@ -46,13 +29,15 @@ (8, 8 + 4 + 1, False), # attestation_slot + min_attestation_inclusion_delay > state_slot (8, 8 - 2, False), - ] + ], ) -def test_validate_attestation_slot(attestation_slot, - state_slot, - slots_per_epoch, - min_attestation_inclusion_delay, - is_valid): +def test_validate_attestation_slot( + attestation_slot, + state_slot, + slots_per_epoch, + min_attestation_inclusion_delay, + is_valid, +): if is_valid: validate_attestation_slot( @@ -73,53 +58,41 @@ def test_validate_attestation_slot(attestation_slot, @pytest.mark.parametrize( ( - 'current_epoch', - 'previous_justified_epoch', - 'current_justified_epoch', - 'slots_per_epoch', + "current_epoch", + "previous_justified_epoch", + "current_justified_epoch", + "slots_per_epoch", ), - [ - (3, 1, 2, 8) - ] + [(3, 1, 2, 8)], ) @pytest.mark.parametrize( - ( - 'attestation_source_epoch', - 'attestation_target_epoch', - 'is_valid', - ), + ("attestation_source_epoch", "attestation_target_epoch", "is_valid"), [ (2, 3, True), # wrong target_epoch (0, 1, False), # wrong source checkpoint (1, 3, False), - ] + ], ) -def test_validate_attestation_data(genesis_state, - sample_attestation_data_params, - attestation_source_epoch, - attestation_target_epoch, - current_epoch, - previous_justified_epoch, - current_justified_epoch, - slots_per_epoch, - config, - is_valid): +def test_validate_attestation_data( + genesis_state, + sample_attestation_data_params, + attestation_source_epoch, + attestation_target_epoch, + current_epoch, + previous_justified_epoch, + current_justified_epoch, + slots_per_epoch, + config, + is_valid, +): state = genesis_state.copy( slot=compute_start_slot_of_epoch(current_epoch, slots_per_epoch) + 5, - previous_justified_checkpoint=Checkpoint( - epoch=previous_justified_epoch, - ), - current_justified_checkpoint=Checkpoint( - epoch=current_justified_epoch, - ), - ) - start_shard = get_start_shard( - state, - current_epoch, - CommitteeConfig(config), + previous_justified_checkpoint=Checkpoint(epoch=previous_justified_epoch), + current_justified_checkpoint=Checkpoint(epoch=current_justified_epoch), ) + start_shard = get_start_shard(state, current_epoch, CommitteeConfig(config)) if attestation_target_epoch == current_epoch: crosslinks = state.current_crosslinks else: @@ -127,12 +100,8 @@ def test_validate_attestation_data(genesis_state, parent_crosslink = crosslinks[start_shard] attestation_data = AttestationData(**sample_attestation_data_params).copy( - source=Checkpoint( - epoch=attestation_source_epoch, - ), - target=Checkpoint( - epoch=attestation_target_epoch, - ), + source=Checkpoint(epoch=attestation_source_epoch), + target=Checkpoint(epoch=attestation_target_epoch), crosslink=Crosslink( start_epoch=parent_crosslink.end_epoch, end_epoch=attestation_target_epoch, @@ -142,49 +111,27 @@ def test_validate_attestation_data(genesis_state, ) if is_valid: - _validate_attestation_data( - state, - attestation_data, - config, - ) + _validate_attestation_data(state, attestation_data, config) else: with pytest.raises(ValidationError): - _validate_attestation_data( - state, - attestation_data, - config, - ) + _validate_attestation_data(state, attestation_data, config) @pytest.mark.parametrize( - ( - 'mutator', - 'is_valid', - ), + ("mutator", "is_valid"), [ (lambda c: c, True), # crosslink.start_epoch != end_epoch - (lambda c: c.copy( - start_epoch=c.start_epoch + 1, - ), False), + (lambda c: c.copy(start_epoch=c.start_epoch + 1), False), # end_epoch does not match expected - (lambda c: c.copy( - end_epoch=c.start_epoch + 10, - ), False), + (lambda c: c.copy(end_epoch=c.start_epoch + 10), False), # parent_root does not match - (lambda c: c.copy( - parent_root=b'\x33' * 32, - ), False), + (lambda c: c.copy(parent_root=b"\x33" * 32), False), # data_root is nonzero - (lambda c: c.copy( - data_root=b'\x33' * 32, - ), False), - ] + (lambda c: c.copy(data_root=b"\x33" * 32), False), + ], ) -def test_validate_crosslink(genesis_state, - mutator, - is_valid, - config): +def test_validate_crosslink(genesis_state, mutator, is_valid, config): some_shard = 3 parent = genesis_state.current_crosslinks[some_shard] target_epoch = config.GENESIS_EPOCH + 1 @@ -200,10 +147,7 @@ def test_validate_crosslink(genesis_state, if is_valid: _validate_crosslink( - candidate_crosslink, - target_epoch, - parent, - config.MAX_EPOCHS_PER_CROSSLINK, + candidate_crosslink, target_epoch, parent, config.MAX_EPOCHS_PER_CROSSLINK ) else: with pytest.raises(ValidationError): diff --git a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_processing.py b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_processing.py index 08831afb99..d946ddfbee 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_processing.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_processing.py @@ -1,54 +1,31 @@ +from eth.constants import ZERO_HASH32 +from eth_utils import ValidationError +from eth_utils.toolz import concat, first, mapcat import pytest -from eth.constants import ( - ZERO_HASH32, -) -from eth_utils import ( - ValidationError, -) -from eth_utils.toolz import ( - first, - mapcat, - concat, -) - from eth2._utils.bls import bls - -from eth2.beacon.types.blocks import BeaconBlock, BeaconBlockBody -from eth2.beacon.types.eth1_data import Eth1Data -from eth2.beacon.types.states import BeaconState +from eth2.beacon.helpers import compute_start_slot_of_epoch, get_domain from eth2.beacon.signature_domain import SignatureDomain - -from eth2.beacon.helpers import ( - get_domain, - compute_start_slot_of_epoch, -) - -from eth2.beacon.state_machines.forks.serenity.blocks import ( - SerenityBeaconBlock, -) -from eth2.beacon.state_machines.forks.serenity.states import ( - SerenityBeaconState, -) - from eth2.beacon.state_machines.forks.serenity.block_processing import ( process_eth1_data, process_randao, ) -from eth2.beacon.tools.builder.proposer import ( - _generate_randao_reveal, -) - -from eth2.beacon.tools.builder.initializer import ( - create_mock_validator, -) +from eth2.beacon.state_machines.forks.serenity.blocks import SerenityBeaconBlock +from eth2.beacon.state_machines.forks.serenity.states import SerenityBeaconState +from eth2.beacon.tools.builder.initializer import create_mock_validator +from eth2.beacon.tools.builder.proposer import _generate_randao_reveal +from eth2.beacon.types.blocks import BeaconBlock, BeaconBlockBody +from eth2.beacon.types.eth1_data import Eth1Data +from eth2.beacon.types.states import BeaconState -def test_randao_processing(sample_beacon_block_params, - sample_beacon_block_body_params, - sample_beacon_state_params, - keymap, - config): +def test_randao_processing( + sample_beacon_block_params, + sample_beacon_block_body_params, + sample_beacon_state_params, + keymap, + config, +): proposer_pubkey, proposer_privkey = first(keymap.items()) state = SerenityBeaconState(**sample_beacon_state_params).copy( validators=tuple( @@ -56,10 +33,8 @@ def test_randao_processing(sample_beacon_block_params, for _ in range(config.TARGET_COMMITTEE_SIZE) ), balances=(config.MAX_EFFECTIVE_BALANCE,) * config.TARGET_COMMITTEE_SIZE, - randao_mixes=tuple( - ZERO_HASH32 - for _ in range(config.EPOCHS_PER_HISTORICAL_VECTOR) + ZERO_HASH32 for _ in range(config.EPOCHS_PER_HISTORICAL_VECTOR) ), ) @@ -67,19 +42,14 @@ def test_randao_processing(sample_beacon_block_params, slot = compute_start_slot_of_epoch(epoch, config.SLOTS_PER_EPOCH) randao_reveal = _generate_randao_reveal( - privkey=proposer_privkey, - slot=slot, - state=state, - config=config, + privkey=proposer_privkey, slot=slot, state=state, config=config ) block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - randao_reveal=randao_reveal, + randao_reveal=randao_reveal ) - block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - body=block_body, - ) + block = SerenityBeaconBlock(**sample_beacon_block_params).copy(body=block_body) new_state = process_randao(state, block, config) @@ -93,12 +63,14 @@ def test_randao_processing(sample_beacon_block_params, ) -def test_randao_processing_validates_randao_reveal(sample_beacon_block_params, - sample_beacon_block_body_params, - sample_beacon_state_params, - sample_fork_params, - keymap, - config): +def test_randao_processing_validates_randao_reveal( + sample_beacon_block_params, + sample_beacon_block_body_params, + sample_beacon_state_params, + sample_fork_params, + keymap, + config, +): proposer_pubkey, proposer_privkey = first(keymap.items()) state = SerenityBeaconState(**sample_beacon_state_params).copy( validators=tuple( @@ -106,10 +78,8 @@ def test_randao_processing_validates_randao_reveal(sample_beacon_block_params, for _ in range(config.TARGET_COMMITTEE_SIZE) ), balances=(config.MAX_EFFECTIVE_BALANCE,) * config.TARGET_COMMITTEE_SIZE, - randao_mixes=tuple( - ZERO_HASH32 - for _ in range(config.EPOCHS_PER_HISTORICAL_VECTOR) + ZERO_HASH32 for _ in range(config.EPOCHS_PER_HISTORICAL_VECTOR) ), ) @@ -119,12 +89,10 @@ def test_randao_processing_validates_randao_reveal(sample_beacon_block_params, randao_reveal = bls.sign(message_hash, proposer_privkey, domain) block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - randao_reveal=randao_reveal, + randao_reveal=randao_reveal ) - block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - body=block_body, - ) + block = SerenityBeaconBlock(**sample_beacon_block_params).copy(body=block_body) with pytest.raises(ValidationError): process_randao(state, block, config) @@ -136,64 +104,48 @@ def test_randao_processing_validates_randao_reveal(sample_beacon_block_params, def _expand_eth1_votes(args): block_hash, vote_count = args - return (Eth1Data( - block_hash=block_hash, - ),) * vote_count - - -@pytest.mark.parametrize(("original_votes", "block_data", "expected_votes"), ( - ((), HASH1, ((HASH1, 1),)), - (((HASH1, 5),), HASH1, ((HASH1, 6),)), - (((HASH2, 5),), HASH1, ((HASH2, 5), (HASH1, 1))), - (((HASH1, 10), (HASH2, 2)), HASH2, ((HASH1, 10), (HASH2, 3))), -)) -def test_process_eth1_data(original_votes, - block_data, - expected_votes, - sample_beacon_state_params, - sample_beacon_block_params, - sample_beacon_block_body_params, - config): - eth1_data_votes = tuple(mapcat( - _expand_eth1_votes, - original_votes, - )) + return (Eth1Data(block_hash=block_hash),) * vote_count + + +@pytest.mark.parametrize( + ("original_votes", "block_data", "expected_votes"), + ( + ((), HASH1, ((HASH1, 1),)), + (((HASH1, 5),), HASH1, ((HASH1, 6),)), + (((HASH2, 5),), HASH1, ((HASH2, 5), (HASH1, 1))), + (((HASH1, 10), (HASH2, 2)), HASH2, ((HASH1, 10), (HASH2, 3))), + ), +) +def test_process_eth1_data( + original_votes, + block_data, + expected_votes, + sample_beacon_state_params, + sample_beacon_block_params, + sample_beacon_block_body_params, + config, +): + eth1_data_votes = tuple(mapcat(_expand_eth1_votes, original_votes)) state = BeaconState(**sample_beacon_state_params).copy( - eth1_data_votes=eth1_data_votes, + eth1_data_votes=eth1_data_votes ) block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - eth1_data=Eth1Data( - block_hash=block_data, - ), + eth1_data=Eth1Data(block_hash=block_data) ) - block = BeaconBlock(**sample_beacon_block_params).copy( - body=block_body, - ) + block = BeaconBlock(**sample_beacon_block_params).copy(body=block_body) updated_state = process_eth1_data(state, block, config) updated_votes = updated_state.eth1_data_votes - expanded_expected_votes = tuple(mapcat( - _expand_eth1_votes, - expected_votes, - )) + expanded_expected_votes = tuple(mapcat(_expand_eth1_votes, expected_votes)) assert updated_votes == expanded_expected_votes +@pytest.mark.parametrize(("slots_per_eth1_voting_period"), ((16),)) @pytest.mark.parametrize( - ( - 'slots_per_eth1_voting_period' - ), - ( - (16), - ) -) -@pytest.mark.parametrize( - ( - 'vote_offsets' # a tuple of offsets against the majority threshold - ), + ("vote_offsets"), # a tuple of offsets against the majority threshold ( # no eth1_data_votes (), @@ -209,31 +161,23 @@ def test_process_eth1_data(original_votes, (12,), # NOTE: we are accepting more than one block per slot if # there are multiple majorities so no need to test this - ) + ), ) -def test_ensure_update_eth1_vote_if_exists(genesis_state, - config, - vote_offsets): +def test_ensure_update_eth1_vote_if_exists(genesis_state, config, vote_offsets): # one less than a majority is the majority divided by 2 threshold = config.SLOTS_PER_ETH1_VOTING_PERIOD // 2 data_votes = tuple( concat( - ( - Eth1Data( - block_hash=(i).to_bytes(32, "little"), - ), - ) * (threshold + offset) + (Eth1Data(block_hash=(i).to_bytes(32, "little")),) * (threshold + offset) for i, offset in enumerate(vote_offsets) ) ) state = genesis_state for vote in data_votes: - state = process_eth1_data(state, BeaconBlock( - body=BeaconBlockBody( - eth1_data=vote, - ) - ), config) + state = process_eth1_data( + state, BeaconBlock(body=BeaconBlockBody(eth1_data=vote)), config + ) if not vote_offsets: assert state.eth1_data == genesis_state.eth1_data diff --git a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_proposer_slashing_validation.py b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_proposer_slashing_validation.py index f6393bfd4b..8f3c6639b7 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_proposer_slashing_validation.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_proposer_slashing_validation.py @@ -1,86 +1,61 @@ +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) - from eth2.beacon.state_machines.forks.serenity.block_validation import ( + validate_block_header_signature, validate_proposer_slashing, validate_proposer_slashing_epoch, validate_proposer_slashing_headers, - validate_block_header_signature, -) -from eth2.beacon.tools.builder.validator import ( - create_mock_proposer_slashing_at_block, ) +from eth2.beacon.tools.builder.validator import create_mock_proposer_slashing_at_block -def get_valid_proposer_slashing(state, - keymap, - config, - proposer_index=0): +def get_valid_proposer_slashing(state, keymap, config, proposer_index=0): return create_mock_proposer_slashing_at_block( state, config, keymap, - block_root_1=b'\x11' * 32, - block_root_2=b'\x22' * 32, + block_root_1=b"\x11" * 32, + block_root_2=b"\x22" * 32, proposer_index=proposer_index, ) -def test_validate_proposer_slashing_valid(genesis_state, - keymap, - slots_per_epoch, - config): +def test_validate_proposer_slashing_valid( + genesis_state, keymap, slots_per_epoch, config +): state = genesis_state - valid_proposer_slashing = get_valid_proposer_slashing( - state, - keymap, - config, - ) + valid_proposer_slashing = get_valid_proposer_slashing(state, keymap, config) validate_proposer_slashing(state, valid_proposer_slashing, slots_per_epoch) -def test_validate_proposer_slashing_epoch(genesis_state, - keymap, - config): +def test_validate_proposer_slashing_epoch(genesis_state, keymap, config): state = genesis_state - valid_proposer_slashing = get_valid_proposer_slashing( - state, - keymap, - config, - ) + valid_proposer_slashing = get_valid_proposer_slashing(state, keymap, config) # Valid validate_proposer_slashing_epoch(valid_proposer_slashing, config.SLOTS_PER_EPOCH) header_1 = valid_proposer_slashing.header_1.copy( slot=valid_proposer_slashing.header_2.slot + 2 * config.SLOTS_PER_EPOCH ) - invalid_proposer_slashing = valid_proposer_slashing.copy( - header_1=header_1, - ) + invalid_proposer_slashing = valid_proposer_slashing.copy(header_1=header_1) # Invalid with pytest.raises(ValidationError): - validate_proposer_slashing_epoch(invalid_proposer_slashing, config.SLOTS_PER_EPOCH) + validate_proposer_slashing_epoch( + invalid_proposer_slashing, config.SLOTS_PER_EPOCH + ) -def test_validate_proposer_slashing_headers(genesis_state, - keymap, - config): +def test_validate_proposer_slashing_headers(genesis_state, keymap, config): state = genesis_state - valid_proposer_slashing = get_valid_proposer_slashing( - state, - keymap, - config, - ) + valid_proposer_slashing = get_valid_proposer_slashing(state, keymap, config) # Valid validate_proposer_slashing_headers(valid_proposer_slashing) invalid_proposer_slashing = valid_proposer_slashing.copy( - header_1=valid_proposer_slashing.header_2, + header_1=valid_proposer_slashing.header_2 ) # Invalid @@ -88,17 +63,12 @@ def test_validate_proposer_slashing_headers(genesis_state, validate_proposer_slashing_headers(invalid_proposer_slashing) -def test_validate_block_header_signature(slots_per_epoch, - genesis_state, - keymap, - config): +def test_validate_block_header_signature( + slots_per_epoch, genesis_state, keymap, config +): state = genesis_state proposer_index = 0 - valid_proposer_slashing = get_valid_proposer_slashing( - state, - keymap, - config, - ) + valid_proposer_slashing = get_valid_proposer_slashing(state, keymap, config) proposer = state.validators[proposer_index] # Valid diff --git a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_validation.py b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_validation.py index db84ccca8a..a770a635a2 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_validation.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_validation.py @@ -1,50 +1,33 @@ +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) + from eth2._utils.bls import bls -from eth2.configs import ( - CommitteeConfig, -) -from eth2.beacon.signature_domain import ( - SignatureDomain, -) -from eth2.beacon.helpers import ( - get_domain, - compute_start_slot_of_epoch, -) +from eth2.beacon.helpers import compute_start_slot_of_epoch, get_domain +from eth2.beacon.signature_domain import SignatureDomain from eth2.beacon.state_machines.forks.serenity.block_validation import ( validate_block_slot, validate_proposer_signature, validate_randao_reveal, ) +from eth2.beacon.tools.builder.initializer import create_mock_validator from eth2.beacon.types.blocks import BeaconBlock from eth2.beacon.types.states import BeaconState - -from eth2.beacon.tools.builder.initializer import create_mock_validator +from eth2.configs import CommitteeConfig @pytest.mark.parametrize( - 'state_slot,' - 'block_slot,' - 'expected', - ( - (10, 10, None), - (1, 10, ValidationError()), - (10, 1, ValidationError()), - ), + "state_slot," "block_slot," "expected", + ((10, 10, None), (1, 10, ValidationError()), (10, 1, ValidationError())), ) -def test_validate_block_slot(sample_beacon_state_params, - sample_beacon_block_params, - state_slot, - block_slot, - expected): - state = BeaconState(**sample_beacon_state_params).copy( - slot=state_slot, - ) - block = BeaconBlock(**sample_beacon_block_params).copy( - slot=block_slot, - ) +def test_validate_block_slot( + sample_beacon_state_params, + sample_beacon_block_params, + state_slot, + block_slot, + expected, +): + state = BeaconState(**sample_beacon_state_params).copy(slot=state_slot) + block = BeaconBlock(**sample_beacon_block_params).copy(slot=block_slot) if isinstance(expected, Exception): with pytest.raises(ValidationError): validate_block_slot(state, block) @@ -53,31 +36,31 @@ def test_validate_block_slot(sample_beacon_state_params, @pytest.mark.parametrize( - 'slots_per_epoch, shard_count,' - 'proposer_privkey, proposer_pubkey, is_valid_signature', + "slots_per_epoch, shard_count," + "proposer_privkey, proposer_pubkey, is_valid_signature", ( - (5, 5, 56, bls.privtopub(56), True, ), - (5, 5, 56, bls.privtopub(56)[1:] + b'\x01', False), + (5, 5, 56, bls.privtopub(56), True), + (5, 5, 56, bls.privtopub(56)[1:] + b"\x01", False), (5, 5, 123, bls.privtopub(123), True), - (5, 5, 123, bls.privtopub(123)[1:] + b'\x01', False), - ) + (5, 5, 123, bls.privtopub(123)[1:] + b"\x01", False), + ), ) def test_validate_proposer_signature( - slots_per_epoch, - shard_count, - proposer_privkey, - proposer_pubkey, - is_valid_signature, - sample_beacon_block_params, - sample_beacon_state_params, - target_committee_size, - max_effective_balance, - config): + slots_per_epoch, + shard_count, + proposer_privkey, + proposer_pubkey, + is_valid_signature, + sample_beacon_block_params, + sample_beacon_state_params, + target_committee_size, + max_effective_balance, + config, +): state = BeaconState(**sample_beacon_state_params).copy( validators=tuple( - create_mock_validator(proposer_pubkey, config) - for _ in range(10) + create_mock_validator(proposer_pubkey, config) for _ in range(10) ), balances=(max_effective_balance,) * 10, ) @@ -90,63 +73,50 @@ def test_validate_proposer_signature( message_hash=header.signing_root, privkey=proposer_privkey, domain=get_domain( - state, - SignatureDomain.DOMAIN_BEACON_PROPOSER, - slots_per_epoch, + state, SignatureDomain.DOMAIN_BEACON_PROPOSER, slots_per_epoch ), - ), + ) ) if is_valid_signature: - validate_proposer_signature( - state, - proposed_block, - CommitteeConfig(config), - ) + validate_proposer_signature(state, proposed_block, CommitteeConfig(config)) else: with pytest.raises(ValidationError): - validate_proposer_signature( - state, - proposed_block, - CommitteeConfig(config), - ) + validate_proposer_signature(state, proposed_block, CommitteeConfig(config)) @pytest.mark.parametrize( - ["is_valid", "epoch", "expected_epoch", "proposer_key_index", "expected_proposer_key_index"], - ( - (True, 0, 0, 0, 0), - (True, 1, 1, 1, 1), - (False, 0, 1, 0, 0), - (False, 0, 0, 0, 1), - ) + [ + "is_valid", + "epoch", + "expected_epoch", + "proposer_key_index", + "expected_proposer_key_index", + ], + ((True, 0, 0, 0, 0), (True, 1, 1, 1, 1), (False, 0, 1, 0, 0), (False, 0, 0, 0, 1)), ) -def test_randao_reveal_validation(is_valid, - epoch, - expected_epoch, - proposer_key_index, - expected_proposer_key_index, - privkeys, - pubkeys, - sample_fork_params, - genesis_state, - config): +def test_randao_reveal_validation( + is_valid, + epoch, + expected_epoch, + proposer_key_index, + expected_proposer_key_index, + privkeys, + pubkeys, + sample_fork_params, + genesis_state, + config, +): state = genesis_state.copy( - slot=compute_start_slot_of_epoch(epoch, config.SLOTS_PER_EPOCH), + slot=compute_start_slot_of_epoch(epoch, config.SLOTS_PER_EPOCH) ) message_hash = epoch.to_bytes(32, byteorder="little") slots_per_epoch = config.SLOTS_PER_EPOCH - domain = get_domain( - state, - SignatureDomain.DOMAIN_RANDAO, - slots_per_epoch, - ) + domain = get_domain(state, SignatureDomain.DOMAIN_RANDAO, slots_per_epoch) proposer_privkey = privkeys[proposer_key_index] randao_reveal = bls.sign( - message_hash=message_hash, - privkey=proposer_privkey, - domain=domain, + message_hash=message_hash, privkey=proposer_privkey, domain=domain ) try: diff --git a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_voluntary_exit_validation.py b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_voluntary_exit_validation.py index d98403ee6e..f915507630 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_voluntary_exit_validation.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_serenity_block_voluntary_exit_validation.py @@ -1,94 +1,59 @@ +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) - -from eth2.beacon.constants import ( - FAR_FUTURE_EPOCH, -) -from eth2.beacon.helpers import ( - compute_start_slot_of_epoch, -) +from eth2.beacon.constants import FAR_FUTURE_EPOCH +from eth2.beacon.helpers import compute_start_slot_of_epoch from eth2.beacon.state_machines.forks.serenity.block_validation import ( - _validate_validator_has_not_exited, _validate_eligible_exit_epoch, + _validate_validator_has_not_exited, _validate_validator_minimum_lifespan, _validate_voluntary_exit_signature, validate_voluntary_exit, ) -from eth2.beacon.tools.builder.validator import ( - create_mock_voluntary_exit, -) +from eth2.beacon.tools.builder.validator import create_mock_voluntary_exit @pytest.mark.parametrize( ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - 'persistent_committee_period', + "validator_count", + "slots_per_epoch", + "target_committee_size", + "persistent_committee_period", ), - [ - (40, 2, 2, 16), - ] + [(40, 2, 2, 16)], ) -def test_validate_voluntary_exit(genesis_state, - keymap, - slots_per_epoch, - persistent_committee_period, - config): +def test_validate_voluntary_exit( + genesis_state, keymap, slots_per_epoch, persistent_committee_period, config +): state = genesis_state.copy( slot=compute_start_slot_of_epoch( - config.GENESIS_EPOCH + persistent_committee_period, - slots_per_epoch, - ), + config.GENESIS_EPOCH + persistent_committee_period, slots_per_epoch + ) ) validator_index = 0 valid_voluntary_exit = create_mock_voluntary_exit( - state, - config, - keymap, - validator_index, + state, config, keymap, validator_index ) validate_voluntary_exit( - state, - valid_voluntary_exit, - slots_per_epoch, - persistent_committee_period, + state, valid_voluntary_exit, slots_per_epoch, persistent_committee_period ) @pytest.mark.parametrize( - ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - ), - [ - (40, 2, 2), - ] + ("validator_count", "slots_per_epoch", "target_committee_size"), [(40, 2, 2)] ) @pytest.mark.parametrize( - ( - 'validator_exit_epoch', - 'success', - ), - [ - (FAR_FUTURE_EPOCH, True), - (FAR_FUTURE_EPOCH - 1, False), - ] + ("validator_exit_epoch", "success"), + [(FAR_FUTURE_EPOCH, True), (FAR_FUTURE_EPOCH - 1, False)], ) -def test_validate_validator_has_not_exited(genesis_state, - validator_exit_epoch, - success): +def test_validate_validator_has_not_exited( + genesis_state, validator_exit_epoch, success +): state = genesis_state validator_index = 0 - validator = state.validators[validator_index].copy( - exit_epoch=validator_exit_epoch, - ) + validator = state.validators[validator_index].copy(exit_epoch=validator_exit_epoch) if success: _validate_validator_has_not_exited(validator) @@ -98,95 +63,65 @@ def test_validate_validator_has_not_exited(genesis_state, @pytest.mark.parametrize( - ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - ), - [ - (40, 2, 2), - ] + ("validator_count", "slots_per_epoch", "target_committee_size"), [(40, 2, 2)] ) @pytest.mark.parametrize( - ( - 'activation_exit_delay', - 'current_epoch', - 'voluntary_exit_epoch', - 'success', - ), - [ - (4, 8, 8, True), - (4, 8, 8 + 1, False), - ] + ("activation_exit_delay", "current_epoch", "voluntary_exit_epoch", "success"), + [(4, 8, 8, True), (4, 8, 8 + 1, False)], ) -def test_validate_eligible_exit_epoch(genesis_state, - keymap, - current_epoch, - voluntary_exit_epoch, - slots_per_epoch, - config, - success): +def test_validate_eligible_exit_epoch( + genesis_state, + keymap, + current_epoch, + voluntary_exit_epoch, + slots_per_epoch, + config, + success, +): state = genesis_state.copy( - slot=compute_start_slot_of_epoch(current_epoch, slots_per_epoch), + slot=compute_start_slot_of_epoch(current_epoch, slots_per_epoch) ) validator_index = 0 voluntary_exit = create_mock_voluntary_exit( - state, - config, - keymap, - validator_index, - exit_epoch=voluntary_exit_epoch, + state, config, keymap, validator_index, exit_epoch=voluntary_exit_epoch ) if success: _validate_eligible_exit_epoch( - voluntary_exit.epoch, - state.current_epoch(slots_per_epoch), + voluntary_exit.epoch, state.current_epoch(slots_per_epoch) ) else: with pytest.raises(ValidationError): _validate_eligible_exit_epoch( - voluntary_exit.epoch, - state.current_epoch(slots_per_epoch), + voluntary_exit.epoch, state.current_epoch(slots_per_epoch) ) @pytest.mark.parametrize( - ( - 'current_epoch', - 'persistent_committee_period', - 'activation_epoch', - 'success', - ), - [ - (16, 4, 16 - 4, True), - (16, 4, 16 - 4 + 1, False), - ] + ("current_epoch", "persistent_committee_period", "activation_epoch", "success"), + [(16, 4, 16 - 4, True), (16, 4, 16 - 4 + 1, False)], ) -def test_validate_validator_minimum_lifespan(genesis_state, - keymap, - current_epoch, - activation_epoch, - slots_per_epoch, - persistent_committee_period, - success): +def test_validate_validator_minimum_lifespan( + genesis_state, + keymap, + current_epoch, + activation_epoch, + slots_per_epoch, + persistent_committee_period, + success, +): state = genesis_state.copy( - slot=compute_start_slot_of_epoch( - current_epoch, - slots_per_epoch - ), + slot=compute_start_slot_of_epoch(current_epoch, slots_per_epoch) ) validator_index = 0 validator = state.validators[validator_index].copy( - activation_epoch=activation_epoch, + activation_epoch=activation_epoch ) state = state.update_validator(validator_index, validator) if success: _validate_validator_minimum_lifespan( - validator, - state.current_epoch(slots_per_epoch), - persistent_committee_period, + validator, state.current_epoch(slots_per_epoch), persistent_committee_period ) else: with pytest.raises(ValidationError): @@ -199,44 +134,28 @@ def test_validate_validator_minimum_lifespan(genesis_state, @pytest.mark.parametrize( ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - 'activation_exit_delay', + "validator_count", + "slots_per_epoch", + "target_committee_size", + "activation_exit_delay", ), - [ - (40, 2, 2, 2), - ] + [(40, 2, 2, 2)], ) -@pytest.mark.parametrize( - ( - 'success', - ), - [ - (True,), - (False,), - ] -) -def test_validate_voluntary_exit_signature(genesis_state, - keymap, - config, - success): +@pytest.mark.parametrize(("success",), [(True,), (False,)]) +def test_validate_voluntary_exit_signature(genesis_state, keymap, config, success): slots_per_epoch = config.SLOTS_PER_EPOCH state = genesis_state validator_index = 0 - voluntary_exit = create_mock_voluntary_exit( - state, - config, - keymap, - validator_index, - ) + voluntary_exit = create_mock_voluntary_exit(state, config, keymap, validator_index) validator = state.validators[validator_index] if success: - _validate_voluntary_exit_signature(state, voluntary_exit, validator, slots_per_epoch) + _validate_voluntary_exit_signature( + state, voluntary_exit, validator, slots_per_epoch + ) else: # Use wrong signature - voluntary_exit = voluntary_exit.copy( - signature=b'\x12' * 96, # wrong signature - ) + voluntary_exit = voluntary_exit.copy(signature=b"\x12" * 96) # wrong signature with pytest.raises(ValidationError): - _validate_voluntary_exit_signature(state, voluntary_exit, validator, slots_per_epoch) + _validate_voluntary_exit_signature( + state, voluntary_exit, validator, slots_per_epoch + ) diff --git a/tests/eth2/core/beacon/state_machines/forks/test_serenity_epoch_processing.py b/tests/eth2/core/beacon/state_machines/forks/test_serenity_epoch_processing.py index 3f018e2342..cf9e372b24 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_serenity_epoch_processing.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_serenity_epoch_processing.py @@ -1,78 +1,62 @@ import random -import pytest +import pytest import ssz -from eth2._utils.bitfield import ( - set_voted, - get_empty_bitfield, -) -from eth2.configs import ( - CommitteeConfig, -) +from eth2._utils.bitfield import get_empty_bitfield, set_voted from eth2.beacon.committee_helpers import ( get_beacon_proposer_index, - get_start_shard, get_shard_delta, + get_start_shard, ) from eth2.beacon.constants import ( FAR_FUTURE_EPOCH, GWEI_PER_ETH, JUSTIFICATION_BITS_LENGTH, ) +from eth2.beacon.epoch_processing_helpers import get_base_reward from eth2.beacon.helpers import ( + compute_epoch_of_slot, + compute_start_slot_of_epoch, get_active_validator_indices, get_block_root, get_block_root_at_slot, - compute_start_slot_of_epoch, - compute_epoch_of_slot, ) -from eth2.beacon.epoch_processing_helpers import ( - get_base_reward, -) -from eth2.beacon.types.attestation_data import AttestationData -from eth2.beacon.types.checkpoints import Checkpoint -from eth2.beacon.types.crosslinks import Crosslink -from eth2.beacon.types.pending_attestations import PendingAttestation -from eth2.beacon.typing import Gwei from eth2.beacon.state_machines.forks.serenity.epoch_processing import ( _bft_threshold_met, + _compute_next_active_index_roots, _determine_new_finalized_epoch, _determine_slashing_penalty, compute_activation_exit_epoch, + get_attestation_deltas, + get_crosslink_deltas, process_crosslinks, process_justification_and_finalization, - process_slashings, - _compute_next_active_index_roots, process_registry_updates, - get_crosslink_deltas, - get_attestation_deltas, + process_slashings, ) - -from eth2.beacon.types.validators import Validator from eth2.beacon.tools.builder.validator import ( + get_crosslink_committees_at_slot, mk_all_pending_attestations_with_full_participation_in_epoch, mk_all_pending_attestations_with_some_participation_in_epoch, - get_crosslink_committees_at_slot, ) +from eth2.beacon.types.attestation_data import AttestationData +from eth2.beacon.types.checkpoints import Checkpoint +from eth2.beacon.types.crosslinks import Crosslink +from eth2.beacon.types.pending_attestations import PendingAttestation +from eth2.beacon.types.validators import Validator +from eth2.beacon.typing import Gwei +from eth2.configs import CommitteeConfig @pytest.mark.parametrize( - "total_balance," - "attesting_balance," - "expected,", + "total_balance," "attesting_balance," "expected,", ( - ( - 1500 * GWEI_PER_ETH, 1000 * GWEI_PER_ETH, True, - ), - ( - 1500 * GWEI_PER_ETH, 999 * GWEI_PER_ETH, False, - ), - ) + (1500 * GWEI_PER_ETH, 1000 * GWEI_PER_ETH, True), + (1500 * GWEI_PER_ETH, 999 * GWEI_PER_ETH, False), + ), ) -def test_bft_threshold_met(attesting_balance, - total_balance, - expected): +def test_bft_threshold_met(attesting_balance, total_balance, expected): assert _bft_threshold_met(attesting_balance, total_balance) == expected @@ -93,26 +77,26 @@ def test_bft_threshold_met(attesting_balance, # No finalize ((False, False, False, False), 2, 2, 1), ((True, True, True, True), 2, 2, 1), - ) + ), ) -def test_get_finalized_epoch(justification_bits, - previous_justified_epoch, - current_justified_epoch, - expected): +def test_get_finalized_epoch( + justification_bits, previous_justified_epoch, current_justified_epoch, expected +): current_epoch = 6 finalized_epoch = 1 - assert _determine_new_finalized_epoch( - finalized_epoch, - previous_justified_epoch, - current_justified_epoch, - current_epoch, - justification_bits, - ) == expected + assert ( + _determine_new_finalized_epoch( + finalized_epoch, + previous_justified_epoch, + current_justified_epoch, + current_epoch, + justification_bits, + ) + == expected + ) -def test_justification_without_mock(genesis_state, - slots_per_historical_root, - config): +def test_justification_without_mock(genesis_state, slots_per_historical_root, config): state = genesis_state state = process_justification_and_finalization(state, config) @@ -161,18 +145,20 @@ def _convert_to_bitfield(bits): (5, False, True, 2, 3, 0b11110, 1, 4, 0b111110, 2), # R1 finalize 2 ), ) -def test_process_justification_and_finalization(genesis_state, - current_epoch, - current_epoch_justifiable, - previous_epoch_justifiable, - previous_justified_epoch, - current_justified_epoch, - justification_bits, - finalized_epoch, - justified_epoch_after, - justification_bits_after, - finalized_epoch_after, - config): +def test_process_justification_and_finalization( + genesis_state, + current_epoch, + current_epoch_justifiable, + previous_epoch_justifiable, + previous_justified_epoch, + current_justified_epoch, + justification_bits, + finalized_epoch, + justified_epoch_after, + justification_bits_after, + finalized_epoch_after, + config, +): justification_bits = _convert_to_bitfield(justification_bits) justification_bits_after = _convert_to_bitfield(justification_bits_after) previous_epoch = max(current_epoch - 1, 0) @@ -180,99 +166,53 @@ def test_process_justification_and_finalization(genesis_state, state = genesis_state.copy( slot=slot, - previous_justified_checkpoint=Checkpoint( - epoch=previous_justified_epoch, - ), - current_justified_checkpoint=Checkpoint( - epoch=current_justified_epoch, - ), + previous_justified_checkpoint=Checkpoint(epoch=previous_justified_epoch), + current_justified_checkpoint=Checkpoint(epoch=current_justified_epoch), justification_bits=justification_bits, - finalized_checkpoint=Checkpoint( - epoch=finalized_epoch, - ), + finalized_checkpoint=Checkpoint(epoch=finalized_epoch), block_roots=tuple( - i.to_bytes(32, "little") - for i in range(config.SLOTS_PER_HISTORICAL_ROOT) + i.to_bytes(32, "little") for i in range(config.SLOTS_PER_HISTORICAL_ROOT) ), ) if previous_epoch_justifiable: attestations = mk_all_pending_attestations_with_full_participation_in_epoch( - state, - previous_epoch, - config, - ) - state = state.copy( - previous_epoch_attestations=attestations, + state, previous_epoch, config ) + state = state.copy(previous_epoch_attestations=attestations) if current_epoch_justifiable: attestations = mk_all_pending_attestations_with_full_participation_in_epoch( - state, - current_epoch, - config, - ) - state = state.copy( - current_epoch_attestations=attestations, + state, current_epoch, config ) + state = state.copy(current_epoch_attestations=attestations) post_state = process_justification_and_finalization(state, config) assert ( - post_state.previous_justified_checkpoint.epoch == state.current_justified_checkpoint.epoch + post_state.previous_justified_checkpoint.epoch + == state.current_justified_checkpoint.epoch ) assert post_state.current_justified_checkpoint.epoch == justified_epoch_after assert post_state.justification_bits == justification_bits_after assert post_state.finalized_checkpoint.epoch == finalized_epoch_after +@pytest.mark.parametrize(("slots_per_epoch," "shard_count,"), [(10, 10)]) @pytest.mark.parametrize( - ( - 'slots_per_epoch,' - 'shard_count,' - ), - [ - ( - 10, - 10, - ), - ] -) -@pytest.mark.parametrize( - ( - 'success_in_previous_epoch,' - 'success_in_current_epoch,' - ), - [ - ( - False, - False, - ), - ( - True, - False, - ), - ( - False, - True, - ), - ] + ("success_in_previous_epoch," "success_in_current_epoch,"), + [(False, False), (True, False), (False, True)], ) -def test_process_crosslinks(genesis_state, - config, - success_in_previous_epoch, - success_in_current_epoch): +def test_process_crosslinks( + genesis_state, config, success_in_previous_epoch, success_in_current_epoch +): shard_count = config.SHARD_COUNT current_slot = config.SLOTS_PER_EPOCH * 5 - 1 current_epoch = compute_epoch_of_slot(current_slot, config.SLOTS_PER_EPOCH) assert current_epoch - 4 >= 0 previous_crosslinks = tuple( - Crosslink( - shard=i, - start_epoch=current_epoch - 4, - end_epoch=current_epoch - 3, - ) + Crosslink(shard=i, start_epoch=current_epoch - 4, end_epoch=current_epoch - 3) for i in range(shard_count) ) parent_crosslinks = tuple( @@ -306,10 +246,7 @@ def test_process_crosslinks(genesis_state, expected_success_shards = set() previous_epoch_attestations = tuple( mk_all_pending_attestations_with_some_participation_in_epoch( - state, - previous_epoch, - config, - 0.7 if success_in_previous_epoch else 0, + state, previous_epoch, config, 0.7 if success_in_previous_epoch else 0 ) ) if success_in_previous_epoch: @@ -318,10 +255,7 @@ def test_process_crosslinks(genesis_state, current_epoch_attestations = tuple( mk_all_pending_attestations_with_some_participation_in_epoch( - state, - current_epoch, - config, - 0.7 if success_in_current_epoch else 0, + state, current_epoch, config, 0.7 if success_in_current_epoch else 0 ) ) if success_in_current_epoch: @@ -351,60 +285,30 @@ def test_process_crosslinks(genesis_state, # TODO better testing on attestation deltas +@pytest.mark.parametrize(("validator_count,"), [(10)]) @pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - ( - 10 - ) - ] -) -@pytest.mark.parametrize( - ( - "finalized_epoch", - "current_slot", - ), - [ - ( - 4, - 384, # epochs_since_finality <= 4 - ), - ( - 3, - 512, # epochs_since_finality > 4 - ), - ] + ("finalized_epoch", "current_slot"), + [(4, 384), (3, 512)], # epochs_since_finality <= 4 # epochs_since_finality > 4 ) -def test_get_attestation_deltas(genesis_state, - config, - slots_per_epoch, - target_committee_size, - shard_count, - min_attestation_inclusion_delay, - inactivity_penalty_quotient, - finalized_epoch, - current_slot, - sample_pending_attestation_record_params, - sample_attestation_data_params): +def test_get_attestation_deltas( + genesis_state, + config, + slots_per_epoch, + target_committee_size, + shard_count, + min_attestation_inclusion_delay, + inactivity_penalty_quotient, + finalized_epoch, + current_slot, + sample_pending_attestation_record_params, + sample_attestation_data_params, +): state = genesis_state.copy( - slot=current_slot, - finalized_checkpoint=Checkpoint( - epoch=finalized_epoch, - ) + slot=current_slot, finalized_checkpoint=Checkpoint(epoch=finalized_epoch) ) previous_epoch = state.previous_epoch(config.SLOTS_PER_EPOCH, config.GENESIS_EPOCH) - epoch_start_shard = get_start_shard( - state, - previous_epoch, - CommitteeConfig(config), - ) - shard_delta = get_shard_delta( - state, - previous_epoch, - CommitteeConfig(config), - ) + epoch_start_shard = get_start_shard(state, previous_epoch, CommitteeConfig(config)) + shard_delta = get_shard_delta(state, previous_epoch, CommitteeConfig(config)) a = epoch_start_shard b = epoch_start_shard + shard_delta @@ -419,9 +323,7 @@ def test_get_attestation_deltas(genesis_state, prev_epoch_attestations = tuple() for slot in range(prev_epoch_start_slot, prev_epoch_start_slot + slots_per_epoch): committee, shard = get_crosslink_committees_at_slot( - state, - slot, - CommitteeConfig(config), + state, slot, CommitteeConfig(config) )[0] if not committee: continue @@ -436,15 +338,10 @@ def test_get_attestation_deltas(genesis_state, aggregation_bits=participants_bitfield, inclusion_delay=min_attestation_inclusion_delay, proposer_index=get_beacon_proposer_index( - state.copy( - slot=slot, - ), - CommitteeConfig(config), + state.copy(slot=slot), CommitteeConfig(config) ), data=AttestationData(**sample_attestation_data_params).copy( - crosslink=Crosslink( - shard=shard, - ), + crosslink=Crosslink(shard=shard), target=Checkpoint( epoch=previous_epoch, root=get_block_root( @@ -455,98 +352,63 @@ def test_get_attestation_deltas(genesis_state, ), ), beacon_block_root=get_block_root_at_slot( - state, - slot, - config.SLOTS_PER_HISTORICAL_ROOT, + state, slot, config.SLOTS_PER_HISTORICAL_ROOT ), ), ), ) - state = state.copy( - previous_epoch_attestations=prev_epoch_attestations, - ) + state = state.copy(previous_epoch_attestations=prev_epoch_attestations) - rewards_received, penalties_received = get_attestation_deltas( - state, - config, - ) + rewards_received, penalties_received = get_attestation_deltas(state, config) # everyone attested, no penalties - assert(sum(penalties_received) == 0) + assert sum(penalties_received) == 0 the_reward = rewards_received[0] # everyone performed the same, equal rewards - assert(sum(rewards_received) // len(rewards_received) == the_reward) + assert sum(rewards_received) // len(rewards_received) == the_reward @pytest.mark.parametrize( ( - 'validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - 'current_slot,' - 'num_attesting_validators,' - 'genesis_slot,' + "validator_count," + "slots_per_epoch," + "target_committee_size," + "shard_count," + "current_slot," + "num_attesting_validators," + "genesis_slot," ), - [ - ( - 50, - 10, - 5, - 10, - 100, - 3, - 0, - ), - ( - 50, - 10, - 5, - 10, - 100, - 4, - 0, - ), - ] + [(50, 10, 5, 10, 100, 3, 0), (50, 10, 5, 10, 100, 4, 0)], ) -def test_process_rewards_and_penalties_for_crosslinks(genesis_state, - config, - slots_per_epoch, - target_committee_size, - shard_count, - current_slot, - num_attesting_validators, - max_effective_balance, - min_attestation_inclusion_delay, - sample_attestation_data_params, - sample_pending_attestation_record_params): - state = genesis_state.copy( - slot=current_slot, - ) +def test_process_rewards_and_penalties_for_crosslinks( + genesis_state, + config, + slots_per_epoch, + target_committee_size, + shard_count, + current_slot, + num_attesting_validators, + max_effective_balance, + min_attestation_inclusion_delay, + sample_attestation_data_params, + sample_pending_attestation_record_params, +): + state = genesis_state.copy(slot=current_slot) previous_epoch = state.previous_epoch(config.SLOTS_PER_EPOCH, config.GENESIS_EPOCH) prev_epoch_start_slot = compute_start_slot_of_epoch(previous_epoch, slots_per_epoch) prev_epoch_crosslink_committees = [ - get_crosslink_committees_at_slot( - state, - slot, - CommitteeConfig(config), - )[0] for slot in range(prev_epoch_start_slot, prev_epoch_start_slot + slots_per_epoch) + get_crosslink_committees_at_slot(state, slot, CommitteeConfig(config))[0] + for slot in range( + prev_epoch_start_slot, prev_epoch_start_slot + slots_per_epoch + ) ] # Record which validators attest during each slot for reward collation. each_slot_attestion_validators_list = [] - epoch_start_shard = get_start_shard( - state, - previous_epoch, - CommitteeConfig(config), - ) - shard_delta = get_shard_delta( - state, - previous_epoch, - CommitteeConfig(config), - ) + epoch_start_shard = get_start_shard(state, previous_epoch, CommitteeConfig(config)) + shard_delta = get_shard_delta(state, previous_epoch, CommitteeConfig(config)) a = epoch_start_shard b = epoch_start_shard + shard_delta @@ -566,40 +428,30 @@ def test_process_rewards_and_penalties_for_crosslinks(genesis_state, # Randomly sample `num_attesting_validators` validators # from the committee to attest in this slot. crosslink_attesting_validators = random.sample( - committee, - num_attesting_validators, + committee, num_attesting_validators ) each_slot_attestion_validators_list.append(crosslink_attesting_validators) participants_bitfield = get_empty_bitfield(len(committee)) for index in crosslink_attesting_validators: - participants_bitfield = set_voted(participants_bitfield, committee.index(index)) + participants_bitfield = set_voted( + participants_bitfield, committee.index(index) + ) previous_epoch_attestations.append( PendingAttestation(**sample_pending_attestation_record_params).copy( aggregation_bits=participants_bitfield, data=AttestationData(**sample_attestation_data_params).copy( - target=Checkpoint( - epoch=previous_epoch, - ), + target=Checkpoint(epoch=previous_epoch), crosslink=Crosslink( - shard=shard, - parent_root=Crosslink().hash_tree_root, + shard=shard, parent_root=Crosslink().hash_tree_root ), ), ) ) - state = state.copy( - previous_epoch_attestations=tuple(previous_epoch_attestations), - ) + state = state.copy(previous_epoch_attestations=tuple(previous_epoch_attestations)) - rewards_received, penalties_received = get_crosslink_deltas( - state, - config, - ) + rewards_received, penalties_received = get_crosslink_deltas(state, config) - expected_rewards_received = { - index: 0 - for index in range(len(state.validators)) - } + expected_rewards_received = {index: 0 for index in range(len(state.validators))} validator_balance = max_effective_balance for i in range(slots_per_epoch): crosslink_committee, shard = prev_epoch_crosslink_committees[i] @@ -610,18 +462,14 @@ def test_process_rewards_and_penalties_for_crosslinks(genesis_state, total_committee_balance = len(crosslink_committee) * validator_balance for index in crosslink_committee: if index in attesting_validators: - reward = get_base_reward( - state=state, - index=index, - config=config, - ) * total_attesting_balance // total_committee_balance + reward = ( + get_base_reward(state=state, index=index, config=config) + * total_attesting_balance + // total_committee_balance + ) expected_rewards_received[index] += reward else: - penalty = get_base_reward( - state=state, - index=index, - config=config, - ) + penalty = get_base_reward(state=state, index=index, config=config) expected_rewards_received[index] -= penalty # Check the rewards/penalties match @@ -629,46 +477,36 @@ def test_process_rewards_and_penalties_for_crosslinks(genesis_state, if index not in indices_to_check: continue assert ( - rewards_received[index] - penalties_received[index] == expected_rewards_received[index] + rewards_received[index] - penalties_received[index] + == expected_rewards_received[index] ) @pytest.mark.parametrize( - ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - 'shard_count', - ), - [ - ( - 10, - 10, - 9, - 10, - ), - ] + ("validator_count", "slots_per_epoch", "target_committee_size", "shard_count"), + [(10, 10, 9, 10)], ) -def test_process_registry_updates(validator_count, - genesis_state, - config, - slots_per_epoch): +def test_process_registry_updates( + validator_count, genesis_state, config, slots_per_epoch +): activation_index = len(genesis_state.validators) exiting_index = len(genesis_state.validators) - 1 activating_validator = Validator.create_pending_validator( - pubkey=b'\x10' * 48, - withdrawal_credentials=b'\x11' * 32, + pubkey=b"\x10" * 48, + withdrawal_credentials=b"\x11" * 32, amount=Gwei(32 * GWEI_PER_ETH), config=config, ) state = genesis_state.copy( - validators=genesis_state.validators[:exiting_index] + ( + validators=genesis_state.validators[:exiting_index] + + ( genesis_state.validators[exiting_index].copy( - effective_balance=config.EJECTION_BALANCE - 1, + effective_balance=config.EJECTION_BALANCE - 1 ), - ) + (activating_validator,), + ) + + (activating_validator,), balances=genesis_state.balances + (config.MAX_EFFECTIVE_BALANCE,), ) @@ -682,8 +520,7 @@ def test_process_registry_updates(validator_count, assert pre_activation_validator.activation_epoch == FAR_FUTURE_EPOCH assert post_activation_validator.activation_eligibility_epoch != FAR_FUTURE_EPOCH activation_epoch = compute_activation_exit_epoch( - state.current_epoch(config.SLOTS_PER_EPOCH), - config.ACTIVATION_EXIT_DELAY, + state.current_epoch(config.SLOTS_PER_EPOCH), config.ACTIVATION_EXIT_DELAY ) assert post_activation_validator.is_active(activation_epoch) # Check if the activating_validator is exited @@ -699,51 +536,52 @@ def test_process_registry_updates(validator_count, @pytest.mark.parametrize( ( - 'validator_count', - 'slots_per_epoch', - 'genesis_slot', - 'current_epoch', - 'epochs_per_slashings_vector', + "validator_count", + "slots_per_epoch", + "genesis_slot", + "current_epoch", + "epochs_per_slashings_vector", ), - [ - ( - 10, 4, 8, 8, 8, - ) - ] + [(10, 4, 8, 8, 8)], ) @pytest.mark.parametrize( - ( - 'total_penalties', - 'total_balance', - 'expected_penalty', - ), + ("total_penalties", "total_balance", "expected_penalty"), [ # total_penalties * 3 is less than total_balance ( - 32 * 10**9, # 1 ETH - (32 * 10**9 * 10), + 32 * 10 ** 9, # 1 ETH + (32 * 10 ** 9 * 10), # effective_balance * total_penalties * 3 // total_balance - ((32 * 10**9) // 10**9) * (3 * 32 * 10**9) // (32 * 10**9 * 10) * 10**9, + ((32 * 10 ** 9) // 10 ** 9) + * (3 * 32 * 10 ** 9) + // (32 * 10 ** 9 * 10) + * 10 ** 9, ), # total_balance is less than total_penalties * 3 ( - 32 * 4 * 10**9, - (32 * 10**9 * 10), + 32 * 4 * 10 ** 9, + (32 * 10 ** 9 * 10), # effective_balance * total_balance // total_balance, - (32 * 10**9) // 10**9 * (32 * 10**9 * 10) // (32 * 10**9 * 10) * 10**9, + (32 * 10 ** 9) + // 10 ** 9 + * (32 * 10 ** 9 * 10) + // (32 * 10 ** 9 * 10) + * 10 ** 9, ), - ] + ], ) -def test_determine_slashing_penalty(genesis_state, - config, - slots_per_epoch, - current_epoch, - epochs_per_slashings_vector, - total_penalties, - total_balance, - expected_penalty): +def test_determine_slashing_penalty( + genesis_state, + config, + slots_per_epoch, + current_epoch, + epochs_per_slashings_vector, + total_penalties, + total_balance, + expected_penalty, +): state = genesis_state.copy( - slot=compute_start_slot_of_epoch(current_epoch, slots_per_epoch), + slot=compute_start_slot_of_epoch(current_epoch, slots_per_epoch) ) # if the size of the v-set changes then update the parameters above assert len(state.validators) == 10 @@ -759,12 +597,12 @@ def test_determine_slashing_penalty(genesis_state, @pytest.mark.parametrize( ( - 'validator_count', - 'slots_per_epoch', - 'current_epoch', - 'epochs_per_slashings_vector', - 'slashings', - 'expected_penalty', + "validator_count", + "slots_per_epoch", + "current_epoch", + "epochs_per_slashings_vector", + "slashings", + "expected_penalty", ), [ ( @@ -772,18 +610,20 @@ def test_determine_slashing_penalty(genesis_state, 4, 8, 8, - (19 * 10**9, 10**9) + (0,) * 6, - (32 * 10**9 // 10**9 * 60 * 10**9) // (320 * 10**9) * 10**9, - ), - ] + (19 * 10 ** 9, 10 ** 9) + (0,) * 6, + (32 * 10 ** 9 // 10 ** 9 * 60 * 10 ** 9) // (320 * 10 ** 9) * 10 ** 9, + ) + ], ) -def test_process_slashings(genesis_state, - config, - current_epoch, - slashings, - slots_per_epoch, - epochs_per_slashings_vector, - expected_penalty): +def test_process_slashings( + genesis_state, + config, + current_epoch, + slashings, + slots_per_epoch, + epochs_per_slashings_vector, + expected_penalty, +): state = genesis_state.copy( slot=compute_start_slot_of_epoch(current_epoch, slots_per_epoch), slashings=slashings, @@ -791,50 +631,40 @@ def test_process_slashings(genesis_state, slashing_validator_index = 0 validator = state.validators[slashing_validator_index].copy( slashed=True, - withdrawable_epoch=current_epoch + epochs_per_slashings_vector // 2 + withdrawable_epoch=current_epoch + epochs_per_slashings_vector // 2, ) state = state.update_validator(slashing_validator_index, validator) result_state = process_slashings(state, config) penalty = ( - state.balances[slashing_validator_index] - - result_state.balances[slashing_validator_index] + state.balances[slashing_validator_index] + - result_state.balances[slashing_validator_index] ) assert penalty == expected_penalty @pytest.mark.parametrize( - ( - 'slots_per_epoch,' - 'epochs_per_historical_vector,' - 'state_slot,' - ), - [ - (4, 16, 4), - (4, 16, 64), - ] + ("slots_per_epoch," "epochs_per_historical_vector," "state_slot,"), + [(4, 16, 4), (4, 16, 64)], ) -def test_update_active_index_roots(genesis_state, - config, - state_slot, - slots_per_epoch, - epochs_per_historical_vector, - activation_exit_delay): - state = genesis_state.copy( - slot=state_slot, - ) +def test_update_active_index_roots( + genesis_state, + config, + state_slot, + slots_per_epoch, + epochs_per_historical_vector, + activation_exit_delay, +): + state = genesis_state.copy(slot=state_slot) result = _compute_next_active_index_roots(state, config) index_root = ssz.get_hash_tree_root( get_active_validator_indices( - state.validators, - compute_epoch_of_slot(state.slot, slots_per_epoch), + state.validators, compute_epoch_of_slot(state.slot, slots_per_epoch) ), ssz.sedes.List(ssz.uint64, config.VALIDATOR_REGISTRY_LIMIT), ) target_epoch = state.next_epoch(slots_per_epoch) + activation_exit_delay - assert result[ - target_epoch % epochs_per_historical_vector - ] == index_root + assert result[target_epoch % epochs_per_historical_vector] == index_root diff --git a/tests/eth2/core/beacon/state_machines/forks/test_serenity_operation_processing.py b/tests/eth2/core/beacon/state_machines/forks/test_serenity_operation_processing.py index 0b72ce4132..828406b870 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_serenity_operation_processing.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_serenity_operation_processing.py @@ -1,73 +1,47 @@ +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) - -from eth2.configs import ( - CommitteeConfig, -) -from eth2.beacon.committee_helpers import ( - get_beacon_proposer_index, -) -from eth2.beacon.constants import ( - FAR_FUTURE_EPOCH, -) -from eth2.beacon.helpers import ( - compute_start_slot_of_epoch, -) -from eth2.beacon.types.blocks import ( - BeaconBlockBody, -) -from eth2.beacon.types.crosslinks import Crosslink -from eth2.beacon.state_machines.forks.serenity.blocks import ( - SerenityBeaconBlock, -) +from eth2.beacon.committee_helpers import get_beacon_proposer_index +from eth2.beacon.constants import FAR_FUTURE_EPOCH +from eth2.beacon.helpers import compute_start_slot_of_epoch +from eth2.beacon.state_machines.forks.serenity.blocks import SerenityBeaconBlock from eth2.beacon.state_machines.forks.serenity.operation_processing import ( process_attestations, - process_proposer_slashings, process_attester_slashings, + process_proposer_slashings, process_voluntary_exits, ) from eth2.beacon.tools.builder.validator import ( create_mock_attester_slashing_is_double_vote, - create_mock_signed_attestations_at_slot, create_mock_proposer_slashing_at_block, + create_mock_signed_attestations_at_slot, create_mock_voluntary_exit, ) - - -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (100), - ], -) -def test_process_max_attestations(genesis_state, - genesis_block, - sample_beacon_block_params, - sample_beacon_block_body_params, - config, - keymap, - fixture_sm_class, - chaindb, - empty_attestation_pool): +from eth2.beacon.types.blocks import BeaconBlockBody +from eth2.beacon.types.crosslinks import Crosslink +from eth2.configs import CommitteeConfig + + +@pytest.mark.parametrize(("validator_count,"), [(100)]) +def test_process_max_attestations( + genesis_state, + genesis_block, + sample_beacon_block_params, + sample_beacon_block_body_params, + config, + keymap, + fixture_sm_class, + chaindb, + empty_attestation_pool, +): attestation_slot = config.GENESIS_SLOT current_slot = attestation_slot + config.MIN_ATTESTATION_INCLUSION_DELAY - state = genesis_state.copy( - slot=current_slot, - ) + state = genesis_state.copy(slot=current_slot) attestations = create_mock_signed_attestations_at_slot( state=state, config=config, - state_machine=fixture_sm_class( - chaindb, - empty_attestation_pool, - current_slot, - ), + state_machine=fixture_sm_class(chaindb, empty_attestation_pool, current_slot), attestation_slot=attestation_slot, beacon_block_root=genesis_block.signing_root, keymap=keymap, @@ -78,52 +52,44 @@ def test_process_max_attestations(genesis_state, assert attestations_count > 0 block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - attestations=attestations * (config.MAX_ATTESTATIONS // attestations_count + 1), + attestations=attestations * (config.MAX_ATTESTATIONS // attestations_count + 1) ) block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - slot=current_slot, - body=block_body, + slot=current_slot, body=block_body ) with pytest.raises(ValidationError): - process_attestations( - state, - block, - config, - ) + process_attestations(state, block, config) @pytest.mark.parametrize( ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - 'shard_count', - 'block_root_1', - 'block_root_2', - 'success' + "validator_count", + "slots_per_epoch", + "target_committee_size", + "shard_count", + "block_root_1", + "block_root_2", + "success", ), [ - (10, 2, 2, 2, b'\x11' * 32, b'\x22' * 32, True), - (10, 2, 2, 2, b'\x11' * 32, b'\x11' * 32, False), - ] + (10, 2, 2, 2, b"\x11" * 32, b"\x22" * 32, True), + (10, 2, 2, 2, b"\x11" * 32, b"\x11" * 32, False), + ], ) -def test_process_proposer_slashings(genesis_state, - sample_beacon_block_params, - sample_beacon_block_body_params, - config, - keymap, - block_root_1, - block_root_2, - success): +def test_process_proposer_slashings( + genesis_state, + sample_beacon_block_params, + sample_beacon_block_body_params, + config, + keymap, + block_root_1, + block_root_2, + success, +): current_slot = config.GENESIS_SLOT + 1 - state = genesis_state.copy( - slot=current_slot, - ) - whistleblower_index = get_beacon_proposer_index( - state, - CommitteeConfig(config), - ) + state = genesis_state.copy(slot=current_slot) + whistleblower_index = get_beacon_proposer_index(state, CommitteeConfig(config)) slashing_proposer_index = (whistleblower_index + 1) % len(state.validators) proposer_slashing = create_mock_proposer_slashing_at_block( state, @@ -136,156 +102,121 @@ def test_process_proposer_slashings(genesis_state, proposer_slashings = (proposer_slashing,) block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - proposer_slashings=proposer_slashings, + proposer_slashings=proposer_slashings ) block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - slot=current_slot, - body=block_body, + slot=current_slot, body=block_body ) if success: - new_state = process_proposer_slashings( - state, - block, - config, - ) + new_state = process_proposer_slashings(state, block, config) # Check if slashed assert ( - new_state.balances[slashing_proposer_index] < - state.balances[slashing_proposer_index] + new_state.balances[slashing_proposer_index] + < state.balances[slashing_proposer_index] ) else: with pytest.raises(ValidationError): - process_proposer_slashings( - state, - block, - config, - ) + process_proposer_slashings(state, block, config) @pytest.mark.parametrize( ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - 'shard_count', - 'min_attestation_inclusion_delay', + "validator_count", + "slots_per_epoch", + "target_committee_size", + "shard_count", + "min_attestation_inclusion_delay", ), - [ - (100, 2, 2, 2, 1), - ] -) -@pytest.mark.parametrize( - ('success'), - [ - (True), - (False), - ] + [(100, 2, 2, 2, 1)], ) -def test_process_attester_slashings(genesis_state, - sample_beacon_block_params, - sample_beacon_block_body_params, - config, - keymap, - min_attestation_inclusion_delay, - success): +@pytest.mark.parametrize(("success"), [(True), (False)]) +def test_process_attester_slashings( + genesis_state, + sample_beacon_block_params, + sample_beacon_block_body_params, + config, + keymap, + min_attestation_inclusion_delay, + success, +): attesting_state = genesis_state.copy( slot=genesis_state.slot + config.SLOTS_PER_EPOCH, block_roots=tuple( - i.to_bytes(32, "little") - for i in range(config.SLOTS_PER_HISTORICAL_ROOT) - ) + i.to_bytes(32, "little") for i in range(config.SLOTS_PER_HISTORICAL_ROOT) + ), ) valid_attester_slashing = create_mock_attester_slashing_is_double_vote( - attesting_state, - config, - keymap, - attestation_epoch=0, + attesting_state, config, keymap, attestation_epoch=0 ) state = attesting_state.copy( - slot=attesting_state.slot + min_attestation_inclusion_delay, + slot=attesting_state.slot + min_attestation_inclusion_delay ) if success: block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - attester_slashings=(valid_attester_slashing,), + attester_slashings=(valid_attester_slashing,) ) block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - slot=state.slot, - body=block_body, + slot=state.slot, body=block_body ) attester_index = valid_attester_slashing.attestation_1.custody_bit_0_indices[0] - new_state = process_attester_slashings( - state, - block, - config, - ) + new_state = process_attester_slashings(state, block, config) # Check if slashed assert not state.validators[attester_index].slashed assert new_state.validators[attester_index].slashed else: invalid_attester_slashing = valid_attester_slashing.copy( attestation_2=valid_attester_slashing.attestation_2.copy( - data=valid_attester_slashing.attestation_1.data, + data=valid_attester_slashing.attestation_1.data ) ) block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - attester_slashings=(invalid_attester_slashing,), + attester_slashings=(invalid_attester_slashing,) ) block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - slot=state.slot, - body=block_body, + slot=state.slot, body=block_body ) with pytest.raises(ValidationError): - process_attester_slashings( - state, - block, - config, - ) + process_attester_slashings(state, block, config) @pytest.mark.parametrize( ( - 'validator_count,' - 'slots_per_epoch,' - 'min_attestation_inclusion_delay,' - 'target_committee_size,' - 'shard_count,' - 'success,' + "validator_count," + "slots_per_epoch," + "min_attestation_inclusion_delay," + "target_committee_size," + "shard_count," + "success," ), - [ - (10, 2, 1, 2, 2, True), - (10, 2, 1, 2, 2, False), - (40, 4, 2, 3, 5, True), - ] + [(10, 2, 1, 2, 2, True), (10, 2, 1, 2, 2, False), (40, 4, 2, 3, 5, True)], ) -def test_process_attestations(genesis_state, - genesis_block, - sample_beacon_block_params, - sample_beacon_block_body_params, - config, - keymap, - fixture_sm_class, - chaindb, - empty_attestation_pool, - success): +def test_process_attestations( + genesis_state, + genesis_block, + sample_beacon_block_params, + sample_beacon_block_body_params, + config, + keymap, + fixture_sm_class, + chaindb, + empty_attestation_pool, + success, +): attestation_slot = 0 current_slot = attestation_slot + config.MIN_ATTESTATION_INCLUSION_DELAY - state = genesis_state.copy( - slot=current_slot, - ) + state = genesis_state.copy(slot=current_slot) attestations = create_mock_signed_attestations_at_slot( state=state, config=config, state_machine=fixture_sm_class( - chaindb, - empty_attestation_pool, - genesis_block.slot, + chaindb, empty_attestation_pool, genesis_block.slot ), attestation_slot=attestation_slot, beacon_block_root=genesis_block.signing_root, @@ -300,120 +231,88 @@ def test_process_attestations(genesis_state, # i.e. wrong parent invalid_attestation_data = attestations[-1].data.copy( crosslink=attestations[-1].data.crosslink.copy( - parent_root=Crosslink( - shard=333, - ).hash_tree_root, + parent_root=Crosslink(shard=333).hash_tree_root ) ) - invalid_attestation = attestations[-1].copy( - data=invalid_attestation_data, - ) + invalid_attestation = attestations[-1].copy(data=invalid_attestation_data) attestations = attestations[:-1] + (invalid_attestation,) block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - attestations=attestations, + attestations=attestations ) block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - slot=current_slot, - body=block_body, + slot=current_slot, body=block_body ) if success: - new_state = process_attestations( - state, - block, - config, - ) + new_state = process_attestations(state, block, config) assert len(new_state.current_epoch_attestations) == len(attestations) else: with pytest.raises(ValidationError): - process_attestations( - state, - block, - config, - ) + process_attestations(state, block, config) @pytest.mark.parametrize( ( - 'validator_count', - 'slots_per_epoch', - 'target_committee_size', - 'activation_exit_delay', - ), - [ - (40, 2, 2, 2), - ] -) -@pytest.mark.parametrize( - ( - 'success', + "validator_count", + "slots_per_epoch", + "target_committee_size", + "activation_exit_delay", ), - [ - (True,), - (False,), - ] + [(40, 2, 2, 2)], ) -def test_process_voluntary_exits(genesis_state, - sample_beacon_block_params, - sample_beacon_block_body_params, - config, - keymap, - success): +@pytest.mark.parametrize(("success",), [(True,), (False,)]) +def test_process_voluntary_exits( + genesis_state, + sample_beacon_block_params, + sample_beacon_block_body_params, + config, + keymap, + success, +): state = genesis_state.copy( slot=compute_start_slot_of_epoch( config.GENESIS_EPOCH + config.PERSISTENT_COMMITTEE_PERIOD, config.SLOTS_PER_EPOCH, - ), + ) ) validator_index = 0 validator = state.validators[validator_index].copy( - activation_epoch=config.GENESIS_EPOCH, + activation_epoch=config.GENESIS_EPOCH ) state = state.update_validator(validator_index, validator) valid_voluntary_exit = create_mock_voluntary_exit( - state, - config, - keymap, - validator_index, + state, config, keymap, validator_index ) if success: block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - voluntary_exits=(valid_voluntary_exit,), + voluntary_exits=(valid_voluntary_exit,) ) block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - slot=state.slot, - body=block_body, + slot=state.slot, body=block_body ) - new_state = process_voluntary_exits( - state, - block, - config, - ) + new_state = process_voluntary_exits(state, block, config) updated_validator = new_state.validators[validator_index] assert updated_validator.exit_epoch != FAR_FUTURE_EPOCH - assert updated_validator.exit_epoch > state.current_epoch(config.SLOTS_PER_EPOCH) + assert updated_validator.exit_epoch > state.current_epoch( + config.SLOTS_PER_EPOCH + ) assert updated_validator.withdrawable_epoch == ( updated_validator.exit_epoch + config.MIN_VALIDATOR_WITHDRAWABILITY_DELAY ) else: invalid_voluntary_exit = valid_voluntary_exit.copy( - signature=b'\x12' * 96, # Put wrong signature + signature=b"\x12" * 96 # Put wrong signature ) block_body = BeaconBlockBody(**sample_beacon_block_body_params).copy( - voluntary_exits=(invalid_voluntary_exit,), + voluntary_exits=(invalid_voluntary_exit,) ) block = SerenityBeaconBlock(**sample_beacon_block_params).copy( - slot=state.slot, - body=block_body, + slot=state.slot, body=block_body ) with pytest.raises(ValidationError): - process_voluntary_exits( - state, - block, - config, - ) + process_voluntary_exits(state, block, config) diff --git a/tests/eth2/core/beacon/state_machines/forks/test_state_machine.py b/tests/eth2/core/beacon/state_machines/forks/test_state_machine.py index 135fdf7a4c..8ea8e8f4ad 100644 --- a/tests/eth2/core/beacon/state_machines/forks/test_state_machine.py +++ b/tests/eth2/core/beacon/state_machines/forks/test_state_machine.py @@ -1,20 +1,10 @@ import pytest -from eth2.beacon.state_machines.forks.serenity import ( - SerenityStateMachine, -) -from eth2.beacon.state_machines.forks.xiao_long_bao import ( - XiaoLongBaoStateMachine, -) +from eth2.beacon.state_machines.forks.serenity import SerenityStateMachine +from eth2.beacon.state_machines.forks.xiao_long_bao import XiaoLongBaoStateMachine -@pytest.mark.parametrize( - "sm_klass", - ( - SerenityStateMachine, - XiaoLongBaoStateMachine, - ) -) +@pytest.mark.parametrize("sm_klass", (SerenityStateMachine, XiaoLongBaoStateMachine)) def test_sm_class_well_defined(sm_klass): state_machine = sm_klass(chaindb=None, attestation_pool=None, slot=None) assert state_machine.get_block_class() diff --git a/tests/eth2/core/beacon/state_machines/test_state_transition.py b/tests/eth2/core/beacon/state_machines/test_state_transition.py index a434ebcc4e..4e537cd886 100644 --- a/tests/eth2/core/beacon/state_machines/test_state_transition.py +++ b/tests/eth2/core/beacon/state_machines/test_state_transition.py @@ -1,23 +1,19 @@ import pytest -from eth2.beacon.state_machines.forks.serenity.blocks import ( - SerenityBeaconBlock, -) -from eth2.beacon.tools.builder.proposer import ( - create_mock_block, -) +from eth2.beacon.state_machines.forks.serenity.blocks import SerenityBeaconBlock +from eth2.beacon.tools.builder.proposer import create_mock_block from eth2.beacon.types.historical_batch import HistoricalBatch @pytest.mark.parametrize( ( - 'validator_count,' - 'slots_per_epoch,' - 'min_attestation_inclusion_delay,' - 'target_committee_size,' - 'shard_count,' - 'state_slot,' - 'slots_per_historical_root' + "validator_count," + "slots_per_epoch," + "min_attestation_inclusion_delay," + "target_committee_size," + "shard_count," + "state_slot," + "slots_per_historical_root" ), [ (10, 10, 1, 2, 10, 2, 8192), @@ -35,17 +31,19 @@ # updated_state.slot % SLOTS_PER_HISTORICAL_ROOT = 0 # (11, 4, 1, 2, 4, 15, 8), # (16, 4, 1, 2, 4, 31, 8), - ] + ], ) -def test_per_slot_transition(chaindb, - genesis_block, - genesis_state, - fixture_sm_class, - config, - state_slot, - fork_choice_scoring, - empty_attestation_pool, - keymap): +def test_per_slot_transition( + chaindb, + genesis_block, + genesis_state, + fixture_sm_class, + config, + state_slot, + fork_choice_scoring, + empty_attestation_pool, + keymap, +): chaindb.persist_block(genesis_block, SerenityBeaconBlock, fork_choice_scoring) chaindb.persist_state(genesis_state) @@ -56,9 +54,7 @@ def test_per_slot_transition(chaindb, state=state, config=config, state_machine=fixture_sm_class( - chaindb, - empty_attestation_pool, - genesis_block.slot, + chaindb, empty_attestation_pool, genesis_block.slot ), block_class=SerenityBeaconBlock, parent_block=genesis_block, @@ -70,11 +66,7 @@ def test_per_slot_transition(chaindb, chaindb.persist_block(block, SerenityBeaconBlock, fork_choice_scoring) # Get state machine instance - sm = fixture_sm_class( - chaindb, - empty_attestation_pool, - block.slot, - ) + sm = fixture_sm_class(chaindb, empty_attestation_pool, block.slot) # Get state transition instance st = sm.state_transition_class(sm.config) @@ -94,8 +86,7 @@ def test_per_slot_transition(chaindb, # historical_roots if updated_state.slot % st.config.SLOTS_PER_HISTORICAL_ROOT == 0: historical_batch = HistoricalBatch( - block_roots=state.block_roots, - state_roots=state.state_roots, + block_roots=state.block_roots, state_roots=state.state_roots ) assert updated_state.historical_roots[-1] == historical_batch.hash_tree_root else: diff --git a/tests/eth2/core/beacon/test_attestation_helpers.py b/tests/eth2/core/beacon/test_attestation_helpers.py index 93fbb01ac5..8eb70cd19e 100644 --- a/tests/eth2/core/beacon/test_attestation_helpers.py +++ b/tests/eth2/core/beacon/test_attestation_helpers.py @@ -1,58 +1,46 @@ import copy import random +from eth_utils import ValidationError +from eth_utils.toolz import assoc import pytest from eth2._utils.bls import bls - -from eth_utils import ( - ValidationError, -) -from eth_utils.toolz import ( - assoc, -) - -from eth2.beacon.helpers import ( - get_domain, -) -from eth2.beacon.signature_domain import SignatureDomain from eth2.beacon.attestation_helpers import ( is_slashable_attestation_data, - validate_indexed_attestation_aggregate_signature, validate_indexed_attestation, + validate_indexed_attestation_aggregate_signature, ) +from eth2.beacon.helpers import get_domain +from eth2.beacon.signature_domain import SignatureDomain from eth2.beacon.types.attestation_data import AttestationData -from eth2.beacon.types.attestation_data_and_custody_bits import AttestationDataAndCustodyBit +from eth2.beacon.types.attestation_data_and_custody_bits import ( + AttestationDataAndCustodyBit, +) from eth2.beacon.types.attestations import IndexedAttestation from eth2.beacon.types.forks import Fork -@pytest.mark.parametrize( - ( - 'validator_count', - ), - [ - (40,), - ] -) +@pytest.mark.parametrize(("validator_count",), [(40,)]) def test_verify_indexed_attestation_signature( - slots_per_epoch, - validator_count, - genesis_state, - config, - privkeys, - sample_beacon_state_params, - genesis_validators, - genesis_balances, - sample_indexed_attestation_params, - sample_fork_params): - state = genesis_state.copy( - fork=Fork(**sample_fork_params), - ) + slots_per_epoch, + validator_count, + genesis_state, + config, + privkeys, + sample_beacon_state_params, + genesis_validators, + genesis_balances, + sample_indexed_attestation_params, + sample_fork_params, +): + state = genesis_state.copy(fork=Fork(**sample_fork_params)) # NOTE: we can do this before "correcting" the params as they # touch disjoint subsets of the provided params - message_hashes = _create_indexed_attestation_messages(sample_indexed_attestation_params) + message_hashes = _create_indexed_attestation_messages( + sample_indexed_attestation_params + ) valid_params = _correct_indexed_attestation_params( validator_count, @@ -64,12 +52,16 @@ def test_verify_indexed_attestation_signature( ) valid_votes = IndexedAttestation(**valid_params) - validate_indexed_attestation_aggregate_signature(state, valid_votes, slots_per_epoch) + validate_indexed_attestation_aggregate_signature( + state, valid_votes, slots_per_epoch + ) invalid_params = _corrupt_signature(slots_per_epoch, valid_params, state.fork) invalid_votes = IndexedAttestation(**invalid_params) with pytest.raises(ValidationError): - validate_indexed_attestation_aggregate_signature(state, invalid_votes, slots_per_epoch) + validate_indexed_attestation_aggregate_signature( + state, invalid_votes, slots_per_epoch + ) def _get_indices_and_signatures(validator_count, state, config, message_hash, privkeys): @@ -85,41 +77,28 @@ def _get_indices_and_signatures(validator_count, state, config, message_hash, pr signature_domain=signature_domain, slots_per_epoch=config.SLOTS_PER_EPOCH, ) - signatures = tuple( - map(lambda key: bls.sign(message_hash, key, domain), privkeys) - ) + signatures = tuple(map(lambda key: bls.sign(message_hash, key, domain), privkeys)) return (indices, signatures) -def _run_verify_indexed_vote(slots_per_epoch, - params, - state, - max_validators_per_committee, - should_succeed): +def _run_verify_indexed_vote( + slots_per_epoch, params, state, max_validators_per_committee, should_succeed +): votes = IndexedAttestation(**params) if should_succeed: validate_indexed_attestation( - state, - votes, - max_validators_per_committee, - slots_per_epoch, + state, votes, max_validators_per_committee, slots_per_epoch ) else: with pytest.raises(ValidationError): validate_indexed_attestation( - state, - votes, - max_validators_per_committee, - slots_per_epoch, + state, votes, max_validators_per_committee, slots_per_epoch ) -def _correct_indexed_attestation_params(validator_count, - message_hashes, - params, - privkeys, - state, - config): +def _correct_indexed_attestation_params( + validator_count, message_hashes, params, privkeys, state, config +): valid_params = copy.deepcopy(params) (custody_bit_0_indices, signatures) = _get_indices_and_signatures( @@ -154,47 +133,26 @@ def _corrupt_custody_bit_0_indices(params): def _corrupt_custody_bit_0_indices_max(max_validators_per_committee, params): - corrupt_custody_bit_0_indices = [ - i - for i in range(max_validators_per_committee + 1) - ] + corrupt_custody_bit_0_indices = [i for i in range(max_validators_per_committee + 1)] return assoc(params, "custody_bit_0_indices", corrupt_custody_bit_0_indices) def _corrupt_signature(slots_per_epoch, params, fork): - return assoc(params, "signature", b'\x12' * 96) + return assoc(params, "signature", b"\x12" * 96) def _create_indexed_attestation_messages(params): attestation = IndexedAttestation(**params) data = attestation.data return ( - AttestationDataAndCustodyBit( - data=data, - custody_bit=False, - ).hash_tree_root, - AttestationDataAndCustodyBit( - data=data, - custody_bit=True, - ).hash_tree_root, + AttestationDataAndCustodyBit(data=data, custody_bit=False).hash_tree_root, + AttestationDataAndCustodyBit(data=data, custody_bit=True).hash_tree_root, ) +@pytest.mark.parametrize(("validator_count",), [(40,)]) @pytest.mark.parametrize( - ( - 'validator_count', - ), - [ - (40,), - ] -) -@pytest.mark.parametrize( - ( - 'param_mapper', - 'should_succeed', - 'needs_fork', - 'is_testing_max_length', - ), + ("param_mapper", "should_succeed", "needs_fork", "is_testing_max_length"), [ (lambda params: params, True, False, False), (_corrupt_custody_bit_1_indices_not_empty, False, False, False), @@ -203,28 +161,30 @@ def _create_indexed_attestation_messages(params): (_corrupt_signature, False, True, False), ], ) -def test_validate_indexed_attestation(slots_per_epoch, - validator_count, - genesis_state, - param_mapper, - should_succeed, - needs_fork, - is_testing_max_length, - privkeys, - sample_beacon_state_params, - genesis_validators, - genesis_balances, - sample_indexed_attestation_params, - sample_fork_params, - max_validators_per_committee, - config): - state = genesis_state.copy( - fork=Fork(**sample_fork_params), - ) +def test_validate_indexed_attestation( + slots_per_epoch, + validator_count, + genesis_state, + param_mapper, + should_succeed, + needs_fork, + is_testing_max_length, + privkeys, + sample_beacon_state_params, + genesis_validators, + genesis_balances, + sample_indexed_attestation_params, + sample_fork_params, + max_validators_per_committee, + config, +): + state = genesis_state.copy(fork=Fork(**sample_fork_params)) # NOTE: we can do this before "correcting" the params as they # touch disjoint subsets of the provided params - message_hashes = _create_indexed_attestation_messages(sample_indexed_attestation_params) + message_hashes = _create_indexed_attestation_messages( + sample_indexed_attestation_params + ) params = _correct_indexed_attestation_params( validator_count, @@ -242,47 +202,37 @@ def test_validate_indexed_attestation(slots_per_epoch, else: params = param_mapper(params) _run_verify_indexed_vote( - slots_per_epoch, - params, - state, - max_validators_per_committee, - should_succeed, + slots_per_epoch, params, state, max_validators_per_committee, should_succeed ) -@pytest.mark.parametrize( - ( - 'validator_count', - ), - [ - (40,), - ] -) -def test_verify_indexed_attestation_after_fork(genesis_state, - slots_per_epoch, - validator_count, - privkeys, - sample_beacon_state_params, - genesis_validators, - genesis_balances, - sample_indexed_attestation_params, - sample_fork_params, - config, - max_validators_per_committee): +@pytest.mark.parametrize(("validator_count",), [(40,)]) +def test_verify_indexed_attestation_after_fork( + genesis_state, + slots_per_epoch, + validator_count, + privkeys, + sample_beacon_state_params, + genesis_validators, + genesis_balances, + sample_indexed_attestation_params, + sample_fork_params, + config, + max_validators_per_committee, +): # Test that indexed data is still valid after fork # Indexed data slot = 10, fork slot = 15, current slot = 20 past_fork_params = { - 'previous_version': (0).to_bytes(4, 'little'), - 'current_version': (1).to_bytes(4, 'little'), - 'epoch': 15, + "previous_version": (0).to_bytes(4, "little"), + "current_version": (1).to_bytes(4, "little"), + "epoch": 15, } - state = genesis_state.copy( - slot=20, - fork=Fork(**past_fork_params), - ) + state = genesis_state.copy(slot=20, fork=Fork(**past_fork_params)) - message_hashes = _create_indexed_attestation_messages(sample_indexed_attestation_params) + message_hashes = _create_indexed_attestation_messages( + sample_indexed_attestation_params + ) valid_params = _correct_indexed_attestation_params( validator_count, @@ -293,53 +243,33 @@ def test_verify_indexed_attestation_after_fork(genesis_state, config, ) _run_verify_indexed_vote( - slots_per_epoch, - valid_params, - state, - max_validators_per_committee, - True, + slots_per_epoch, valid_params, state, max_validators_per_committee, True ) @pytest.mark.parametrize( - ( - 'is_double_vote,' - 'is_surround_vote,' - ), - [ - (False, False), - (False, True), - (True, False), - (True, True), - ], + ("is_double_vote," "is_surround_vote,"), + [(False, False), (False, True), (True, False), (True, True)], ) -def test_is_slashable_attestation_data(sample_attestation_data_params, - is_double_vote, - is_surround_vote): +def test_is_slashable_attestation_data( + sample_attestation_data_params, is_double_vote, is_surround_vote +): data_1 = AttestationData(**sample_attestation_data_params) data_2 = AttestationData(**sample_attestation_data_params) if is_double_vote: data_2 = data_2.copy( beacon_block_root=( - int.from_bytes( - data_1.beacon_block_root, - "little", - ) + 1 + int.from_bytes(data_1.beacon_block_root, "little") + 1 ).to_bytes(32, "little") ) if is_surround_vote: data_1 = data_1.copy( - source=data_1.source.copy( - epoch=data_2.source.epoch - 1, - ), - target=data_1.target.copy( - epoch=data_2.target.epoch + 1, - ), + source=data_1.source.copy(epoch=data_2.source.epoch - 1), + target=data_1.target.copy(epoch=data_2.target.epoch + 1), ) - assert is_slashable_attestation_data( - data_1, - data_2 - ) == (is_double_vote or is_surround_vote) + assert is_slashable_attestation_data(data_1, data_2) == ( + is_double_vote or is_surround_vote + ) diff --git a/tests/eth2/core/beacon/test_committee_helpers.py b/tests/eth2/core/beacon/test_committee_helpers.py index e71d54a326..df2a515b3c 100644 --- a/tests/eth2/core/beacon/test_committee_helpers.py +++ b/tests/eth2/core/beacon/test_committee_helpers.py @@ -1,37 +1,32 @@ import random +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) - from eth2.beacon.committee_helpers import ( - get_committees_per_slot, - get_committee_count, - get_shard_delta, - get_start_shard, - _find_proposer_in_committee, _calculate_first_committee_at_slot, + _find_proposer_in_committee, get_beacon_proposer_index, + get_committee_count, + get_committees_per_slot, get_crosslink_committee, + get_shard_delta, + get_start_shard, ) from eth2.beacon.helpers import ( - get_active_validator_indices, compute_start_slot_of_epoch, + get_active_validator_indices, ) -from eth2.configs import ( - CommitteeConfig, -) +from eth2.configs import CommitteeConfig @pytest.mark.parametrize( ( - 'active_validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - 'expected_committee_count' + "active_validator_count," + "slots_per_epoch," + "target_committee_size," + "shard_count," + "expected_committee_count" ), [ # SHARD_COUNT // SLOTS_PER_EPOCH @@ -44,11 +39,13 @@ (40, 5, 10, 5, 5), ], ) -def test_get_committees_per_slot(active_validator_count, - slots_per_epoch, - target_committee_size, - shard_count, - expected_committee_count): +def test_get_committees_per_slot( + active_validator_count, + slots_per_epoch, + target_committee_size, + shard_count, + expected_committee_count, +): assert expected_committee_count // slots_per_epoch == get_committees_per_slot( active_validator_count=active_validator_count, shard_count=shard_count, @@ -59,11 +56,11 @@ def test_get_committees_per_slot(active_validator_count, @pytest.mark.parametrize( ( - 'active_validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - 'expected_committee_count' + "active_validator_count," + "slots_per_epoch," + "target_committee_size," + "shard_count," + "expected_committee_count" ), [ # SHARD_COUNT // SLOTS_PER_EPOCH @@ -76,11 +73,13 @@ def test_get_committees_per_slot(active_validator_count, (40, 5, 10, 5, 5), ], ) -def test_get_committee_count(active_validator_count, - slots_per_epoch, - target_committee_size, - shard_count, - expected_committee_count): +def test_get_committee_count( + active_validator_count, + slots_per_epoch, + target_committee_size, + shard_count, + expected_committee_count, +): assert expected_committee_count == get_committee_count( active_validator_count=active_validator_count, shard_count=shard_count, @@ -91,11 +90,11 @@ def test_get_committee_count(active_validator_count, @pytest.mark.parametrize( ( - 'validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - 'expected_shard_delta,' + "validator_count," + "slots_per_epoch," + "target_committee_size," + "shard_count," + "expected_shard_delta," ), [ # SHARD_COUNT - SHARD_COUNT // SLOTS_PER_EPOCH @@ -104,9 +103,7 @@ def test_get_committee_count(active_validator_count, (500, 20, 10, 100, 40), ], ) -def test_get_shard_delta(genesis_state, - expected_shard_delta, - config): +def test_get_shard_delta(genesis_state, expected_shard_delta, config): state = genesis_state epoch = state.current_epoch(config.SLOTS_PER_EPOCH) @@ -115,13 +112,13 @@ def test_get_shard_delta(genesis_state, @pytest.mark.parametrize( ( - 'validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - 'current_epoch,' - 'target_epoch,' - 'expected_epoch_start_shard,' + "validator_count," + "slots_per_epoch," + "target_committee_size," + "shard_count," + "current_epoch," + "target_epoch," + "expected_epoch_start_shard," ), [ (1000, 25, 5, 50, 3, 2, 2), @@ -130,29 +127,28 @@ def test_get_shard_delta(genesis_state, (1000, 25, 5, 50, 3, 5, None), ], ) -def test_get_start_shard(genesis_state, - current_epoch, - target_epoch, - expected_epoch_start_shard, - config): +def test_get_start_shard( + genesis_state, current_epoch, target_epoch, expected_epoch_start_shard, config +): state = genesis_state.copy( - slot=compute_start_slot_of_epoch(current_epoch, config.SLOTS_PER_EPOCH), + slot=compute_start_slot_of_epoch(current_epoch, config.SLOTS_PER_EPOCH) ) if expected_epoch_start_shard is None: with pytest.raises(ValidationError): get_start_shard(state, target_epoch, CommitteeConfig(config)) else: - epoch_start_shard = get_start_shard(state, target_epoch, CommitteeConfig(config)) + epoch_start_shard = get_start_shard( + state, target_epoch, CommitteeConfig(config) + ) assert epoch_start_shard == expected_epoch_start_shard -SOME_SEED = b'\x33' * 32 +SOME_SEED = b"\x33" * 32 -def test_find_proposer_in_committee(genesis_validators, - config): - epoch = random.randrange(config.GENESIS_EPOCH, 2**64) +def test_find_proposer_in_committee(genesis_validators, config): + epoch = random.randrange(config.GENESIS_EPOCH, 2 ** 64) proposer_index = random.randrange(0, len(genesis_validators)) validators = tuple() @@ -160,26 +156,26 @@ def test_find_proposer_in_committee(genesis_validators, # should at a minimum have 17 ETH as ``effective_balance``. # Using 1 ETH should maintain the same spirit of the test and # ensure we can know the likely candidate ahead of time. - one_eth_in_gwei = 1 * 10**9 + one_eth_in_gwei = 1 * 10 ** 9 for index, validator in enumerate(genesis_validators): if index == proposer_index: validators += (validator,) else: - validators += (validator.copy( - effective_balance=one_eth_in_gwei, - ),) - - assert _find_proposer_in_committee( - validators, - range(len(validators)), - epoch, - SOME_SEED, - config.MAX_EFFECTIVE_BALANCE - ) == proposer_index + validators += (validator.copy(effective_balance=one_eth_in_gwei),) + + assert ( + _find_proposer_in_committee( + validators, + range(len(validators)), + epoch, + SOME_SEED, + config.MAX_EFFECTIVE_BALANCE, + ) + == proposer_index + ) -def test_calculate_first_committee_at_slot(genesis_state, - config): +def test_calculate_first_committee_at_slot(genesis_state, config): state = genesis_state slots_per_epoch = config.SLOTS_PER_EPOCH shard_count = config.SHARD_COUNT @@ -187,7 +183,9 @@ def test_calculate_first_committee_at_slot(genesis_state, current_epoch = state.current_epoch(slots_per_epoch) - active_validator_indices = get_active_validator_indices(state.validators, current_epoch) + active_validator_indices = get_active_validator_indices( + state.validators, current_epoch + ) committees_per_slot = get_committees_per_slot( len(active_validator_indices), @@ -199,73 +197,48 @@ def test_calculate_first_committee_at_slot(genesis_state, assert state.slot % config.SLOTS_PER_EPOCH == 0 for slot in range(state.slot, state.slot + config.SLOTS_PER_EPOCH): offset = committees_per_slot * (slot % slots_per_epoch) - shard = ( - get_start_shard(state, current_epoch, config) + offset - ) % shard_count - committee = get_crosslink_committee( - state, - current_epoch, - shard, - config, - ) + shard = (get_start_shard(state, current_epoch, config) + offset) % shard_count + committee = get_crosslink_committee(state, current_epoch, shard, config) - assert committee == _calculate_first_committee_at_slot(state, slot, CommitteeConfig(config)) + assert committee == _calculate_first_committee_at_slot( + state, slot, CommitteeConfig(config) + ) def _invalidate_all_but_proposer(proposer_index, index, validator): if proposer_index == index: return validator else: - return validator.copy( - effective_balance=-1, - ) + return validator.copy(effective_balance=-1) -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (1000), - ], -) -def test_get_beacon_proposer_index(genesis_state, - config): +@pytest.mark.parametrize(("validator_count,"), [(1000)]) +def test_get_beacon_proposer_index(genesis_state, config): state = genesis_state first_committee = _calculate_first_committee_at_slot( - state, - state.slot, - CommitteeConfig(config), + state, state.slot, CommitteeConfig(config) ) some_validator_index = random.sample(first_committee, 1)[0] state = state.copy( validators=tuple( - _invalidate_all_but_proposer( - some_validator_index, - index, - validator, - ) for index, validator in enumerate(state.validators) - ), + _invalidate_all_but_proposer(some_validator_index, index, validator) + for index, validator in enumerate(state.validators) + ) ) - assert get_beacon_proposer_index(state, CommitteeConfig(config)) == some_validator_index + assert ( + get_beacon_proposer_index(state, CommitteeConfig(config)) + == some_validator_index + ) -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (1000), - ], -) -def test_get_crosslink_committee(genesis_state, - config): +@pytest.mark.parametrize(("validator_count,"), [(1000)]) +def test_get_crosslink_committee(genesis_state, config): indices = tuple() - for shard in range(get_shard_delta(genesis_state, - config.GENESIS_EPOCH, - CommitteeConfig(config))): + for shard in range( + get_shard_delta(genesis_state, config.GENESIS_EPOCH, CommitteeConfig(config)) + ): some_committee = get_crosslink_committee( genesis_state, genesis_state.current_epoch(config.SLOTS_PER_EPOCH), diff --git a/tests/eth2/core/beacon/test_deposit_helpers.py b/tests/eth2/core/beacon/test_deposit_helpers.py index dcb61e72a3..f69ba5a084 100644 --- a/tests/eth2/core/beacon/test_deposit_helpers.py +++ b/tests/eth2/core/beacon/test_deposit_helpers.py @@ -1,74 +1,42 @@ +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) +from eth2.beacon.deposit_helpers import process_deposit, validate_deposit_proof +from eth2.beacon.tools.builder.initializer import create_mock_deposit -from eth2.beacon.deposit_helpers import ( - process_deposit, - validate_deposit_proof, -) -from eth2.beacon.tools.builder.initializer import ( - create_mock_deposit, -) - - -@pytest.mark.parametrize( - ( - 'success', - ), - [ - (True,), - (False,), - ] -) -def test_validate_deposit_proof(config, - keymap, - pubkeys, - deposit_contract_tree_depth, - genesis_state, - success): +@pytest.mark.parametrize(("success",), [(True,), (False,)]) +def test_validate_deposit_proof( + config, keymap, pubkeys, deposit_contract_tree_depth, genesis_state, success +): state = genesis_state - withdrawal_credentials = b'\x34' * 32 + withdrawal_credentials = b"\x34" * 32 state, deposit = create_mock_deposit( - state, - pubkeys[0], - keymap, - withdrawal_credentials, - config, + state, pubkeys[0], keymap, withdrawal_credentials, config ) if success: validate_deposit_proof(state, deposit, deposit_contract_tree_depth) else: deposit = deposit.copy( - data=deposit.data.copy( - withdrawal_credentials=b'\x23' * 32, - ) + data=deposit.data.copy(withdrawal_credentials=b"\x23" * 32) ) with pytest.raises(ValidationError): validate_deposit_proof(state, deposit, deposit_contract_tree_depth) -@pytest.mark.parametrize( - ( - 'is_new_validator', - ), - [ - (True,), - (False,), - ] -) -def test_process_deposit(config, - sample_beacon_state_params, - keymap, - genesis_state, - validator_count, - is_new_validator, - pubkeys): +@pytest.mark.parametrize(("is_new_validator",), [(True,), (False,)]) +def test_process_deposit( + config, + sample_beacon_state_params, + keymap, + genesis_state, + validator_count, + is_new_validator, + pubkeys, +): state = genesis_state - withdrawal_credentials = b'\x34' * 32 + withdrawal_credentials = b"\x34" * 32 if is_new_validator: validator_index = validator_count else: @@ -77,20 +45,12 @@ def test_process_deposit(config, pubkey = pubkeys[validator_index] state, deposit = create_mock_deposit( - state, - pubkey, - keymap, - withdrawal_credentials, - config, + state, pubkey, keymap, withdrawal_credentials, config ) validator_count_before_deposit = state.validator_count - result_state = process_deposit( - state=state, - deposit=deposit, - config=config, - ) + result_state = process_deposit(state=state, deposit=deposit, config=config) # test immutability assert len(state.validators) == validator_count_before_deposit @@ -104,4 +64,6 @@ def test_process_deposit(config, else: assert len(result_state.validators) == len(state.validators) assert validator.pubkey == pubkeys[validator_index] - assert result_state.balances[validator_index] == 2 * config.MAX_EFFECTIVE_BALANCE + assert ( + result_state.balances[validator_index] == 2 * config.MAX_EFFECTIVE_BALANCE + ) diff --git a/tests/eth2/core/beacon/test_epoch_processing_helpers.py b/tests/eth2/core/beacon/test_epoch_processing_helpers.py index 996e0cfd5b..6b2d73baac 100644 --- a/tests/eth2/core/beacon/test_epoch_processing_helpers.py +++ b/tests/eth2/core/beacon/test_epoch_processing_helpers.py @@ -1,72 +1,43 @@ import random -import pytest -from eth_utils.toolz import ( - random_sample, -) +from eth_utils.toolz import random_sample +import pytest -from eth2._utils.bitfield import ( - set_voted, - get_empty_bitfield, -) -from eth2._utils.tuple import ( - update_tuple_item, -) -from eth2.configs import CommitteeConfig -from eth2.beacon.constants import ( - FAR_FUTURE_EPOCH, - GWEI_PER_ETH, -) -from eth2.beacon.exceptions import InvalidEpochError -from eth2.beacon.committee_helpers import ( - get_crosslink_committee, -) +from eth2._utils.bitfield import get_empty_bitfield, set_voted +from eth2._utils.tuple import update_tuple_item +from eth2.beacon.committee_helpers import get_crosslink_committee +from eth2.beacon.constants import FAR_FUTURE_EPOCH, GWEI_PER_ETH from eth2.beacon.epoch_processing_helpers import ( - increase_balance, + _find_winning_crosslink_and_attesting_indices_from_candidates, + _get_attestations_for_shard, + _get_attestations_for_valid_crosslink, + compute_activation_exit_epoch, decrease_balance, get_attesting_indices, - compute_activation_exit_epoch, - get_validator_churn_limit, + get_base_reward, + get_matching_head_attestations, get_matching_source_attestations, get_matching_target_attestations, - get_matching_head_attestations, get_unslashed_attesting_indices, - _get_attestations_for_shard, - _get_attestations_for_valid_crosslink, - _find_winning_crosslink_and_attesting_indices_from_candidates, - get_base_reward, -) -from eth2.beacon.helpers import ( - compute_start_slot_of_epoch, -) -from eth2.beacon.types.attestation_data import ( - AttestationData, + get_validator_churn_limit, + increase_balance, ) +from eth2.beacon.exceptions import InvalidEpochError +from eth2.beacon.helpers import compute_start_slot_of_epoch +from eth2.beacon.tools.builder.validator import mk_pending_attestation_from_committee +from eth2.beacon.types.attestation_data import AttestationData from eth2.beacon.types.checkpoints import Checkpoint from eth2.beacon.types.crosslinks import Crosslink from eth2.beacon.types.pending_attestations import PendingAttestation -from eth2.beacon.typing import ( - Gwei, -) -from eth2.beacon.tools.builder.validator import ( - mk_pending_attestation_from_committee, -) +from eth2.beacon.typing import Gwei +from eth2.configs import CommitteeConfig @pytest.mark.parametrize( - ( - "delta," - ), - [ - (1), - (GWEI_PER_ETH), - (2 * GWEI_PER_ETH), - (32 * GWEI_PER_ETH), - (33 * GWEI_PER_ETH), - ], + ("delta,"), + [(1), (GWEI_PER_ETH), (2 * GWEI_PER_ETH), (32 * GWEI_PER_ETH), (33 * GWEI_PER_ETH)], ) -def test_increase_balance(genesis_state, - delta): +def test_increase_balance(genesis_state, delta): index = random.sample(range(len(genesis_state.validators)), 1)[0] prior_balance = genesis_state.balances[index] state = increase_balance(genesis_state, index, delta) @@ -74,9 +45,7 @@ def test_increase_balance(genesis_state, @pytest.mark.parametrize( - ( - "delta," - ), + ("delta,"), [ (1), (GWEI_PER_ETH), @@ -86,43 +55,26 @@ def test_increase_balance(genesis_state, (100 * GWEI_PER_ETH), ], ) -def test_decrease_balance(genesis_state, - delta): +def test_decrease_balance(genesis_state, delta): index = random.sample(range(len(genesis_state.validators)), 1)[0] prior_balance = genesis_state.balances[index] state = decrease_balance(genesis_state, index, delta) assert state.balances[index] == Gwei(max(prior_balance - delta, 0)) -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (1000), - ], -) -def test_get_attesting_indices(genesis_state, - config): +@pytest.mark.parametrize(("validator_count,"), [(1000)]) +def test_get_attesting_indices(genesis_state, config): state = genesis_state.copy( slot=compute_start_slot_of_epoch(3, config.SLOTS_PER_EPOCH) ) target_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) target_shard = (state.start_shard + 3) % config.SHARD_COUNT some_committee = get_crosslink_committee( - state, - target_epoch, - target_shard, - CommitteeConfig(config), + state, target_epoch, target_shard, CommitteeConfig(config) ) data = AttestationData( - target=Checkpoint( - epoch=target_epoch, - ), - crosslink=Crosslink( - shard=target_shard, - ), + target=Checkpoint(epoch=target_epoch), crosslink=Crosslink(shard=target_shard) ) some_subset_count = random.randrange(1, len(some_committee) // 2) some_subset = random.sample(some_committee, some_subset_count) @@ -132,12 +84,7 @@ def test_get_attesting_indices(genesis_state, if index in some_subset: bitfield = set_voted(bitfield, i) - indices = get_attesting_indices( - state, - data, - bitfield, - CommitteeConfig(config), - ) + indices = get_attesting_indices(state, data, bitfield, CommitteeConfig(config)) assert set(indices) == set(some_subset) assert len(indices) == len(some_subset) @@ -146,8 +93,7 @@ def test_get_attesting_indices(genesis_state, def test_compute_activation_exit_epoch(activation_exit_delay): epoch = random.randrange(0, FAR_FUTURE_EPOCH) entry_exit_effect_epoch = compute_activation_exit_epoch( - epoch, - activation_exit_delay, + epoch, activation_exit_delay ) assert entry_exit_effect_epoch == (epoch + 1 + activation_exit_delay) @@ -166,61 +112,40 @@ def test_compute_activation_exit_epoch(activation_exit_delay): (100, 1, 5, 100), ], ) -def test_get_validator_churn_limit(genesis_state, - expected_churn_limit, - config): +def test_get_validator_churn_limit(genesis_state, expected_churn_limit, config): assert get_validator_churn_limit(genesis_state, config) == expected_churn_limit @pytest.mark.parametrize( - ( - "current_epoch," - "target_epoch," - "success," - ), - [ - (40, 40, True), - (40, 39, True), - (40, 38, False), - (40, 41, False), - ], + ("current_epoch," "target_epoch," "success,"), + [(40, 40, True), (40, 39, True), (40, 38, False), (40, 41, False)], ) -def test_get_matching_source_attestations(genesis_state, - current_epoch, - target_epoch, - success, - config): +def test_get_matching_source_attestations( + genesis_state, current_epoch, target_epoch, success, config +): state = genesis_state.copy( slot=compute_start_slot_of_epoch(current_epoch, config.SLOTS_PER_EPOCH), current_epoch_attestations=tuple( PendingAttestation( data=AttestationData( - beacon_block_root=current_epoch.to_bytes(32, "little"), + beacon_block_root=current_epoch.to_bytes(32, "little") ) ) ), previous_epoch_attestations=tuple( PendingAttestation( data=AttestationData( - beacon_block_root=(current_epoch - 1).to_bytes(32, "little"), + beacon_block_root=(current_epoch - 1).to_bytes(32, "little") ) ) - ) + ), ) if success: - attestations = get_matching_source_attestations( - state, - target_epoch, - config, - ) + attestations = get_matching_source_attestations(state, target_epoch, config) else: with pytest.raises(InvalidEpochError): - get_matching_source_attestations( - state, - target_epoch, - config, - ) + get_matching_source_attestations(state, target_epoch, config) return if current_epoch == target_epoch: @@ -229,31 +154,24 @@ def test_get_matching_source_attestations(genesis_state, assert attestations == state.previous_epoch_attestations -def test_get_matching_target_attestations(genesis_state, - config): +def test_get_matching_target_attestations(genesis_state, config): some_epoch = config.GENESIS_EPOCH + 20 some_slot = compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH) - some_target_root = b'\x33' * 32 + some_target_root = b"\x33" * 32 target_attestations = tuple( ( PendingAttestation( - data=AttestationData( - target=Checkpoint( - root=some_target_root, - ), - ), - ) for _ in range(3) + data=AttestationData(target=Checkpoint(root=some_target_root)) + ) + for _ in range(3) ) ) current_epoch_attestations = target_attestations + tuple( ( PendingAttestation( - data=AttestationData( - target=Checkpoint( - root=b'\x44' * 32, - ), - ), - ) for _ in range(3) + data=AttestationData(target=Checkpoint(root=b"\x44" * 32)) + ) + for _ in range(3) ) ) state = genesis_state.copy( @@ -266,48 +184,39 @@ def test_get_matching_target_attestations(genesis_state, current_epoch_attestations=current_epoch_attestations, ) - attestations = get_matching_target_attestations( - state, - some_epoch, - config, - ) + attestations = get_matching_target_attestations(state, some_epoch, config) assert attestations == target_attestations -def test_get_matching_head_attestations(genesis_state, - config): +def test_get_matching_head_attestations(genesis_state, config): some_epoch = config.GENESIS_EPOCH + 20 - some_slot = compute_start_slot_of_epoch( - some_epoch, - config.SLOTS_PER_EPOCH - ) + config.SLOTS_PER_EPOCH // 4 - some_target_root = b'\x33' * 32 + some_slot = ( + compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH) + + config.SLOTS_PER_EPOCH // 4 + ) + some_target_root = b"\x33" * 32 target_attestations = tuple( ( PendingAttestation( data=AttestationData( beacon_block_root=some_target_root, - target=Checkpoint( - epoch=some_epoch - 1, - ), - crosslink=Crosslink( - shard=i, - ) - ), - ) for i in range(3) + target=Checkpoint(epoch=some_epoch - 1), + crosslink=Crosslink(shard=i), + ) + ) + for i in range(3) ) ) current_epoch_attestations = target_attestations + tuple( ( PendingAttestation( data=AttestationData( - beacon_block_root=b'\x44' * 32, - target=Checkpoint( - epoch=some_epoch - 1, - ), - ), - ) for _ in range(3) + beacon_block_root=b"\x44" * 32, + target=Checkpoint(epoch=some_epoch - 1), + ) + ) + for _ in range(3) ) ) state = genesis_state.copy( @@ -318,44 +227,24 @@ def test_get_matching_head_attestations(genesis_state, current_epoch_attestations=current_epoch_attestations, ) - attestations = get_matching_head_attestations( - state, - some_epoch, - config, - ) + attestations = get_matching_head_attestations(state, some_epoch, config) assert attestations == target_attestations -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (1000), - ], -) -def test_get_unslashed_attesting_indices(genesis_state, - config): +@pytest.mark.parametrize(("validator_count,"), [(1000)]) +def test_get_unslashed_attesting_indices(genesis_state, config): state = genesis_state.copy( slot=compute_start_slot_of_epoch(3, config.SLOTS_PER_EPOCH) ) target_epoch = state.current_epoch(config.SLOTS_PER_EPOCH) target_shard = (state.start_shard + 3) % config.SHARD_COUNT some_committee = get_crosslink_committee( - state, - target_epoch, - target_shard, - CommitteeConfig(config), + state, target_epoch, target_shard, CommitteeConfig(config) ) data = AttestationData( - target=Checkpoint( - epoch=target_epoch, - ), - crosslink=Crosslink( - shard=target_shard, - ), + target=Checkpoint(epoch=target_epoch), crosslink=Crosslink(shard=target_shard) ) some_subset_count = random.randrange(1, len(some_committee) // 2) some_subset = random.sample(some_committee, some_subset_count) @@ -365,26 +254,17 @@ def test_get_unslashed_attesting_indices(genesis_state, if index in some_subset: if random.choice([True, False]): state = state.update_validator_with_fn( - index, - lambda v, *_: v.copy( - slashed=True, - ) + index, lambda v, *_: v.copy(slashed=True) ) bitfield = set_voted(bitfield, i) - some_subset = tuple(filter( - lambda index: not state.validators[index].slashed, - some_subset, - )) + some_subset = tuple( + filter(lambda index: not state.validators[index].slashed, some_subset) + ) indices = get_unslashed_attesting_indices( state, - ( - PendingAttestation( - data=data, - aggregation_bits=bitfield, - ), - ), + (PendingAttestation(data=data, aggregation_bits=bitfield),), CommitteeConfig(config), ) @@ -392,16 +272,8 @@ def test_get_unslashed_attesting_indices(genesis_state, assert len(indices) == len(some_subset) -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (1000), - ], -) -def test_find_candidate_attestations_for_shard(genesis_state, - config): +@pytest.mark.parametrize(("validator_count,"), [(1000)]) +def test_find_candidate_attestations_for_shard(genesis_state, config): some_epoch = config.GENESIS_EPOCH + 20 # start on some shard and walk a subset of them some_shard = 3 @@ -411,48 +283,41 @@ def test_find_candidate_attestations_for_shard(genesis_state, slot=compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH), start_shard=some_shard, current_crosslinks=tuple( - Crosslink( - shard=i, - data_root=(i).to_bytes(32, "little"), - ) + Crosslink(shard=i, data_root=(i).to_bytes(32, "little")) for i in range(config.SHARD_COUNT) ), ) # sample a subset of the shards to make attestations for some_shards_with_attestations = random.sample( - range(some_shard, some_shard + shard_offset), - shard_offset // 2, + range(some_shard, some_shard + shard_offset), shard_offset // 2 ) committee_and_shard_pairs = tuple( ( get_crosslink_committee( - state, - some_epoch, - some_shard + i, - CommitteeConfig(config), - ), some_shard + i - ) for i in range(shard_offset) + state, some_epoch, some_shard + i, CommitteeConfig(config) + ), + some_shard + i, + ) + for i in range(shard_offset) if some_shard + i in some_shards_with_attestations ) pending_attestations = { shard: mk_pending_attestation_from_committee( - state.current_crosslinks[shard], - len(committee), - shard, - ) for committee, shard in committee_and_shard_pairs + state.current_crosslinks[shard], len(committee), shard + ) + for committee, shard in committee_and_shard_pairs } # invalidate some crosslinks to test the crosslink filter some_crosslinks_to_mangle = random.sample( - some_shards_with_attestations, - len(some_shards_with_attestations) // 2, + some_shards_with_attestations, len(some_shards_with_attestations) // 2 ) - shards_with_valid_crosslinks = ( - set(some_shards_with_attestations) - set(some_crosslinks_to_mangle) + shards_with_valid_crosslinks = set(some_shards_with_attestations) - set( + some_crosslinks_to_mangle ) crosslinks = tuple() @@ -462,68 +327,37 @@ def test_find_candidate_attestations_for_shard(genesis_state, else: crosslinks += (Crosslink(),) - state = state.copy( - current_crosslinks=crosslinks, - ) + state = state.copy(current_crosslinks=crosslinks) # check around the range of shards we built up for shard in range(0, some_shard + shard_offset + 3): if shard in some_shards_with_attestations: attestations = _get_attestations_for_shard( - pending_attestations.values(), - shard, + pending_attestations.values(), shard ) assert attestations == (pending_attestations[shard],) if shard in some_crosslinks_to_mangle: assert not _get_attestations_for_valid_crosslink( - pending_attestations.values(), - state, - shard, - config, + pending_attestations.values(), state, shard, config ) else: attestations = _get_attestations_for_valid_crosslink( - pending_attestations.values(), - state, - shard, - config, + pending_attestations.values(), state, shard, config ) assert attestations == (pending_attestations[shard],) else: - assert not _get_attestations_for_shard( - pending_attestations.values(), - shard, - ) + assert not _get_attestations_for_shard(pending_attestations.values(), shard) assert not _get_attestations_for_valid_crosslink( - pending_attestations.values(), - state, - shard, - config, + pending_attestations.values(), state, shard, config ) -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (1000), - ], -) -@pytest.mark.parametrize( - ( - 'number_of_candidates,' - ), - [ - (0), - (1), - (3), - ], -) -def test_find_winning_crosslink_and_attesting_indices_from_candidates(genesis_state, - number_of_candidates, - config): +@pytest.mark.parametrize(("validator_count,"), [(1000)]) +@pytest.mark.parametrize(("number_of_candidates,"), [(0), (1), (3)]) +def test_find_winning_crosslink_and_attesting_indices_from_candidates( + genesis_state, number_of_candidates, config +): some_epoch = config.GENESIS_EPOCH + 20 some_shard = 3 @@ -531,19 +365,13 @@ def test_find_winning_crosslink_and_attesting_indices_from_candidates(genesis_st slot=compute_start_slot_of_epoch(some_epoch, config.SLOTS_PER_EPOCH), start_shard=some_shard, current_crosslinks=tuple( - Crosslink( - shard=i, - data_root=(i).to_bytes(32, "little"), - ) + Crosslink(shard=i, data_root=(i).to_bytes(32, "little")) for i in range(config.SHARD_COUNT) ), ) full_committee = get_crosslink_committee( - state, - some_epoch, - some_shard, - CommitteeConfig(config), + state, some_epoch, some_shard, CommitteeConfig(config) ) # break the committees up into different subsets to simulate different @@ -570,25 +398,20 @@ def test_find_winning_crosslink_and_attesting_indices_from_candidates(genesis_st len(full_committee), some_shard, target_epoch=some_epoch, - ) for committee in filtered_committees + ) + for committee in filtered_committees ) if number_of_candidates == 0: expected_result = (Crosslink(), set()) else: - expected_result = ( - candidates[0].data.crosslink, - set(sorted(full_committee)), - ) + expected_result = (candidates[0].data.crosslink, set(sorted(full_committee))) result = _find_winning_crosslink_and_attesting_indices_from_candidates( - state, - candidates, - config, + state, candidates, config ) assert result == expected_result -def test_get_base_reward(genesis_state, - config): +def test_get_base_reward(genesis_state, config): assert get_base_reward(genesis_state, 0, config) == 724077 diff --git a/tests/eth2/core/beacon/test_genesis.py b/tests/eth2/core/beacon/test_genesis.py index 27b127b19a..16b8a0cc0d 100644 --- a/tests/eth2/core/beacon/test_genesis.py +++ b/tests/eth2/core/beacon/test_genesis.py @@ -1,33 +1,23 @@ +from eth.constants import ZERO_HASH32 import pytest -from eth.constants import ( - ZERO_HASH32, -) - -from eth2.beacon.constants import ( - EMPTY_SIGNATURE, - JUSTIFICATION_BITS_LENGTH, +from eth2.beacon.constants import EMPTY_SIGNATURE, JUSTIFICATION_BITS_LENGTH +from eth2.beacon.genesis import ( + _genesis_time_from_eth1_timestamp, + get_genesis_block, + initialize_beacon_state_from_eth1, ) -from eth2.beacon.types.blocks import BeaconBlock, BeaconBlockBody +from eth2.beacon.tools.builder.initializer import create_mock_deposits_and_root from eth2.beacon.types.block_headers import BeaconBlockHeader +from eth2.beacon.types.blocks import BeaconBlock, BeaconBlockBody from eth2.beacon.types.crosslinks import Crosslink from eth2.beacon.types.eth1_data import Eth1Data from eth2.beacon.types.forks import Fork -from eth2.beacon.genesis import ( - get_genesis_block, - initialize_beacon_state_from_eth1, - _genesis_time_from_eth1_timestamp, -) -from eth2.beacon.tools.builder.initializer import ( - create_mock_deposits_and_root, -) -from eth2.beacon.typing import ( - Gwei, -) +from eth2.beacon.typing import Gwei def test_get_genesis_block(): - genesis_state_root = b'\x10' * 32 + genesis_state_root = b"\x10" * 32 genesis_slot = 0 genesis_block = get_genesis_block(genesis_state_root, BeaconBlock) assert genesis_block.slot == genesis_slot @@ -37,29 +27,21 @@ def test_get_genesis_block(): assert genesis_block.body.is_empty -@pytest.mark.parametrize( - ( - 'validator_count,' - ), - [ - (10) - ] -) +@pytest.mark.parametrize(("validator_count,"), [(10)]) def test_get_genesis_beacon_state( - validator_count, - pubkeys, - genesis_epoch, - genesis_slot, - shard_count, - slots_per_historical_root, - epochs_per_slashings_vector, - epochs_per_historical_vector, - config, - keymap): + validator_count, + pubkeys, + genesis_epoch, + genesis_slot, + shard_count, + slots_per_historical_root, + epochs_per_slashings_vector, + epochs_per_historical_vector, + config, + keymap, +): genesis_deposits, deposit_root = create_mock_deposits_and_root( - pubkeys=pubkeys[:validator_count], - keymap=keymap, - config=config, + pubkeys=pubkeys[:validator_count], keymap=keymap, config=config ) genesis_eth1_data = Eth1Data( @@ -83,7 +65,7 @@ def test_get_genesis_beacon_state( # History assert state.latest_block_header == BeaconBlockHeader( - body_root=BeaconBlockBody().hash_tree_root, + body_root=BeaconBlockBody().hash_tree_root ) assert len(state.block_roots) == slots_per_historical_root assert state.block_roots == (ZERO_HASH32,) * slots_per_historical_root diff --git a/tests/eth2/core/beacon/test_helpers.py b/tests/eth2/core/beacon/test_helpers.py index f449daa901..8dae1f4d8c 100644 --- a/tests/eth2/core/beacon/test_helpers.py +++ b/tests/eth2/core/beacon/test_helpers.py @@ -1,35 +1,21 @@ +from eth.constants import ZERO_HASH32 +from eth_utils import ValidationError, to_tuple import pytest -from eth_utils import ( - ValidationError, - to_tuple, -) - -from eth.constants import ( - ZERO_HASH32, -) - -from eth2._utils.hash import ( - hash_eth2, -) -from eth2.beacon.constants import ( - GWEI_PER_ETH, - FAR_FUTURE_EPOCH, -) - -from eth2.beacon.types.states import BeaconState -from eth2.beacon.types.forks import Fork -from eth2.beacon.types.validators import Validator - +from eth2._utils.hash import hash_eth2 +from eth2.beacon.constants import FAR_FUTURE_EPOCH, GWEI_PER_ETH from eth2.beacon.helpers import ( + _get_fork_version, _get_seed, + compute_start_slot_of_epoch, get_active_validator_indices, get_block_root_at_slot, - compute_start_slot_of_epoch, get_domain, - _get_fork_version, get_total_balance, ) +from eth2.beacon.types.forks import Fork +from eth2.beacon.types.states import BeaconState +from eth2.beacon.types.validators import Validator @to_tuple @@ -40,28 +26,19 @@ def get_pseudo_chain(length, genesis_block): block = genesis_block.copy() yield block for slot in range(1, length * 3): - block = genesis_block.copy( - slot=slot, - parent_root=block.signing_root - ) + block = genesis_block.copy(slot=slot, parent_root=block.signing_root) yield block def generate_mock_latest_historical_roots( - genesis_block, - current_slot, - slots_per_epoch, - slots_per_historical_root): + genesis_block, current_slot, slots_per_epoch, slots_per_historical_root +): assert current_slot < slots_per_historical_root chain_length = (current_slot // slots_per_epoch + 1) * slots_per_epoch blocks = get_pseudo_chain(chain_length, genesis_block) - block_roots = [ - block.signing_root - for block in blocks[:current_slot] - ] + [ - ZERO_HASH32 - for _ in range(slots_per_historical_root - current_slot) + block_roots = [block.signing_root for block in blocks[:current_slot]] + [ + ZERO_HASH32 for _ in range(slots_per_historical_root - current_slot) ] return blocks, block_roots @@ -70,9 +47,7 @@ def generate_mock_latest_historical_roots( # Get historical roots # @pytest.mark.parametrize( - ( - 'current_slot,target_slot,success' - ), + ("current_slot,target_slot,success"), [ (10, 0, True), (10, 9, True), @@ -82,49 +57,38 @@ def generate_mock_latest_historical_roots( (128, 128, False), ], ) -def test_get_block_root_at_slot(sample_beacon_state_params, - current_slot, - target_slot, - success, - slots_per_epoch, - slots_per_historical_root, - sample_block): +def test_get_block_root_at_slot( + sample_beacon_state_params, + current_slot, + target_slot, + success, + slots_per_epoch, + slots_per_historical_root, + sample_block, +): blocks, block_roots = generate_mock_latest_historical_roots( - sample_block, - current_slot, - slots_per_epoch, - slots_per_historical_root, + sample_block, current_slot, slots_per_epoch, slots_per_historical_root ) state = BeaconState(**sample_beacon_state_params).copy( - slot=current_slot, - block_roots=block_roots, + slot=current_slot, block_roots=block_roots ) if success: block_root = get_block_root_at_slot( - state, - target_slot, - slots_per_historical_root, + state, target_slot, slots_per_historical_root ) assert block_root == blocks[target_slot].signing_root else: with pytest.raises(ValidationError): - get_block_root_at_slot( - state, - target_slot, - slots_per_historical_root, - ) + get_block_root_at_slot(state, target_slot, slots_per_historical_root) def test_get_active_validator_indices(sample_validator_record_params): current_epoch = 1 # 3 validators are ACTIVE validators = [ - Validator( - **sample_validator_record_params, - ).copy( - activation_epoch=0, - exit_epoch=FAR_FUTURE_EPOCH, + Validator(**sample_validator_record_params).copy( + activation_epoch=0, exit_epoch=FAR_FUTURE_EPOCH ) for i in range(3) ] @@ -132,172 +96,119 @@ def test_get_active_validator_indices(sample_validator_record_params): assert len(active_validator_indices) == 3 validators[0] = validators[0].copy( - activation_epoch=current_epoch + 1, # activation_epoch > current_epoch + activation_epoch=current_epoch + 1 # activation_epoch > current_epoch ) active_validator_indices = get_active_validator_indices(validators, current_epoch) assert len(active_validator_indices) == 2 validators[1] = validators[1].copy( - exit_epoch=current_epoch, # current_epoch == exit_epoch + exit_epoch=current_epoch # current_epoch == exit_epoch ) active_validator_indices = get_active_validator_indices(validators, current_epoch) assert len(active_validator_indices) == 1 @pytest.mark.parametrize( - ( - 'balances,' - 'validator_indices,' - 'expected' - ), + ("balances," "validator_indices," "expected"), [ - ( - tuple(), - tuple(), - 1, - ), - ( - (32 * GWEI_PER_ETH, 32 * GWEI_PER_ETH), - (0, 1), - 64 * GWEI_PER_ETH, - ), - ( - (32 * GWEI_PER_ETH, 32 * GWEI_PER_ETH), - (1,), - 32 * GWEI_PER_ETH, - ), - ] + (tuple(), tuple(), 1), + ((32 * GWEI_PER_ETH, 32 * GWEI_PER_ETH), (0, 1), 64 * GWEI_PER_ETH), + ((32 * GWEI_PER_ETH, 32 * GWEI_PER_ETH), (1,), 32 * GWEI_PER_ETH), + ], ) -def test_get_total_balance(genesis_state, - balances, - validator_indices, - expected): +def test_get_total_balance(genesis_state, balances, validator_indices, expected): state = genesis_state for i, index in enumerate(validator_indices): - state = state._update_validator_balance( - index, - balances[i], - ) + state = state._update_validator_balance(index, balances[i]) total_balance = get_total_balance(state, validator_indices) assert total_balance == expected @pytest.mark.parametrize( - ( - 'previous_version,' - 'current_version,' - 'epoch,' - 'current_epoch,' - 'expected' - ), + ("previous_version," "current_version," "epoch," "current_epoch," "expected"), [ - (b'\x00' * 4, b'\x00' * 4, 0, 0, b'\x00' * 4), - (b'\x00' * 4, b'\x00' * 4, 0, 1, b'\x00' * 4), - (b'\x00' * 4, b'\x11' * 4, 20, 10, b'\x00' * 4), - (b'\x00' * 4, b'\x11' * 4, 20, 20, b'\x11' * 4), - (b'\x00' * 4, b'\x11' * 4, 10, 20, b'\x11' * 4), - ] + (b"\x00" * 4, b"\x00" * 4, 0, 0, b"\x00" * 4), + (b"\x00" * 4, b"\x00" * 4, 0, 1, b"\x00" * 4), + (b"\x00" * 4, b"\x11" * 4, 20, 10, b"\x00" * 4), + (b"\x00" * 4, b"\x11" * 4, 20, 20, b"\x11" * 4), + (b"\x00" * 4, b"\x11" * 4, 10, 20, b"\x11" * 4), + ], ) -def test_get_fork_version(previous_version, - current_version, - epoch, - current_epoch, - expected): +def test_get_fork_version( + previous_version, current_version, epoch, current_epoch, expected +): fork = Fork( - previous_version=previous_version, - current_version=current_version, - epoch=epoch, - ) - assert expected == _get_fork_version( - fork, - current_epoch, + previous_version=previous_version, current_version=current_version, epoch=epoch ) + assert expected == _get_fork_version(fork, current_epoch) @pytest.mark.parametrize( ( - 'previous_version,' - 'current_version,' - 'epoch,' - 'current_epoch,' - 'signature_domain,' - 'expected' + "previous_version," + "current_version," + "epoch," + "current_epoch," + "signature_domain," + "expected" ), [ - ( - b'\x11' * 4, - b'\x22' * 4, - 4, - 4, - 1, - b'\x01\x00\x00\x00' + b'\x22' * 4, - ), - ( - b'\x11' * 4, - b'\x22' * 4, - 4, - 4 - 1, - 1, - b'\x01\x00\x00\x00' + b'\x11' * 4, - ), - ] + (b"\x11" * 4, b"\x22" * 4, 4, 4, 1, b"\x01\x00\x00\x00" + b"\x22" * 4), + (b"\x11" * 4, b"\x22" * 4, 4, 4 - 1, 1, b"\x01\x00\x00\x00" + b"\x11" * 4), + ], ) -def test_get_domain(previous_version, - current_version, - epoch, - current_epoch, - signature_domain, - genesis_state, - slots_per_epoch, - expected): +def test_get_domain( + previous_version, + current_version, + epoch, + current_epoch, + signature_domain, + genesis_state, + slots_per_epoch, + expected, +): state = genesis_state fork = Fork( - previous_version=previous_version, - current_version=current_version, - epoch=epoch, + previous_version=previous_version, current_version=current_version, epoch=epoch ) assert expected == get_domain( - state=state.copy( - fork=fork, - ), + state=state.copy(fork=fork), signature_domain=signature_domain, slots_per_epoch=slots_per_epoch, message_epoch=current_epoch, ) -def test_get_seed(genesis_state, - committee_config, - slots_per_epoch, - min_seed_lookahead, - activation_exit_delay, - epochs_per_historical_vector): - def mock_get_randao_mix(state, - epoch, - epochs_per_historical_vector): +def test_get_seed( + genesis_state, + committee_config, + slots_per_epoch, + min_seed_lookahead, + activation_exit_delay, + epochs_per_historical_vector, +): + def mock_get_randao_mix(state, epoch, epochs_per_historical_vector): return hash_eth2( - state.hash_tree_root + - epoch.to_bytes(32, byteorder='little') + - epochs_per_historical_vector.to_bytes(32, byteorder='little') + state.hash_tree_root + + epoch.to_bytes(32, byteorder="little") + + epochs_per_historical_vector.to_bytes(32, byteorder="little") ) - def mock_get_active_index_root(state, - epoch, - epochs_per_historical_vector): + def mock_get_active_index_root(state, epoch, epochs_per_historical_vector): return hash_eth2( - state.hash_tree_root + - epoch.to_bytes(32, byteorder='little') + - slots_per_epoch.to_bytes(32, byteorder='little') + - epochs_per_historical_vector.to_bytes(32, byteorder='little') + state.hash_tree_root + + epoch.to_bytes(32, byteorder="little") + + slots_per_epoch.to_bytes(32, byteorder="little") + + epochs_per_historical_vector.to_bytes(32, byteorder="little") ) state = genesis_state epoch = 1 state = state.copy( - slot=compute_start_slot_of_epoch(epoch, committee_config.SLOTS_PER_EPOCH), + slot=compute_start_slot_of_epoch(epoch, committee_config.SLOTS_PER_EPOCH) ) - epoch_as_bytes = epoch.to_bytes(32, 'little') + epoch_as_bytes = epoch.to_bytes(32, "little") seed = _get_seed( state=state, @@ -312,9 +223,11 @@ def mock_get_active_index_root(state, state=state, epoch=(epoch + epochs_per_historical_vector - min_seed_lookahead - 1), epochs_per_historical_vector=epochs_per_historical_vector, - ) + mock_get_active_index_root( + ) + + mock_get_active_index_root( state=state, epoch=epoch, epochs_per_historical_vector=epochs_per_historical_vector, - ) + epoch_as_bytes + ) + + epoch_as_bytes ) diff --git a/tests/eth2/core/beacon/test_validator_status_helpers.py b/tests/eth2/core/beacon/test_validator_status_helpers.py index a625b03df1..37b893bdb3 100644 --- a/tests/eth2/core/beacon/test_validator_status_helpers.py +++ b/tests/eth2/core/beacon/test_validator_status_helpers.py @@ -1,26 +1,16 @@ import random +from eth_utils.toolz import first, groupby, update_in import pytest -from eth_utils.toolz import ( - first, - update_in, - groupby, -) - -from eth2.beacon.constants import ( - FAR_FUTURE_EPOCH, -) -from eth2.beacon.committee_helpers import ( - get_beacon_proposer_index, -) -from eth2.beacon.helpers import ( - compute_start_slot_of_epoch, -) +from eth2.beacon.committee_helpers import get_beacon_proposer_index +from eth2.beacon.constants import FAR_FUTURE_EPOCH from eth2.beacon.epoch_processing_helpers import ( compute_activation_exit_epoch, get_validator_churn_limit, ) +from eth2.beacon.helpers import compute_start_slot_of_epoch +from eth2.beacon.tools.builder.initializer import create_mock_validator from eth2.beacon.validator_status_helpers import ( _compute_exit_queue_epoch, _set_validator_slashed, @@ -28,29 +18,14 @@ initiate_exit_for_validator, slash_validator, ) -from eth2.beacon.tools.builder.initializer import ( - create_mock_validator, -) -from eth2.configs import ( - CommitteeConfig, -) +from eth2.configs import CommitteeConfig -@pytest.mark.parametrize( - ( - 'is_already_activated,' - ), - [ - (True), - (False), - ] -) -def test_activate_validator(genesis_state, - is_already_activated, - validator_count, - pubkeys, - config): - some_future_epoch = config.GENESIS_EPOCH + random.randrange(1, 2**32) +@pytest.mark.parametrize(("is_already_activated,"), [(True), (False)]) +def test_activate_validator( + genesis_state, is_already_activated, validator_count, pubkeys, config +): + some_future_epoch = config.GENESIS_EPOCH + random.randrange(1, 2 ** 32) if is_already_activated: assert validator_count > 0 @@ -59,9 +34,7 @@ def test_activate_validator(genesis_state, assert some_validator.activation_epoch == config.GENESIS_EPOCH else: some_validator = create_mock_validator( - pubkeys[:validator_count + 1], - config, - is_active=is_already_activated, + pubkeys[: validator_count + 1], config, is_active=is_already_activated ) assert some_validator.activation_eligibility_epoch == FAR_FUTURE_EPOCH assert some_validator.activation_epoch == FAR_FUTURE_EPOCH @@ -74,54 +47,36 @@ def test_activate_validator(genesis_state, @pytest.mark.parametrize( - ( - "is_delayed_exit_epoch_the_maximum_exit_queue_epoch" - ), - [ - (True,), - (False,), - ] + ("is_delayed_exit_epoch_the_maximum_exit_queue_epoch"), [(True,), (False,)] ) -@pytest.mark.parametrize( - ( - "is_churn_limit_met" - ), - [ - (True,), - (False,), - ] -) -def test_compute_exit_queue_epoch(genesis_state, - is_delayed_exit_epoch_the_maximum_exit_queue_epoch, - is_churn_limit_met, - config): +@pytest.mark.parametrize(("is_churn_limit_met"), [(True,), (False,)]) +def test_compute_exit_queue_epoch( + genesis_state, + is_delayed_exit_epoch_the_maximum_exit_queue_epoch, + is_churn_limit_met, + config, +): state = genesis_state - for index in random.sample(range(len(state.validators)), len(state.validators) // 4): - some_future_epoch = config.GENESIS_EPOCH + random.randrange(1, 2**32) + for index in random.sample( + range(len(state.validators)), len(state.validators) // 4 + ): + some_future_epoch = config.GENESIS_EPOCH + random.randrange(1, 2 ** 32) state = state.update_validator_with_fn( - index, - lambda validator, *_: validator.copy( - exit_epoch=some_future_epoch, - ), + index, lambda validator, *_: validator.copy(exit_epoch=some_future_epoch) ) if is_delayed_exit_epoch_the_maximum_exit_queue_epoch: expected_candidate_exit_queue_epoch = compute_activation_exit_epoch( - state.current_epoch(config.SLOTS_PER_EPOCH), - config.ACTIVATION_EXIT_DELAY, + state.current_epoch(config.SLOTS_PER_EPOCH), config.ACTIVATION_EXIT_DELAY ) for index, validator in enumerate(state.validators): if validator.exit_epoch == FAR_FUTURE_EPOCH: continue some_prior_epoch = random.randrange( - config.GENESIS_EPOCH, - expected_candidate_exit_queue_epoch, + config.GENESIS_EPOCH, expected_candidate_exit_queue_epoch ) state = state.update_validator_with_fn( - index, - lambda validator, *_: validator.copy( - exit_epoch=some_prior_epoch, - ), + index, lambda validator, *_: validator.copy(exit_epoch=some_prior_epoch) ) validator = state.validators[index] assert expected_candidate_exit_queue_epoch >= validator.exit_epoch @@ -146,44 +101,37 @@ def test_compute_exit_queue_epoch(genesis_state, if validator.exit_epoch == expected_candidate_exit_queue_epoch } additional_queued_validator_count = random.randrange( - len(queued_validators), - len(state.validators), + len(queued_validators), len(state.validators) ) unqueued_validators = tuple( - v for v in state.validators - if v.exit_epoch == FAR_FUTURE_EPOCH + v for v in state.validators if v.exit_epoch == FAR_FUTURE_EPOCH ) for index in random.sample( - range(len(unqueued_validators)), - additional_queued_validator_count, + range(len(unqueued_validators)), additional_queued_validator_count ): state = state.update_validator_with_fn( index, lambda validator, *_: validator.copy( - exit_epoch=expected_candidate_exit_queue_epoch, + exit_epoch=expected_candidate_exit_queue_epoch ), ) all_queued_validators = tuple( - v for v in state.validators + v + for v in state.validators if v.exit_epoch == expected_candidate_exit_queue_epoch ) churn_limit = len(all_queued_validators) + 1 expected_exit_queue_epoch = expected_candidate_exit_queue_epoch - assert _compute_exit_queue_epoch(state, churn_limit, config) == expected_exit_queue_epoch + assert ( + _compute_exit_queue_epoch(state, churn_limit, config) + == expected_exit_queue_epoch + ) -@pytest.mark.parametrize( - ( - 'is_already_exited,' - ), - [ - (True), - (False), - ] -) +@pytest.mark.parametrize(("is_already_exited,"), [(True), (False)]) def test_initiate_validator_exit(genesis_state, is_already_exited, config): state = genesis_state index = random.choice(range(len(state.validators))) @@ -199,14 +147,11 @@ def test_initiate_validator_exit(genesis_state, is_already_exited, config): exit_queue_epoch = _compute_exit_queue_epoch(state, churn_limit, config) validator = validator.copy( exit_epoch=exit_queue_epoch, - withdrawable_epoch=exit_queue_epoch + config.MIN_VALIDATOR_WITHDRAWABILITY_DELAY, + withdrawable_epoch=exit_queue_epoch + + config.MIN_VALIDATOR_WITHDRAWABILITY_DELAY, ) - exited_validator = initiate_exit_for_validator( - validator, - state, - config, - ) + exited_validator = initiate_exit_for_validator(validator, state, config) if is_already_exited: assert exited_validator == validator @@ -219,29 +164,18 @@ def test_initiate_validator_exit(genesis_state, is_already_exited, config): ) -@pytest.mark.parametrize( - ( - 'is_already_slashed,' - ), - [ - (True), - (False), - ] -) -def test_set_validator_slashed(genesis_state, - is_already_slashed, - validator_count, - pubkeys, - config): - some_future_epoch = config.GENESIS_EPOCH + random.randrange(1, 2**32) +@pytest.mark.parametrize(("is_already_slashed,"), [(True), (False)]) +def test_set_validator_slashed( + genesis_state, is_already_slashed, validator_count, pubkeys, config +): + some_future_epoch = config.GENESIS_EPOCH + random.randrange(1, 2 ** 32) assert len(genesis_state.validators) > 0 some_validator = genesis_state.validators[0] if is_already_slashed: some_validator = some_validator.copy( - slashed=True, - withdrawable_epoch=some_future_epoch, + slashed=True, withdrawable_epoch=some_future_epoch ) assert some_validator.slashed assert some_validator.withdrawable_epoch == some_future_epoch @@ -249,50 +183,42 @@ def test_set_validator_slashed(genesis_state, assert not some_validator.slashed slashed_validator = _set_validator_slashed( - some_validator, - some_future_epoch, - config.EPOCHS_PER_SLASHINGS_VECTOR, + some_validator, some_future_epoch, config.EPOCHS_PER_SLASHINGS_VECTOR ) assert slashed_validator.slashed assert slashed_validator.withdrawable_epoch == max( slashed_validator.withdrawable_epoch, - some_future_epoch + config.EPOCHS_PER_SLASHINGS_VECTOR + some_future_epoch + config.EPOCHS_PER_SLASHINGS_VECTOR, ) -@pytest.mark.parametrize( - ( - 'validator_count' - ), - [ - (100), - ] -) -def test_slash_validator(genesis_state, - config): +@pytest.mark.parametrize(("validator_count"), [(100)]) +def test_slash_validator(genesis_state, config): some_epoch = ( - config.GENESIS_EPOCH + random.randrange(1, 2**32) + config.EPOCHS_PER_SLASHINGS_VECTOR + config.GENESIS_EPOCH + + random.randrange(1, 2 ** 32) + + config.EPOCHS_PER_SLASHINGS_VECTOR ) earliest_slashable_epoch = some_epoch - config.EPOCHS_PER_SLASHINGS_VECTOR slashable_range = range(earliest_slashable_epoch, some_epoch) sampling_quotient = 4 state = genesis_state.copy( - slot=compute_start_slot_of_epoch(earliest_slashable_epoch, config.SLOTS_PER_EPOCH), + slot=compute_start_slot_of_epoch( + earliest_slashable_epoch, config.SLOTS_PER_EPOCH + ) ) validator_count_to_slash = len(state.validators) // sampling_quotient assert validator_count_to_slash > 1 validator_indices_to_slash = random.sample( - range(len(state.validators)), - validator_count_to_slash, + range(len(state.validators)), validator_count_to_slash ) # ensure case w/ one slashing in an epoch # by ignoring the first set_of_colluding_validators = validator_indices_to_slash[1:] # simulate multiple slashings in an epoch validators_grouped_by_coalition = groupby( - lambda index: index % sampling_quotient, - set_of_colluding_validators, + lambda index: index % sampling_quotient, set_of_colluding_validators ) coalition_count = len(validators_grouped_by_coalition) slashings = { @@ -326,9 +252,10 @@ def test_slash_validator(genesis_state, expected_individual_penalties, [index], lambda penalty: ( - penalty + ( - state.validators[index].effective_balance // - config.MIN_SLASHING_PENALTY_QUOTIENT + penalty + + ( + state.validators[index].effective_balance + // config.MIN_SLASHING_PENALTY_QUOTIENT ) ), default=0, @@ -347,9 +274,8 @@ def test_slash_validator(genesis_state, expected_proposer_rewards = update_in( expected_proposer_rewards, [proposer_index], - lambda reward: reward + ( - expected_total_slashed_balance // config.WHISTLEBLOWER_REWARD_QUOTIENT - ), + lambda reward: reward + + (expected_total_slashed_balance // config.WHISTLEBLOWER_REWARD_QUOTIENT), default=0, ) for index in coalition: @@ -373,14 +299,14 @@ def test_slash_validator(genesis_state, slashed_balance = state.slashings[slashed_epoch_index] assert slashed_balance == expected_slashings[epoch] assert state.balances[index] == ( - config.MAX_EFFECTIVE_BALANCE - - expected_individual_penalties[index] + - expected_proposer_rewards.get(index, 0) + config.MAX_EFFECTIVE_BALANCE + - expected_individual_penalties[index] + + expected_proposer_rewards.get(index, 0) ) for proposer_index, total_reward in expected_proposer_rewards.items(): assert state.balances[proposer_index] == ( - total_reward + - config.MAX_EFFECTIVE_BALANCE - - expected_individual_penalties.get(proposer_index, 0) + total_reward + + config.MAX_EFFECTIVE_BALANCE + - expected_individual_penalties.get(proposer_index, 0) ) diff --git a/tests/eth2/core/beacon/tools/builder/test_committee_assignment.py b/tests/eth2/core/beacon/tools/builder/test_committee_assignment.py index 652dd22f48..084c5d229e 100644 --- a/tests/eth2/core/beacon/tools/builder/test_committee_assignment.py +++ b/tests/eth2/core/beacon/tools/builder/test_committee_assignment.py @@ -1,61 +1,46 @@ import pytest -from eth2.beacon.exceptions import ( - NoCommitteeAssignment, -) -from eth2.beacon.helpers import ( - compute_start_slot_of_epoch, -) - -from eth2.beacon.tools.builder.committee_assignment import ( - get_committee_assignment, -) +from eth2.beacon.exceptions import NoCommitteeAssignment +from eth2.beacon.helpers import compute_start_slot_of_epoch +from eth2.beacon.tools.builder.committee_assignment import get_committee_assignment @pytest.mark.parametrize( ( - 'validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - 'state_epoch,' - 'epoch,' + "validator_count," + "slots_per_epoch," + "target_committee_size," + "shard_count," + "state_epoch," + "epoch," ), [ (40, 16, 1, 16, 0, 0), # genesis (40, 16, 1, 16, 1, 1), # current epoch (40, 16, 1, 16, 1, 0), # previous epoch (40, 16, 1, 16, 1, 2), # next epoch - ] + ], ) -def test_get_committee_assignment(genesis_state, - slots_per_epoch, - shard_count, - config, - validator_count, - state_epoch, - epoch, - fixture_sm_class): +def test_get_committee_assignment( + genesis_state, + slots_per_epoch, + shard_count, + config, + validator_count, + state_epoch, + epoch, + fixture_sm_class, +): state_slot = compute_start_slot_of_epoch(state_epoch, slots_per_epoch) - state = genesis_state.copy( - slot=state_slot, - ) + state = genesis_state.copy(slot=state_slot) proposer_count = 0 - shard_validator_count = [ - 0 - for _ in range(shard_count) - ] + shard_validator_count = [0 for _ in range(shard_count)] slots = [] epoch_start_slot = compute_start_slot_of_epoch(epoch, slots_per_epoch) for validator_index in range(validator_count): - assignment = get_committee_assignment( - state, - config, - epoch, - validator_index, - ) + assignment = get_committee_assignment(state, config, epoch, validator_index) assert assignment.slot >= epoch_start_slot assert assignment.slot < epoch_start_slot + slots_per_epoch if assignment.is_proposer: @@ -69,30 +54,17 @@ def test_get_committee_assignment(genesis_state, @pytest.mark.parametrize( - ( - 'validator_count,' - 'slots_per_epoch,' - 'target_committee_size,' - 'shard_count,' - ), - [ - (40, 16, 1, 16), - ] + ("validator_count," "slots_per_epoch," "target_committee_size," "shard_count,"), + [(40, 16, 1, 16)], ) -def test_get_committee_assignment_no_assignment(genesis_state, - genesis_epoch, - slots_per_epoch, - config): +def test_get_committee_assignment_no_assignment( + genesis_state, genesis_epoch, slots_per_epoch, config +): state = genesis_state validator_index = 1 current_epoch = state.current_epoch(slots_per_epoch) - validator = state.validators[validator_index].copy( - exit_epoch=genesis_epoch, - ) - state = state.update_validator( - validator_index, - validator, - ) + validator = state.validators[validator_index].copy(exit_epoch=genesis_epoch) + state = state.update_validator(validator_index, validator) assert not validator.is_active(current_epoch) with pytest.raises(NoCommitteeAssignment): diff --git a/tests/eth2/core/beacon/tools/builder/test_validator.py b/tests/eth2/core/beacon/tools/builder/test_validator.py index e60af3917d..a4cd158f1b 100644 --- a/tests/eth2/core/beacon/tools/builder/test_validator.py +++ b/tests/eth2/core/beacon/tools/builder/test_validator.py @@ -1,60 +1,29 @@ +from eth_utils import ValidationError +from hypothesis import given, settings +from hypothesis import strategies as st import pytest -from hypothesis import ( - given, - settings, - strategies as st, -) - -from eth_utils import ( - ValidationError, -) +from eth2._utils.bitfield import get_empty_bitfield, has_voted from eth2._utils.bls import bls -from eth2._utils.bls.backends.chia import ( - ChiaBackend, -) -from eth2._utils.bls.backends.milagro import ( - MilagroBackend, -) -from eth2._utils.bitfield import ( - get_empty_bitfield, - has_voted, -) -from eth2.beacon.helpers import ( - compute_domain, -) +from eth2._utils.bls.backends.chia import ChiaBackend +from eth2._utils.bls.backends.milagro import MilagroBackend +from eth2.beacon.helpers import compute_domain from eth2.beacon.signature_domain import SignatureDomain -from eth2.beacon.tools.builder.validator import ( - aggregate_votes, - verify_votes, -) +from eth2.beacon.tools.builder.validator import aggregate_votes, verify_votes @pytest.mark.slow -@settings( - max_examples=1, - deadline=None, -) +@settings(max_examples=1, deadline=None) @given(random=st.randoms()) -@pytest.mark.parametrize( - ( - 'votes_count' - ), - [ - (0), - (9), - ], -) +@pytest.mark.parametrize(("votes_count"), [(0), (9)]) def test_aggregate_votes(votes_count, random, privkeys, pubkeys): bit_count = 10 pre_bitfield = get_empty_bitfield(bit_count) pre_sigs = () - domain = compute_domain( - SignatureDomain.DOMAIN_ATTESTATION, - ) + domain = compute_domain(SignatureDomain.DOMAIN_ATTESTATION) random_votes = random.sample(range(bit_count), votes_count) - message_hash = b'\x12' * 32 + message_hash = b"\x12" * 32 # Get votes: (committee_index, sig, public_key) votes = [ @@ -74,7 +43,7 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys): bitfield=pre_bitfield, sigs=pre_sigs, voting_sigs=sigs, - attesting_indices=committee_indices + attesting_indices=committee_indices, ) try: diff --git a/tests/eth2/core/beacon/types/test_attestation.py b/tests/eth2/core/beacon/types/test_attestation.py index 8a38b94062..043e3b10f2 100644 --- a/tests/eth2/core/beacon/types/test_attestation.py +++ b/tests/eth2/core/beacon/types/test_attestation.py @@ -1,18 +1,10 @@ import pytest - import ssz -from eth2.beacon.types.attestations import ( - Attestation, -) +from eth2.beacon.types.attestations import Attestation -@pytest.mark.parametrize( - 'param,default_value', - [ - ('signature', b'\x00' * 96), - ] -) +@pytest.mark.parametrize("param,default_value", [("signature", b"\x00" * 96)]) def test_defaults(param, default_value, sample_attestation_params): del sample_attestation_params[param] attestation = Attestation(**sample_attestation_params) diff --git a/tests/eth2/core/beacon/types/test_attestation_data.py b/tests/eth2/core/beacon/types/test_attestation_data.py index c81aadb5de..c3b6e82e31 100644 --- a/tests/eth2/core/beacon/types/test_attestation_data.py +++ b/tests/eth2/core/beacon/types/test_attestation_data.py @@ -1,9 +1,9 @@ -from eth2.beacon.types.attestation_data import ( - AttestationData, -) +from eth2.beacon.types.attestation_data import AttestationData def test_defaults(sample_attestation_data_params): attestation_data = AttestationData(**sample_attestation_data_params) - assert attestation_data.source.epoch == sample_attestation_data_params['source'].epoch + assert ( + attestation_data.source.epoch == sample_attestation_data_params["source"].epoch + ) diff --git a/tests/eth2/core/beacon/types/test_attestation_data_and_custody_bit.py b/tests/eth2/core/beacon/types/test_attestation_data_and_custody_bit.py index c3da1725dd..3adad08c38 100644 --- a/tests/eth2/core/beacon/types/test_attestation_data_and_custody_bit.py +++ b/tests/eth2/core/beacon/types/test_attestation_data_and_custody_bit.py @@ -7,5 +7,5 @@ def test_defaults(sample_attestation_data_and_custody_bit_params): params = sample_attestation_data_and_custody_bit_params attestation_data_and_custody_bit = AttestationDataAndCustodyBit(**params) - assert attestation_data_and_custody_bit.data == params['data'] - assert attestation_data_and_custody_bit.custody_bit == params['custody_bit'] + assert attestation_data_and_custody_bit.data == params["data"] + assert attestation_data_and_custody_bit.custody_bit == params["custody_bit"] diff --git a/tests/eth2/core/beacon/types/test_attester_slashings.py b/tests/eth2/core/beacon/types/test_attester_slashings.py index 672a534639..c6b10d3de1 100644 --- a/tests/eth2/core/beacon/types/test_attester_slashings.py +++ b/tests/eth2/core/beacon/types/test_attester_slashings.py @@ -7,12 +7,12 @@ def test_defaults(sample_attester_slashing_params): attester_slashing = AttesterSlashing(**sample_attester_slashing_params) assert ( - attester_slashing.attestation_1.custody_bit_0_indices == - sample_attester_slashing_params['attestation_1'].custody_bit_0_indices + attester_slashing.attestation_1.custody_bit_0_indices + == sample_attester_slashing_params["attestation_1"].custody_bit_0_indices ) assert ( - attester_slashing.attestation_2.data == - sample_attester_slashing_params['attestation_2'].data + attester_slashing.attestation_2.data + == sample_attester_slashing_params["attestation_2"].data ) assert ssz.encode(attester_slashing) diff --git a/tests/eth2/core/beacon/types/test_block.py b/tests/eth2/core/beacon/types/test_block.py index 183bbc6895..21b9050398 100644 --- a/tests/eth2/core/beacon/types/test_block.py +++ b/tests/eth2/core/beacon/types/test_block.py @@ -1,19 +1,11 @@ -from eth2.beacon.types.attestations import ( - Attestation, -) -from eth2.beacon.types.blocks import ( - BeaconBlock, - BeaconBlockBody, -) - -from eth2.beacon.typing import ( - FromBlockParams, -) +from eth2.beacon.types.attestations import Attestation +from eth2.beacon.types.blocks import BeaconBlock, BeaconBlockBody +from eth2.beacon.typing import FromBlockParams def test_defaults(sample_beacon_block_params): block = BeaconBlock(**sample_beacon_block_params) - assert block.slot == sample_beacon_block_params['slot'] + assert block.slot == sample_beacon_block_params["slot"] assert block.is_genesis @@ -29,12 +21,8 @@ def test_update_attestations(sample_attestation_params, sample_beacon_block_para attestations = block.body.attestations attestations = list(attestations) attestations.append(Attestation(**sample_attestation_params)) - body2 = block.body.copy( - attestations=attestations - ) - block2 = block.copy( - body=body2 - ) + body2 = block.body.copy(attestations=attestations) + block2 = block.copy(body=body2) assert len(block2.body.attestations) == 1 @@ -50,7 +38,7 @@ def test_block_body_empty(sample_attestation_params): assert block_body.is_empty block_body = block_body.copy( - attestations=(Attestation(**sample_attestation_params),), + attestations=(Attestation(**sample_attestation_params),) ) assert not block_body.is_empty diff --git a/tests/eth2/core/beacon/types/test_crosslink_record.py b/tests/eth2/core/beacon/types/test_crosslink_record.py index 604aa07307..2b31afd258 100644 --- a/tests/eth2/core/beacon/types/test_crosslink_record.py +++ b/tests/eth2/core/beacon/types/test_crosslink_record.py @@ -1,9 +1,7 @@ -from eth2.beacon.types.crosslinks import ( - Crosslink, -) +from eth2.beacon.types.crosslinks import Crosslink def test_defaults(sample_crosslink_record_params): crosslink = Crosslink(**sample_crosslink_record_params) - assert crosslink.start_epoch == sample_crosslink_record_params['start_epoch'] - assert crosslink.data_root == sample_crosslink_record_params['data_root'] + assert crosslink.start_epoch == sample_crosslink_record_params["start_epoch"] + assert crosslink.data_root == sample_crosslink_record_params["data_root"] diff --git a/tests/eth2/core/beacon/types/test_deposit_data.py b/tests/eth2/core/beacon/types/test_deposit_data.py index 9bdcac7fec..d1791b6f83 100644 --- a/tests/eth2/core/beacon/types/test_deposit_data.py +++ b/tests/eth2/core/beacon/types/test_deposit_data.py @@ -4,5 +4,5 @@ def test_defaults(sample_deposit_data_params): deposit_data = DepositData(**sample_deposit_data_params) - assert deposit_data.pubkey == sample_deposit_data_params['pubkey'] - assert deposit_data.amount == sample_deposit_data_params['amount'] + assert deposit_data.pubkey == sample_deposit_data_params["pubkey"] + assert deposit_data.amount == sample_deposit_data_params["amount"] diff --git a/tests/eth2/core/beacon/types/test_deposits.py b/tests/eth2/core/beacon/types/test_deposits.py index c1bfa7af18..f2212d3e20 100644 --- a/tests/eth2/core/beacon/types/test_deposits.py +++ b/tests/eth2/core/beacon/types/test_deposits.py @@ -4,4 +4,4 @@ def test_defaults(sample_deposit_params): deposit = Deposit(**sample_deposit_params) - assert deposit.data == sample_deposit_params['data'] + assert deposit.data == sample_deposit_params["data"] diff --git a/tests/eth2/core/beacon/types/test_eth1_data.py b/tests/eth2/core/beacon/types/test_eth1_data.py index 9d6b04b950..1be0ff266d 100644 --- a/tests/eth2/core/beacon/types/test_eth1_data.py +++ b/tests/eth2/core/beacon/types/test_eth1_data.py @@ -1,11 +1,7 @@ -from eth2.beacon.types.eth1_data import ( - Eth1Data, -) +from eth2.beacon.types.eth1_data import Eth1Data def test_defaults(sample_eth1_data_params): - eth1_data = Eth1Data( - **sample_eth1_data_params, - ) - assert eth1_data.deposit_root == sample_eth1_data_params['deposit_root'] - assert eth1_data.block_hash == sample_eth1_data_params['block_hash'] + eth1_data = Eth1Data(**sample_eth1_data_params) + assert eth1_data.deposit_root == sample_eth1_data_params["deposit_root"] + assert eth1_data.block_hash == sample_eth1_data_params["block_hash"] diff --git a/tests/eth2/core/beacon/types/test_fork.py b/tests/eth2/core/beacon/types/test_fork.py index 349565ffd4..e27df949b1 100644 --- a/tests/eth2/core/beacon/types/test_fork.py +++ b/tests/eth2/core/beacon/types/test_fork.py @@ -1,13 +1,11 @@ import ssz -from eth2.beacon.types.forks import ( - Fork, -) +from eth2.beacon.types.forks import Fork def test_defaults(sample_fork_params): fork = Fork(**sample_fork_params) - assert fork.previous_version == sample_fork_params['previous_version'] - assert fork.current_version == sample_fork_params['current_version'] - assert fork.epoch == sample_fork_params['epoch'] + assert fork.previous_version == sample_fork_params["previous_version"] + assert fork.current_version == sample_fork_params["current_version"] + assert fork.epoch == sample_fork_params["epoch"] assert ssz.encode(fork) diff --git a/tests/eth2/core/beacon/types/test_pending_attestation_record.py b/tests/eth2/core/beacon/types/test_pending_attestation_record.py index f8e690f6ad..cec6f27aae 100644 --- a/tests/eth2/core/beacon/types/test_pending_attestation_record.py +++ b/tests/eth2/core/beacon/types/test_pending_attestation_record.py @@ -1,15 +1,22 @@ import ssz -from eth2.beacon.types.pending_attestations import ( - PendingAttestation, -) +from eth2.beacon.types.pending_attestations import PendingAttestation def test_defaults(sample_pending_attestation_record_params): pending_attestation = PendingAttestation(**sample_pending_attestation_record_params) - assert pending_attestation.data == sample_pending_attestation_record_params['data'] - assert pending_attestation.aggregation_bits == sample_pending_attestation_record_params['aggregation_bits'] # noqa: E501 - assert pending_attestation.inclusion_delay == sample_pending_attestation_record_params['inclusion_delay'] # noqa: E501 - assert pending_attestation.proposer_index == sample_pending_attestation_record_params['proposer_index'] # noqa: E501 + assert pending_attestation.data == sample_pending_attestation_record_params["data"] + assert ( + pending_attestation.aggregation_bits + == sample_pending_attestation_record_params["aggregation_bits"] + ) # noqa: E501 + assert ( + pending_attestation.inclusion_delay + == sample_pending_attestation_record_params["inclusion_delay"] + ) # noqa: E501 + assert ( + pending_attestation.proposer_index + == sample_pending_attestation_record_params["proposer_index"] + ) # noqa: E501 assert ssz.encode(pending_attestation) diff --git a/tests/eth2/core/beacon/types/test_proposer_slashings.py b/tests/eth2/core/beacon/types/test_proposer_slashings.py index a36ac88217..df91740c95 100644 --- a/tests/eth2/core/beacon/types/test_proposer_slashings.py +++ b/tests/eth2/core/beacon/types/test_proposer_slashings.py @@ -1,13 +1,11 @@ import ssz -from eth2.beacon.types.proposer_slashings import ( - ProposerSlashing, -) +from eth2.beacon.types.proposer_slashings import ProposerSlashing def test_defaults(sample_proposer_slashing_params): slashing = ProposerSlashing(**sample_proposer_slashing_params) - assert slashing.proposer_index == sample_proposer_slashing_params['proposer_index'] - assert slashing.header_1 == sample_proposer_slashing_params['header_1'] - assert slashing.header_2 == sample_proposer_slashing_params['header_2'] + assert slashing.proposer_index == sample_proposer_slashing_params["proposer_index"] + assert slashing.header_1 == sample_proposer_slashing_params["header_1"] + assert slashing.header_2 == sample_proposer_slashing_params["header_2"] assert ssz.encode(slashing) diff --git a/tests/eth2/core/beacon/types/test_states.py b/tests/eth2/core/beacon/types/test_states.py index 3db6c6071e..df935600fc 100644 --- a/tests/eth2/core/beacon/types/test_states.py +++ b/tests/eth2/core/beacon/types/test_states.py @@ -1,19 +1,13 @@ import pytest - import ssz -from eth2.beacon.types.states import ( - BeaconState, -) - -from eth2.beacon.tools.builder.initializer import ( - create_mock_validator, -) +from eth2.beacon.tools.builder.initializer import create_mock_validator +from eth2.beacon.types.states import BeaconState def test_defaults(sample_beacon_state_params): state = BeaconState(**sample_beacon_state_params) - assert state.validators == sample_beacon_state_params['validators'] + assert state.validators == sample_beacon_state_params["validators"] assert ssz.encode(state) @@ -22,34 +16,30 @@ def test_validators_and_balances_length(sample_beacon_state_params, config): with pytest.raises(ValueError): BeaconState(**sample_beacon_state_params).copy( validators=tuple( - create_mock_validator(pubkey, config) - for pubkey in range(10) - ), + create_mock_validator(pubkey, config) for pubkey in range(10) + ) ) @pytest.mark.parametrize( - 'validator_index_offset, new_pubkey, new_balance', - [ - (0, 5566, 100), - (100, 5566, 100), - ] + "validator_index_offset, new_pubkey, new_balance", + [(0, 5566, 100), (100, 5566, 100)], ) -def test_update_validator(genesis_state, - validator_index_offset, - validator_count, - new_pubkey, - new_balance, - config): +def test_update_validator( + genesis_state, + validator_index_offset, + validator_count, + new_pubkey, + new_balance, + config, +): state = genesis_state validator = create_mock_validator(new_pubkey, config) validator_index = validator_count + validator_index_offset if validator_index < state.validator_count: result_state = state.update_validator( - validator_index=validator_index, - validator=validator, - balance=new_balance, + validator_index=validator_index, validator=validator, balance=new_balance ) assert result_state.balances[validator_index] == new_balance assert result_state.validators[validator_index].pubkey == new_pubkey diff --git a/tests/eth2/core/beacon/types/test_transfer.py b/tests/eth2/core/beacon/types/test_transfer.py index b3811570b4..166d0f9011 100644 --- a/tests/eth2/core/beacon/types/test_transfer.py +++ b/tests/eth2/core/beacon/types/test_transfer.py @@ -1,12 +1,10 @@ import ssz -from eth2.beacon.types.transfers import ( - Transfer, -) +from eth2.beacon.types.transfers import Transfer def test_defaults(sample_transfer_params): transfer = Transfer(**sample_transfer_params) - assert transfer.recipient == sample_transfer_params['recipient'] + assert transfer.recipient == sample_transfer_params["recipient"] assert ssz.encode(transfer) diff --git a/tests/eth2/core/beacon/types/test_validator_record.py b/tests/eth2/core/beacon/types/test_validator_record.py index 084f515909..6ab6f58ccb 100644 --- a/tests/eth2/core/beacon/types/test_validator_record.py +++ b/tests/eth2/core/beacon/types/test_validator_record.py @@ -1,39 +1,30 @@ import pytest -from eth2.beacon.constants import ( - FAR_FUTURE_EPOCH, - GWEI_PER_ETH, -) -from eth2.beacon.types.validators import ( - Validator, -) +from eth2.beacon.constants import FAR_FUTURE_EPOCH, GWEI_PER_ETH +from eth2.beacon.types.validators import Validator from eth2.beacon.typing import Gwei def test_defaults(sample_validator_record_params): validator = Validator(**sample_validator_record_params) - assert validator.pubkey == sample_validator_record_params['pubkey'] - assert validator.withdrawal_credentials == sample_validator_record_params['withdrawal_credentials'] # noqa: E501 + assert validator.pubkey == sample_validator_record_params["pubkey"] + assert ( + validator.withdrawal_credentials + == sample_validator_record_params["withdrawal_credentials"] + ) # noqa: E501 @pytest.mark.parametrize( - 'activation_epoch,exit_epoch,epoch,expected', - [ - (0, 1, 0, True), - (1, 1, 1, False), - (0, 1, 1, False), - (0, 1, 2, False), - ], + "activation_epoch,exit_epoch,epoch,expected", + [(0, 1, 0, True), (1, 1, 1, False), (0, 1, 1, False), (0, 1, 2, False)], ) -def test_is_active(sample_validator_record_params, - activation_epoch, - exit_epoch, - epoch, - expected): +def test_is_active( + sample_validator_record_params, activation_epoch, exit_epoch, epoch, expected +): validator_record_params = { **sample_validator_record_params, - 'activation_epoch': activation_epoch, - 'exit_epoch': exit_epoch, + "activation_epoch": activation_epoch, + "exit_epoch": exit_epoch, } validator = Validator(**validator_record_params) assert validator.is_active(epoch) == expected @@ -41,7 +32,7 @@ def test_is_active(sample_validator_record_params, def test_create_pending_validator(config): pubkey = 123 - withdrawal_credentials = b'\x11' * 32 + withdrawal_credentials = b"\x11" * 32 effective_balance = 22 * GWEI_PER_ETH amount = Gwei(effective_balance + config.EFFECTIVE_BALANCE_INCREMENT // 2) diff --git a/tests/eth2/core/beacon/types/test_voluntary_exits.py b/tests/eth2/core/beacon/types/test_voluntary_exits.py index 4a052e47a7..69d5c38274 100644 --- a/tests/eth2/core/beacon/types/test_voluntary_exits.py +++ b/tests/eth2/core/beacon/types/test_voluntary_exits.py @@ -1,12 +1,10 @@ import ssz -from eth2.beacon.types.voluntary_exits import ( - VoluntaryExit, -) +from eth2.beacon.types.voluntary_exits import VoluntaryExit def test_defaults(sample_voluntary_exit_params): exit = VoluntaryExit(**sample_voluntary_exit_params) - assert exit.signature[0] == sample_voluntary_exit_params['signature'][0] + assert exit.signature[0] == sample_voluntary_exit_params["signature"][0] assert ssz.encode(exit) diff --git a/tests/eth2/fixtures/bls-fixtures/test_bls.py b/tests/eth2/fixtures/bls-fixtures/test_bls.py deleted file mode 100644 index f6595194da..0000000000 --- a/tests/eth2/fixtures/bls-fixtures/test_bls.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import ( - Tuple, -) -from dataclasses import ( - dataclass, -) -import pytest - -from py_ecc.bls.typing import Domain - -from eth2._utils.bls import bls -from eth2._utils.bls.backends import ( - MilagroBackend, -) -from eth2.beacon.tools.fixtures.loading import ( - get_input_bls_privkey, - get_input_bls_pubkeys, - get_input_bls_signatures, - get_input_sign_message, - get_output_bls_pubkey, - get_output_bls_signature, -) -from eth2.beacon.tools.fixtures.test_case import ( - BaseTestCase, -) -from eth_typing import ( - BLSPubkey, - BLSSignature, - Hash32, -) - - -from tests.eth2.fixtures.helpers import ( - get_test_cases, -) -from tests.eth2.fixtures.path import ( - BASE_FIXTURE_PATH, - ROOT_PROJECT_DIR, -) - - -# Test files -RUNNER_FIXTURE_PATH = BASE_FIXTURE_PATH / 'bls' -HANDLER_FIXTURE_PATHES = ( - RUNNER_FIXTURE_PATH / 'aggregate_pubkeys', - RUNNER_FIXTURE_PATH / 'aggregate_sigs', - # RUNNER_FIXTURE_PATH / 'msg_hash_g2_compressed', # NOTE: No public API in PyEECBackend - # RUNNER_FIXTURE_PATH / 'msg_hash_g2_uncompressed', # NOTE: No public API in PyEECBackend - RUNNER_FIXTURE_PATH / 'priv_to_pub', - RUNNER_FIXTURE_PATH / 'sign_msg', -) -FILTERED_CONFIG_NAMES = () - - -# -# Test format -# -@dataclass -class BLSPubkeyAggregationTestCase(BaseTestCase): - input: Tuple[BLSPubkey, ...] - output: BLSPubkey - - -@dataclass -class BLSSignaturesAggregationTestCase(BaseTestCase): - input: Tuple[BLSPubkey, ...] - output: BLSSignature - - -@dataclass -class BLSPrivToPubTestCase(BaseTestCase): - input: int - output: BLSPubkey - - -@dataclass -class BLSSignMessageTestCase(BaseTestCase): - input: Tuple[bytes, Hash32, Domain] - output: BLSPubkey - - -handler_to_processing_call_map = { - 'aggregate_pubkeys': ( - bls.aggregate_pubkeys, - BLSPubkeyAggregationTestCase, - get_input_bls_pubkeys, - get_output_bls_pubkey, - ), - 'aggregate_sigs': ( - bls.aggregate_signatures, - BLSSignaturesAggregationTestCase, - get_input_bls_signatures, - get_output_bls_signature, - ), - 'priv_to_pub': ( - bls.privtopub, - BLSPrivToPubTestCase, - get_input_bls_privkey, - get_output_bls_pubkey, - ), - 'sign_msg': ( - bls.sign, - BLSSignMessageTestCase, - get_input_sign_message, - get_output_bls_pubkey, - ), -} - - -# -# Helpers for generating test suite -# -def parse_bls_test_case(test_case, handler, index, config=None): - _, test_case_class, input_fn, output_fn = handler_to_processing_call_map[handler] - return test_case_class( - handler=handler, - index=index, - input=input_fn(test_case), - output=output_fn(test_case), - ) - - -all_test_cases = get_test_cases( - root_project_dir=ROOT_PROJECT_DIR, - fixture_pathes=HANDLER_FIXTURE_PATHES, - config_names=FILTERED_CONFIG_NAMES, - parse_test_case_fn=parse_bls_test_case, -) - - -@pytest.mark.parametrize( - "test_case, config", - all_test_cases -) -def test_aggregate_pubkeys_fixture(config, test_case): - bls.use(MilagroBackend) - processing_call, _, _, _ = handler_to_processing_call_map[test_case.handler] - assert processing_call(**(test_case.input)) == test_case.output diff --git a/tests/eth2/fixtures/helpers.py b/tests/eth2/fixtures/helpers.py deleted file mode 100644 index 9deac2e395..0000000000 --- a/tests/eth2/fixtures/helpers.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest - -from eth_utils import ( - to_tuple, -) - -from eth2.configs import ( - Eth2GenesisConfig, -) -from eth2.beacon.db.chain import BeaconChainDB -from eth2.beacon.state_machines.forks.serenity import ( - SerenityStateMachine, -) -from eth2.beacon.tools.fixtures.loading import ( - get_all_test_files, -) - - -# -# pytest setting -# -def bls_setting_mark_fn(bls_setting): - if bls_setting: - return pytest.mark.noautofixture - return None - - -@to_tuple -def get_test_cases(root_project_dir, fixture_pathes, config_names, parse_test_case_fn): - # TODO: batch reading files - test_files = get_all_test_files( - root_project_dir, - fixture_pathes, - config_names, - parse_test_case_fn, - ) - for test_file in test_files: - for test_case in test_file.test_cases: - bls_setting = test_case.bls_setting if hasattr(test_case, 'bls_setting') else False - yield mark_test_case(test_file, test_case, bls_setting=bls_setting) - - -def get_test_id(test_file, test_case): - description = test_case.description if hasattr(test_case, 'description') else '' - return f"{test_file.file_name}:{test_case.index}:{description}" - - -def mark_test_case(test_file, test_case, bls_setting=False): - test_id = get_test_id(test_file, test_case) - - mark = bls_setting_mark_fn(bls_setting) - if mark: - return pytest.param(test_case, test_file.config, id=test_id, marks=(mark,)) - else: - return pytest.param(test_case, test_file.config, id=test_id) - - -# -# State execution -# -def get_sm_class_of_config(config): - return SerenityStateMachine.configure( - __name__='SerenityStateMachineForTesting', - config=config, - ) - - -def get_chaindb_of_config(base_db, config): - return BeaconChainDB(base_db, Eth2GenesisConfig(config)) diff --git a/tests/eth2/fixtures/path.py b/tests/eth2/fixtures/path.py deleted file mode 100644 index e504d2f2cd..0000000000 --- a/tests/eth2/fixtures/path.py +++ /dev/null @@ -1,9 +0,0 @@ -import os -from pathlib import Path - - -# ROOT_PROJECT_DIR = Path(__file__).cwd() -ROOT_PROJECT_DIR = Path( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -) -BASE_FIXTURE_PATH = ROOT_PROJECT_DIR / 'eth2-fixtures' / 'tests' diff --git a/tests/eth2/fixtures/shuffling-fixtures/test_shuffling_core.py b/tests/eth2/fixtures/shuffling-fixtures/test_shuffling_core.py deleted file mode 100644 index 7b3d4cc19f..0000000000 --- a/tests/eth2/fixtures/shuffling-fixtures/test_shuffling_core.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest -from typing import ( - Tuple, -) - -from dataclasses import ( - dataclass, -) - -from eth_utils import ( - decode_hex, -) - -from eth2.beacon.tools.misc.ssz_vector import ( - override_lengths, -) -from eth2.beacon.committee_helpers import ( - compute_shuffled_index, -) -from eth2.beacon.tools.fixtures.config_name import ( - ONLY_MINIMAL, -) -from eth2.beacon.tools.fixtures.test_case import ( - BaseTestCase, -) - -from tests.eth2.fixtures.helpers import ( - get_test_cases, -) -from tests.eth2.fixtures.path import ( - BASE_FIXTURE_PATH, - ROOT_PROJECT_DIR, -) - - -# Test files -SHUFFLING_FIXTURE_PATH = BASE_FIXTURE_PATH / 'shuffling' -FIXTURE_PATHES = ( - SHUFFLING_FIXTURE_PATH, -) -FILTERED_CONFIG_NAMES = ONLY_MINIMAL - - -@dataclass -class ShufflingTestCase(BaseTestCase): - seed: bytes - count: int - shuffled: Tuple[int, ...] - - -# -# Helpers for generating test suite -# -def parse_shuffling_test_case(test_case, handler, index, config): - override_lengths(config) - - return ShufflingTestCase( - handler=handler, - index=index, - seed=decode_hex(test_case['seed']), - count=test_case['count'], - shuffled=tuple(test_case['shuffled']), - ) - - -all_test_cases = get_test_cases( - root_project_dir=ROOT_PROJECT_DIR, - fixture_pathes=FIXTURE_PATHES, - config_names=FILTERED_CONFIG_NAMES, - parse_test_case_fn=parse_shuffling_test_case, -) - - -@pytest.mark.parametrize( - "test_case, config", - all_test_cases -) -def test_shuffling_fixture(test_case, config): - - result = tuple( - compute_shuffled_index(index, test_case.count, test_case.seed, config.SHUFFLE_ROUND_COUNT) - for index in range(test_case.count) - ) - assert result == test_case.shuffled diff --git a/tests/eth2/fixtures/state-fixtures/__init__.py b/tests/eth2/fixtures/state-fixtures/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/eth2/fixtures/state-fixtures/conftest.py b/tests/eth2/fixtures/state-fixtures/conftest.py deleted file mode 100644 index b0d055f40d..0000000000 --- a/tests/eth2/fixtures/state-fixtures/conftest.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest -from eth2._utils.bls import bls -from eth2._utils.bls.backends import PyECCBackend -from eth2.beacon.operations.attestation_pool import AttestationPool - - -# -# BLS mock -# -@pytest.fixture(autouse=True) -def mock_bls(mocker, request): - if 'noautofixture' in request.keywords: - bls.use(PyECCBackend) - else: - bls.use_noop_backend() - - -# -# Attestation pool -# -@pytest.fixture -def empty_attestation_pool(): - return AttestationPool() diff --git a/tests/eth2/fixtures/state-fixtures/test_genesis_initialization.py b/tests/eth2/fixtures/state-fixtures/test_genesis_initialization.py deleted file mode 100644 index 5d6d690f99..0000000000 --- a/tests/eth2/fixtures/state-fixtures/test_genesis_initialization.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import ( - Tuple, -) -from dataclasses import ( - dataclass, - field, -) -import pytest - -from eth_utils import ( - decode_hex, -) -from eth_typing import ( - Hash32, -) -from ssz.tools import ( - from_formatted_dict, -) - -from eth2.beacon.genesis import ( - initialize_beacon_state_from_eth1, -) -from eth2.beacon.tools.misc.ssz_vector import ( - override_lengths, -) -from eth2.beacon.types.deposits import Deposit -from eth2.beacon.types.states import BeaconState -from eth2.beacon.tools.fixtures.config_name import ( - ONLY_MINIMAL, -) -from eth2.beacon.tools.fixtures.helpers import ( - validate_state, -) -from eth2.beacon.tools.fixtures.loading import ( - get_bls_setting, - get_deposits, -) -from eth2.beacon.tools.fixtures.test_case import ( - BaseTestCase, -) -from eth2.beacon.typing import ( - Timestamp, -) - -from tests.eth2.fixtures.helpers import ( - get_test_cases, -) -from tests.eth2.fixtures.path import ( - BASE_FIXTURE_PATH, - ROOT_PROJECT_DIR, -) - - -# Test files -RUNNER_FIXTURE_PATH = BASE_FIXTURE_PATH / 'genesis' -HANDLER_FIXTURE_PATHES = ( - RUNNER_FIXTURE_PATH / 'initialization', -) -FILTERED_CONFIG_NAMES = ONLY_MINIMAL - - -# -# Test format -# -@dataclass -class GenesisInitializationTestCase(BaseTestCase): - bls_setting: bool - description: str - eth1_block_hash: Hash32 - eth1_timestamp: Timestamp - state: BeaconState - deposits: Tuple[Deposit, ...] = field(default_factory=tuple) - - -# -# Helpers for generating test suite -# -def parse_genesis_initialization_test_case(test_case, handler, index, config): - override_lengths(config) - - bls_setting = get_bls_setting(test_case) - eth1_block_hash = decode_hex(test_case['eth1_block_hash']) - eth1_timestamp = test_case['eth1_timestamp'] - state = from_formatted_dict(test_case['state'], BeaconState) - deposits = get_deposits(test_case, Deposit) - - return GenesisInitializationTestCase( - handler=handler, - index=index, - bls_setting=bls_setting, - description=test_case['description'], - eth1_block_hash=eth1_block_hash, - eth1_timestamp=eth1_timestamp, - state=state, - deposits=deposits, - ) - - -all_test_cases = get_test_cases( - root_project_dir=ROOT_PROJECT_DIR, - fixture_pathes=HANDLER_FIXTURE_PATHES, - config_names=FILTERED_CONFIG_NAMES, - parse_test_case_fn=parse_genesis_initialization_test_case, -) - - -@pytest.mark.parametrize( - "test_case, config", - all_test_cases -) -def test_genesis_initialization_fixture(config, test_case): - result_state = initialize_beacon_state_from_eth1( - eth1_block_hash=test_case.eth1_block_hash, - eth1_timestamp=test_case.eth1_timestamp, - deposits=test_case.deposits, - config=config, - ) - - validate_state(test_case.state, result_state) diff --git a/tests/eth2/fixtures/state-fixtures/test_genesis_validity.py b/tests/eth2/fixtures/state-fixtures/test_genesis_validity.py deleted file mode 100644 index b341319b4b..0000000000 --- a/tests/eth2/fixtures/state-fixtures/test_genesis_validity.py +++ /dev/null @@ -1,86 +0,0 @@ -from dataclasses import ( - dataclass, -) -import pytest - -from ssz.tools import ( - from_formatted_dict, -) - -from eth2.beacon.genesis import ( - is_valid_genesis_state, -) -from eth2.beacon.tools.misc.ssz_vector import ( - override_lengths, -) -from eth2.beacon.types.states import BeaconState -from eth2.beacon.tools.fixtures.config_name import ( - ONLY_MINIMAL, -) - -from eth2.beacon.tools.fixtures.loading import ( - get_bls_setting, -) -from eth2.beacon.tools.fixtures.test_case import ( - BaseTestCase, -) - -from tests.eth2.fixtures.helpers import ( - get_test_cases, -) -from tests.eth2.fixtures.path import ( - BASE_FIXTURE_PATH, - ROOT_PROJECT_DIR, -) - - -# Test files -RUNNER_FIXTURE_PATH = BASE_FIXTURE_PATH / 'genesis' -HANDLER_FIXTURE_PATHES = ( - RUNNER_FIXTURE_PATH / 'validity', -) -FILTERED_CONFIG_NAMES = ONLY_MINIMAL - - -# -# Test format -# -@dataclass -class GenesisValidityTestCase(BaseTestCase): - bls_setting: bool - description: str - genesis: BeaconState - is_valid: bool - - -def parse_genesis_validity_test_case(test_case, handler, index, config): - override_lengths(config) - - bls_setting = get_bls_setting(test_case) - genesis = from_formatted_dict(test_case['genesis'], BeaconState) - is_valid = test_case['is_valid'] - - return GenesisValidityTestCase( - handler=handler, - index=index, - bls_setting=bls_setting, - description=test_case['description'], - genesis=genesis, - is_valid=is_valid, - ) - - -all_test_cases = get_test_cases( - root_project_dir=ROOT_PROJECT_DIR, - fixture_pathes=HANDLER_FIXTURE_PATHES, - config_names=FILTERED_CONFIG_NAMES, - parse_test_case_fn=parse_genesis_validity_test_case, -) - - -@pytest.mark.parametrize( - "test_case, config", - all_test_cases -) -def test_genesis_validity_fixture(config, test_case): - assert test_case.is_valid == is_valid_genesis_state(test_case.genesis, config) diff --git a/tests/eth2/fixtures/state-fixtures/test_operations.py b/tests/eth2/fixtures/state-fixtures/test_operations.py deleted file mode 100644 index af0c58145a..0000000000 --- a/tests/eth2/fixtures/state-fixtures/test_operations.py +++ /dev/null @@ -1,135 +0,0 @@ -import pytest - -from eth_utils import ( - ValidationError, -) - -from eth2.beacon.exceptions import ( - SignatureError, -) -from eth2.beacon.tools.misc.ssz_vector import ( - override_lengths, -) -from eth2.beacon.types.attestations import Attestation -from eth2.beacon.types.attester_slashings import AttesterSlashing -from eth2.beacon.types.blocks import ( - BeaconBlock, - BeaconBlockBody, -) -from eth2.beacon.types.deposits import Deposit -from eth2.beacon.types.proposer_slashings import ProposerSlashing -from eth2.beacon.types.states import BeaconState -from eth2.beacon.types.transfers import Transfer -from eth2.beacon.types.voluntary_exits import VoluntaryExit - - -from eth2.beacon.tools.fixtures.config_name import ( - ONLY_MINIMAL, -) -from eth2.beacon.tools.fixtures.helpers import ( - validate_state, -) -from eth2.beacon.state_machines.forks.serenity.operation_processing import ( - process_attestations, - process_attester_slashings, - process_deposits, - process_proposer_slashings, - process_transfers, - process_voluntary_exits, -) -from eth2.beacon.tools.fixtures.loading import ( - get_bls_setting, - get_operation_or_header, - get_states, -) -from eth2.beacon.tools.fixtures.test_case import ( - OperationCase, -) - -from tests.eth2.fixtures.helpers import ( - get_test_cases, -) -from tests.eth2.fixtures.path import ( - BASE_FIXTURE_PATH, - ROOT_PROJECT_DIR, -) - - -# Test files -RUNNER_FIXTURE_PATH = BASE_FIXTURE_PATH / 'operations' -HANDLER_FIXTURE_PATHES = ( - RUNNER_FIXTURE_PATH / 'proposer_slashing', - RUNNER_FIXTURE_PATH / 'attester_slashing', - RUNNER_FIXTURE_PATH / 'attestation', - RUNNER_FIXTURE_PATH / 'deposit', - RUNNER_FIXTURE_PATH / 'voluntary_exit', - RUNNER_FIXTURE_PATH / 'transfer', -) -FILTERED_CONFIG_NAMES = ONLY_MINIMAL - -handler_to_processing_call_map = { - 'proposer_slashing': (ProposerSlashing, process_proposer_slashings), - 'attester_slashing': (AttesterSlashing, process_attester_slashings), - 'attestation': (Attestation, process_attestations), - 'deposit': (Deposit, process_deposits), - 'voluntary_exit': (VoluntaryExit, process_voluntary_exits), - 'transfer': (Transfer, process_transfers), -} - - -# -# Helpers for generating test suite -# -def parse_operation_test_case(test_case, handler, index, config): - config = config._replace( - MAX_TRANSFERS=1, - ) - override_lengths(config) - - bls_setting = get_bls_setting(test_case) - pre, post, is_valid = get_states(test_case, BeaconState) - operation_class, _ = handler_to_processing_call_map[handler] - operation = get_operation_or_header(test_case, operation_class, handler) - - return OperationCase( - handler=handler, - index=index, - bls_setting=bls_setting, - description=test_case['description'], - pre=pre, - operation=operation, - post=post, - is_valid=is_valid, - ) - - -all_test_cases = get_test_cases( - root_project_dir=ROOT_PROJECT_DIR, - fixture_pathes=HANDLER_FIXTURE_PATHES, - config_names=FILTERED_CONFIG_NAMES, - parse_test_case_fn=parse_operation_test_case, -) - - -@pytest.mark.parametrize( - "test_case, config", - all_test_cases -) -def test_operation_fixture(config, test_case): - config = config._replace( - MAX_TRANSFERS=1, - ) - post_state = test_case.pre - block = BeaconBlock().copy( - body=BeaconBlockBody( - **{test_case.handler + 's': (test_case.operation,)} # TODO: it looks awful - ) - ) - _, operation_processing = handler_to_processing_call_map[test_case.handler] - - if test_case.is_valid: - post_state = operation_processing(post_state, block, config) - validate_state(test_case.post, post_state) - else: - with pytest.raises((ValidationError, IndexError, SignatureError)): - operation_processing(post_state, block, config) diff --git a/tests/eth2/fixtures/state-fixtures/test_sanity.py b/tests/eth2/fixtures/state-fixtures/test_sanity.py deleted file mode 100644 index 2e135743fa..0000000000 --- a/tests/eth2/fixtures/state-fixtures/test_sanity.py +++ /dev/null @@ -1,118 +0,0 @@ -from dataclasses import ( - dataclass, -) -import pytest - -from eth_utils import ( - ValidationError, -) - -from eth2.beacon.tools.misc.ssz_vector import ( - override_lengths, -) -from eth2.beacon.types.blocks import BeaconBlock -from eth2.beacon.types.states import BeaconState -from eth2.beacon.tools.fixtures.config_name import ( - ONLY_MINIMAL, -) -from eth2.beacon.tools.fixtures.helpers import ( - run_state_execution, - validate_state, -) -from eth2.beacon.tools.fixtures.loading import ( - get_bls_setting, - get_blocks, - get_slots, - get_states, -) -from eth2.beacon.tools.fixtures.test_case import ( - StateTestCase, -) - -from tests.eth2.fixtures.helpers import ( - get_test_cases, - get_chaindb_of_config, - get_sm_class_of_config, -) -from tests.eth2.fixtures.path import ( - BASE_FIXTURE_PATH, - ROOT_PROJECT_DIR, -) - - -# Test files -RUNNER_FIXTURE_PATH = BASE_FIXTURE_PATH / 'sanity' -HANDLER_FIXTURE_PATHES = ( - RUNNER_FIXTURE_PATH, -) -FILTERED_CONFIG_NAMES = ONLY_MINIMAL - - -# -# Test format -# -@dataclass -class SanityTestCase(StateTestCase): - pass - - -# -# Helpers for generating test suite -# -def parse_sanity_test_case(test_case, handler, index, config): - override_lengths(config) - - bls_setting = get_bls_setting(test_case) - pre, post, is_valid = get_states(test_case, BeaconState) - blocks = get_blocks(test_case, BeaconBlock) - slots = get_slots(test_case) - - return SanityTestCase( - handler=handler, - index=index, - bls_setting=bls_setting, - description=test_case['description'], - pre=pre, - post=post, - is_valid=is_valid, - slots=slots, - blocks=blocks, - ) - - -all_test_cases = get_test_cases( - root_project_dir=ROOT_PROJECT_DIR, - fixture_pathes=HANDLER_FIXTURE_PATHES, - config_names=FILTERED_CONFIG_NAMES, - parse_test_case_fn=parse_sanity_test_case, -) - - -@pytest.mark.parametrize( - "test_case, config", - all_test_cases -) -def test_sanity_fixture(base_db, config, test_case, empty_attestation_pool): - sm_class = get_sm_class_of_config(config) - chaindb = get_chaindb_of_config(base_db, config) - - post_state = test_case.pre - if test_case.is_valid: - post_state = run_state_execution( - test_case, - sm_class, - chaindb, - empty_attestation_pool, - post_state, - ) - - validate_state(test_case.post, post_state) - else: - with pytest.raises(ValidationError): - run_state_execution( - test_case, - sm_class, - chaindb, - empty_attestation_pool, - post_state, - ) diff --git a/tests/eth2/fixtures/test_bls.py b/tests/eth2/fixtures/test_bls.py new file mode 100644 index 0000000000..23fc98a325 --- /dev/null +++ b/tests/eth2/fixtures/test_bls.py @@ -0,0 +1,37 @@ +from eth2.beacon.tools.fixtures.test_gen import ( + generate_pytests_from_eth2_fixture, + pytest_from_eth2_fixture, +) +from eth2.beacon.tools.fixtures.test_types.bls import BLSTestType + + +def pytest_generate_tests(metafunc): + generate_pytests_from_eth2_fixture(metafunc) + + +@pytest_from_eth2_fixture( + {"test_types": {BLSTestType: lambda handler: handler.name == "aggregate_pubkeys"}} +) +def test_aggregate_pubkeys(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + {"test_types": {BLSTestType: lambda handler: handler.name == "aggregate_sigs"}} +) +def test_aggregate_sigs(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + {"test_types": {BLSTestType: lambda handler: handler.name == "priv_to_pub"}} +) +def test_priv_to_pub(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + {"test_types": {BLSTestType: lambda handler: handler.name == "sign_msg"}} +) +def test_sign_msg(test_case): + test_case.execute() diff --git a/tests/eth2/fixtures/test_epoch_processing.py b/tests/eth2/fixtures/test_epoch_processing.py new file mode 100644 index 0000000000..58760637fb --- /dev/null +++ b/tests/eth2/fixtures/test_epoch_processing.py @@ -0,0 +1,73 @@ +from eth2.beacon.tools.fixtures.config_types import Minimal +from eth2.beacon.tools.fixtures.test_gen import ( + generate_pytests_from_eth2_fixture, + pytest_from_eth2_fixture, +) +from eth2.beacon.tools.fixtures.test_types.epoch_processing import ( + EpochProcessingTestType, +) + + +def pytest_generate_tests(metafunc): + generate_pytests_from_eth2_fixture(metafunc) + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + EpochProcessingTestType: lambda handler: handler.name == "crosslinks" + }, + } +) +def test_crosslinks(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + EpochProcessingTestType: lambda handler: handler.name + == "justification_and_finalization" + }, + } +) +def test_justification_and_finalization(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + EpochProcessingTestType: lambda handler: handler.name == "registry_updates" + }, + } +) +def test_registry_updates(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + EpochProcessingTestType: lambda handler: handler.name == "slashings" + }, + } +) +def test_slashings(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + EpochProcessingTestType: lambda handler: handler.name == "final_updates" + }, + } +) +def test_final_updates(test_case): + test_case.execute() diff --git a/tests/eth2/fixtures/test_genesis.py b/tests/eth2/fixtures/test_genesis.py new file mode 100644 index 0000000000..bae3982987 --- /dev/null +++ b/tests/eth2/fixtures/test_genesis.py @@ -0,0 +1,32 @@ +from eth2.beacon.tools.fixtures.config_types import Minimal +from eth2.beacon.tools.fixtures.test_gen import ( + generate_pytests_from_eth2_fixture, + pytest_from_eth2_fixture, +) +from eth2.beacon.tools.fixtures.test_types.genesis import GenesisTestType + + +def pytest_generate_tests(metafunc): + generate_pytests_from_eth2_fixture(metafunc) + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": {GenesisTestType: lambda handler: handler.name == "validity"}, + } +) +def test_validity(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + GenesisTestType: lambda handler: handler.name == "initialization" + }, + } +) +def test_initialization(test_case): + test_case.execute() diff --git a/tests/eth2/fixtures/test_operations.py b/tests/eth2/fixtures/test_operations.py new file mode 100644 index 0000000000..ff2a797c8c --- /dev/null +++ b/tests/eth2/fixtures/test_operations.py @@ -0,0 +1,121 @@ +from eth_utils import ValidationError +import pytest + +from eth2.beacon.tools.fixtures.config_types import Minimal +from eth2.beacon.tools.fixtures.test_gen import ( + generate_pytests_from_eth2_fixture, + pytest_from_eth2_fixture, +) +from eth2.beacon.tools.fixtures.test_types.operations import OperationsTestType + + +def pytest_generate_tests(metafunc): + generate_pytests_from_eth2_fixture(metafunc) + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + OperationsTestType: lambda handler: handler.name == "attestation" + }, + } +) +def test_attestation(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + OperationsTestType: lambda handler: handler.name == "attester_slashing" + }, + } +) +def test_attester_slashing(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + OperationsTestType: lambda handler: handler.name == "block_header" + }, + } +) +def test_block_header(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": {OperationsTestType: lambda handler: handler.name == "deposit"}, + } +) +def test_deposit(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + OperationsTestType: lambda handler: handler.name == "proposer_slashing" + }, + } +) +def test_proposer_slashing(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": {OperationsTestType: lambda handler: handler.name == "transfer"}, + } +) +def test_transfer(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": { + OperationsTestType: lambda handler: handler.name == "voluntary_exit" + }, + } +) +def test_voluntary_exit(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() diff --git a/tests/eth2/fixtures/test_sanity.py b/tests/eth2/fixtures/test_sanity.py new file mode 100644 index 0000000000..20865cfa9a --- /dev/null +++ b/tests/eth2/fixtures/test_sanity.py @@ -0,0 +1,37 @@ +from eth_utils import ValidationError +import pytest + +from eth2.beacon.tools.fixtures.config_types import Minimal +from eth2.beacon.tools.fixtures.test_gen import ( + generate_pytests_from_eth2_fixture, + pytest_from_eth2_fixture, +) +from eth2.beacon.tools.fixtures.test_types.sanity import SanityTestType + + +def pytest_generate_tests(metafunc): + generate_pytests_from_eth2_fixture(metafunc) + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": {SanityTestType: lambda handler: handler.name == "slots"}, + } +) +def test_slots(test_case): + test_case.execute() + + +@pytest_from_eth2_fixture( + { + "config_types": (Minimal,), + "test_types": {SanityTestType: lambda handler: handler.name == "blocks"}, + } +) +def test_blocks(test_case): + if test_case.valid(): + test_case.execute() + else: + with pytest.raises(ValidationError): + test_case.execute() diff --git a/tests/eth2/fixtures/test_shuffling.py b/tests/eth2/fixtures/test_shuffling.py new file mode 100644 index 0000000000..7d365342b3 --- /dev/null +++ b/tests/eth2/fixtures/test_shuffling.py @@ -0,0 +1,22 @@ +from eth2.beacon.tools.fixtures.config_types import Minimal +from eth2.beacon.tools.fixtures.test_gen import ( + generate_pytests_from_eth2_fixture, + pytest_from_eth2_fixture, +) +from eth2.beacon.tools.fixtures.test_types.shuffling import ShufflingTestType + + +def pytest_generate_tests(metafunc): + generate_pytests_from_eth2_fixture(metafunc) + + +@pytest_from_eth2_fixture( + {"config_types": (Minimal,), "test_types": (ShufflingTestType,)} +) +def test_minimal(test_case): + test_case.execute() + + +# @pytest_from_eth2_fixture({"config_types": (Full,), "test_types": (ShufflingTestType,)}) +# def test_full(test_case): +# test_case.execute() diff --git a/tests/eth2/integration/test_demo.py b/tests/eth2/integration/test_demo.py index 69da92e516..e6d5060b0a 100644 --- a/tests/eth2/integration/test_demo.py +++ b/tests/eth2/integration/test_demo.py @@ -1,33 +1,17 @@ import pytest +from eth2._utils.bls import bls from eth2.beacon.db.chain import BeaconChainDB from eth2.beacon.fork_choice.higher_slot import higher_slot_scoring -from eth2.beacon.helpers import ( - compute_epoch_of_slot, -) +from eth2.beacon.helpers import compute_epoch_of_slot from eth2.beacon.operations.attestation_pool import AttestationPool -from eth2.beacon.state_machines.forks.serenity.blocks import ( - SerenityBeaconBlock, -) -from eth2.beacon.state_machines.forks.serenity.configs import ( - SERENITY_CONFIG, -) -from eth2.beacon.state_machines.forks.serenity import ( - SerenityStateMachine, -) -from eth2.beacon.tools.builder.initializer import ( - create_mock_genesis, -) -from eth2.beacon.tools.builder.proposer import ( - create_mock_block, -) -from eth2.beacon.tools.builder.validator import ( - create_mock_signed_attestations_at_slot, -) -from eth2.beacon.tools.misc.ssz_vector import ( - override_lengths, -) -from eth2._utils.bls import bls +from eth2.beacon.state_machines.forks.serenity import SerenityStateMachine +from eth2.beacon.state_machines.forks.serenity.blocks import SerenityBeaconBlock +from eth2.beacon.state_machines.forks.serenity.configs import SERENITY_CONFIG +from eth2.beacon.tools.builder.initializer import create_mock_genesis +from eth2.beacon.tools.builder.proposer import create_mock_block +from eth2.beacon.tools.builder.validator import create_mock_signed_attestations_at_slot +from eth2.beacon.tools.misc.ssz_vector import override_lengths @pytest.fixture @@ -35,32 +19,22 @@ def fork_choice_scoring(): return higher_slot_scoring -@pytest.mark.parametrize( - ( - "validator_count" - ), - ( - (40), - ) -) -def test_demo(base_db, - validator_count, - keymap, - pubkeys, - fork_choice_scoring): +@pytest.mark.parametrize(("validator_count"), ((40),)) +def test_demo(base_db, validator_count, keymap, pubkeys, fork_choice_scoring): bls.use_noop_backend() slots_per_epoch = 8 config = SERENITY_CONFIG._replace( SLOTS_PER_EPOCH=slots_per_epoch, - GENESIS_EPOCH=compute_epoch_of_slot(SERENITY_CONFIG.GENESIS_SLOT, slots_per_epoch), + GENESIS_EPOCH=compute_epoch_of_slot( + SERENITY_CONFIG.GENESIS_SLOT, slots_per_epoch + ), TARGET_COMMITTEE_SIZE=3, SHARD_COUNT=2, MIN_ATTESTATION_INCLUSION_DELAY=2, ) override_lengths(config) fixture_sm_class = SerenityStateMachine.configure( - __name__='SerenityStateMachineForTesting', - config=config, + __name__="SerenityStateMachineForTesting", config=config ) genesis_slot = config.GENESIS_SLOT @@ -90,18 +64,16 @@ def test_demo(base_db, for current_slot in range(genesis_slot + 1, genesis_slot + chain_length + 1): if current_slot > genesis_slot + config.MIN_ATTESTATION_INCLUSION_DELAY: - attestations = attestations_map[current_slot - config.MIN_ATTESTATION_INCLUSION_DELAY] + attestations = attestations_map[ + current_slot - config.MIN_ATTESTATION_INCLUSION_DELAY + ] else: attestations = () block = create_mock_block( state=state, config=config, - state_machine=fixture_sm_class( - chaindb, - attestation_pool, - blocks[-1].slot, - ), + state_machine=fixture_sm_class(chaindb, attestation_pool, blocks[-1].slot), block_class=SerenityBeaconBlock, parent_block=block, keymap=keymap, @@ -110,12 +82,8 @@ def test_demo(base_db, ) # Get state machine instance - sm = fixture_sm_class( - chaindb, - attestation_pool, - blocks[-1].slot, - ) - state, _ = sm.import_block(block) + sm = fixture_sm_class(chaindb, attestation_pool, blocks[-1].slot) + state, _ = sm.import_block(block, state) chaindb.persist_state(state) chaindb.persist_block(block, SerenityBeaconBlock, fork_choice_scoring) @@ -127,11 +95,7 @@ def test_demo(base_db, attestations = create_mock_signed_attestations_at_slot( state=state, config=config, - state_machine=fixture_sm_class( - chaindb, - attestation_pool, - block.slot, - ), + state_machine=fixture_sm_class(chaindb, attestation_pool, block.slot), attestation_slot=attestation_slot, beacon_block_root=block.signing_root, keymap=keymap, diff --git a/tests/eth2/utils-tests/bitfield-utils/test_bitfields.py b/tests/eth2/utils-tests/bitfield-utils/test_bitfields.py index 9b9734e8d9..c5e4ba3e8b 100644 --- a/tests/eth2/utils-tests/bitfield-utils/test_bitfields.py +++ b/tests/eth2/utils-tests/bitfield-utils/test_bitfields.py @@ -1,30 +1,21 @@ import random -from hypothesis import ( - given, - strategies as st, -) +from hypothesis import given +from hypothesis import strategies as st import pytest from eth2._utils.bitfield import ( - has_voted, - set_voted, get_bitfield_length, get_empty_bitfield, get_vote_count, + has_voted, + set_voted, ) @pytest.mark.parametrize( - 'attester_count, bitfield_length', - [ - (0, 0), - (1, 1), - (8, 1), - (9, 2), - (16, 2), - (17, 3), - ] + "attester_count, bitfield_length", + [(0, 0), (1, 1), (8, 1), (9, 2), (16, 2), (17, 3)], ) def test_bitfield_length(attester_count, bitfield_length): assert get_bitfield_length(attester_count) == bitfield_length @@ -42,24 +33,114 @@ def test_bitfield_single_votes(): attesters = list(range(10)) bitfield = get_empty_bitfield(len(attesters)) - assert set_voted(bitfield, 0) == (True, False, False, False, False, - False, False, False, False, False) - assert set_voted(bitfield, 1) == (False, True, False, False, False, - False, False, False, False, False) - assert set_voted(bitfield, 2) == (False, False, True, False, False, - False, False, False, False, False) - assert set_voted(bitfield, 4) == (False, False, False, False, True, - False, False, False, False, False) - assert set_voted(bitfield, 5) == (False, False, False, False, False, - True, False, False, False, False) - assert set_voted(bitfield, 6) == (False, False, False, False, False, - False, True, False, False, False) - assert set_voted(bitfield, 7) == (False, False, False, False, False, - False, False, True, False, False) - assert set_voted(bitfield, 8) == (False, False, False, False, False, - False, False, False, True, False) - assert set_voted(bitfield, 9) == (False, False, False, False, False, - False, False, False, False, True) + assert set_voted(bitfield, 0) == ( + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ) + assert set_voted(bitfield, 1) == ( + False, + True, + False, + False, + False, + False, + False, + False, + False, + False, + ) + assert set_voted(bitfield, 2) == ( + False, + False, + True, + False, + False, + False, + False, + False, + False, + False, + ) + assert set_voted(bitfield, 4) == ( + False, + False, + False, + False, + True, + False, + False, + False, + False, + False, + ) + assert set_voted(bitfield, 5) == ( + False, + False, + False, + False, + False, + True, + False, + False, + False, + False, + ) + assert set_voted(bitfield, 6) == ( + False, + False, + False, + False, + False, + False, + True, + False, + False, + False, + ) + assert set_voted(bitfield, 7) == ( + False, + False, + False, + False, + False, + False, + False, + True, + False, + False, + ) + assert set_voted(bitfield, 8) == ( + False, + False, + False, + False, + False, + False, + False, + False, + True, + False, + ) + assert set_voted(bitfield, 9) == ( + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + ) for voter in attesters: bitfield = set_voted((False,) * 16, voter) @@ -84,12 +165,7 @@ def test_bitfield_all_votes(): def test_bitfield_some_votes(): attesters = list(range(10)) - voters = [ - 0, # b'\x01\x00' - 4, # b'\x10\x00' - 5, # b'\x20\x00' - 9, # b'\x00\x02' - ] + voters = [0, 4, 5, 9] # b'\x01\x00' # b'\x10\x00' # b'\x20\x00' # b'\x00\x02' bitfield = get_empty_bitfield(len(attesters)) for voter in voters: diff --git a/tests/eth2/utils-tests/bls-utils/test_backends.py b/tests/eth2/utils-tests/bls-utils/test_backends.py index 843ed8278b..18667d1b47 100644 --- a/tests/eth2/utils-tests/bls-utils/test_backends.py +++ b/tests/eth2/utils-tests/bls-utils/test_backends.py @@ -1,24 +1,10 @@ +from eth_utils import ValidationError +from py_ecc.optimized_bls12_381 import curve_order import pytest -from py_ecc.optimized_bls12_381 import ( - curve_order, -) - -from eth2._utils.bls.backends import ( - AVAILABLE_BACKENDS, - NoOpBackend, -) -from eth2._utils.bls import ( - bls, -) - -from eth2.beacon.constants import ( - EMPTY_PUBKEY, - EMPTY_SIGNATURE, -) -from eth_utils import ( - ValidationError, -) +from eth2._utils.bls import bls +from eth2._utils.bls.backends import AVAILABLE_BACKENDS, NoOpBackend +from eth2.beacon.constants import EMPTY_PUBKEY, EMPTY_SIGNATURE def assert_pubkey(obj): @@ -31,12 +17,10 @@ def assert_signature(obj): @pytest.fixture def domain(): - return (123).to_bytes(8, 'big') + return (123).to_bytes(8, "big") -@pytest.mark.parametrize( - "backend", AVAILABLE_BACKENDS -) +@pytest.mark.parametrize("backend", AVAILABLE_BACKENDS) def test_sanity(backend, domain): bls.use(backend) msg_0 = b"\x32" * 32 @@ -81,11 +65,9 @@ def test_sanity(backend, domain): ) +@pytest.mark.parametrize("backend", AVAILABLE_BACKENDS) @pytest.mark.parametrize( - "backend", AVAILABLE_BACKENDS -) -@pytest.mark.parametrize( - 'privkey', + "privkey", [ (1), (5), @@ -94,67 +76,51 @@ def test_sanity(backend, domain): (127409812145), (90768492698215092512159), (curve_order - 1), - ] + ], ) def test_bls_core_succeed(backend, privkey, domain): bls.use(backend) - msg = str(privkey).encode('utf-8') + msg = str(privkey).encode("utf-8") sig = bls.sign(msg, privkey, domain=domain) pub = bls.privtopub(privkey) assert bls.verify(msg, pub, sig, domain=domain) -@pytest.mark.parametrize( - "backend", AVAILABLE_BACKENDS -) -@pytest.mark.parametrize( - 'privkey', - [ - (0), - (curve_order), - (curve_order + 1), - ] -) +@pytest.mark.parametrize("backend", AVAILABLE_BACKENDS) +@pytest.mark.parametrize("privkey", [(0), (curve_order), (curve_order + 1)]) def test_invalid_private_key(backend, privkey, domain): bls.use(backend) - msg = str(privkey).encode('utf-8') + msg = str(privkey).encode("utf-8") with pytest.raises(ValueError): bls.privtopub(privkey) with pytest.raises(ValueError): bls.sign(msg, privkey, domain=domain) -@pytest.mark.parametrize( - "backend", AVAILABLE_BACKENDS -) +@pytest.mark.parametrize("backend", AVAILABLE_BACKENDS) def test_empty_aggregation(backend): bls.use(backend) assert bls.aggregate_pubkeys([]) == EMPTY_PUBKEY assert bls.aggregate_signatures([]) == EMPTY_SIGNATURE -@pytest.mark.parametrize( - "backend", AVAILABLE_BACKENDS -) +@pytest.mark.parametrize("backend", AVAILABLE_BACKENDS) def test_verify_empty_signatures(backend, domain): # Want EMPTY_SIGNATURE to fail in Trinity bls.use(backend) def validate(): - bls.validate(b'\x11' * 32, EMPTY_PUBKEY, EMPTY_SIGNATURE, domain) + bls.validate(b"\x11" * 32, EMPTY_PUBKEY, EMPTY_SIGNATURE, domain) def validate_multiple_1(): bls.validate_multiple( - pubkeys=(), - message_hashes=(), - signature=EMPTY_SIGNATURE, - domain=domain, + pubkeys=(), message_hashes=(), signature=EMPTY_SIGNATURE, domain=domain ) def validate_multiple_2(): bls.validate_multiple( pubkeys=(EMPTY_PUBKEY, EMPTY_PUBKEY), - message_hashes=(b'\x11' * 32, b'\x12' * 32), + message_hashes=(b"\x11" * 32, b"\x12" * 32), signature=EMPTY_SIGNATURE, domain=domain, ) @@ -172,15 +138,16 @@ def validate_multiple_2(): validate_multiple_2() +@pytest.mark.parametrize("backend", AVAILABLE_BACKENDS) @pytest.mark.parametrize( - "backend", AVAILABLE_BACKENDS -) -@pytest.mark.parametrize( - 'msg, privkeys', + "msg, privkeys", [ - (b'\x12' * 32, [1, 5, 124, 735, 127409812145, 90768492698215092512159, curve_order - 1]), - (b'\x34' * 32, [42, 666, 1274099945, 4389392949595]), - ] + ( + b"\x12" * 32, + [1, 5, 124, 735, 127409812145, 90768492698215092512159, curve_order - 1], + ), + (b"\x34" * 32, [42, 666, 1274099945, 4389392949595]), + ], ) def test_signature_aggregation(backend, msg, privkeys, domain): bls.use(backend) @@ -191,33 +158,30 @@ def test_signature_aggregation(backend, msg, privkeys, domain): assert bls.verify(msg, aggpub, aggsig, domain=domain) +@pytest.mark.parametrize("backend", AVAILABLE_BACKENDS) +@pytest.mark.parametrize("msg_1, msg_2", [(b"\x12" * 32, b"\x34" * 32)]) @pytest.mark.parametrize( - "backend", AVAILABLE_BACKENDS -) -@pytest.mark.parametrize( - 'msg_1, msg_2', - [ - (b'\x12' * 32, b'\x34' * 32) - ] -) -@pytest.mark.parametrize( - 'privkeys_1, privkeys_2', + "privkeys_1, privkeys_2", [ (tuple(range(1, 11)), tuple(range(1, 11))), ((1, 2, 3), (4, 5, 6, 7)), ((1, 2, 3), (2, 3, 4, 5)), ((1, 2, 3), ()), ((), (2, 3, 4, 5)), - ] + ], ) def test_multi_aggregation(backend, msg_1, msg_2, privkeys_1, privkeys_2, domain): bls.use(backend) - sigs_1 = [bls.sign(msg_1, k, domain=domain) for k in privkeys_1] # signatures to msg_1 + sigs_1 = [ + bls.sign(msg_1, k, domain=domain) for k in privkeys_1 + ] # signatures to msg_1 pubs_1 = [bls.privtopub(k) for k in privkeys_1] aggpub_1 = bls.aggregate_pubkeys(pubs_1) # sig_1 to msg_1 - sigs_2 = [bls.sign(msg_2, k, domain=domain) for k in privkeys_2] # signatures to msg_2 + sigs_2 = [ + bls.sign(msg_2, k, domain=domain) for k in privkeys_2 + ] # signatures to msg_2 pubs_2 = [bls.privtopub(k) for k in privkeys_2] aggpub_2 = bls.aggregate_pubkeys(pubs_2) # sig_2 to msg_2 @@ -226,8 +190,5 @@ def test_multi_aggregation(backend, msg_1, msg_2, privkeys_1, privkeys_2, domain aggsig = bls.aggregate_signatures(sigs_1 + sigs_2) assert bls.verify_multiple( - pubkeys=pubs, - message_hashes=message_hashes, - signature=aggsig, - domain=domain, + pubkeys=pubs, message_hashes=message_hashes, signature=aggsig, domain=domain ) diff --git a/tests/eth2/utils-tests/hash-utils/test_hash.py b/tests/eth2/utils-tests/hash-utils/test_hash.py index dac4b39fcc..300cc1a60a 100644 --- a/tests/eth2/utils-tests/hash-utils/test_hash.py +++ b/tests/eth2/utils-tests/hash-utils/test_hash.py @@ -4,9 +4,9 @@ def test_hash(): - output = hash_eth2(b'helloworld') + output = hash_eth2(b"helloworld") assert len(output) == 32 def test_hash_is_keccak256(): - assert hash_eth2(b'foo') == sha256(b'foo').digest() + assert hash_eth2(b"foo") == sha256(b"foo").digest() diff --git a/tests/eth2/utils-tests/merkle-utils/test_merkle_trees.py b/tests/eth2/utils-tests/merkle-utils/test_merkle_trees.py index 8882024f66..61ecc174e7 100644 --- a/tests/eth2/utils-tests/merkle-utils/test_merkle_trees.py +++ b/tests/eth2/utils-tests/merkle-utils/test_merkle_trees.py @@ -1,65 +1,46 @@ +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) - -from eth2._utils.hash import ( - hash_eth2, -) +from eth2._utils.hash import hash_eth2 from eth2._utils.merkle.normal import ( - get_merkle_root_from_items, calc_merkle_tree, - get_root, get_merkle_proof, get_merkle_root, + get_merkle_root_from_items, + get_root, verify_merkle_proof, ) -@pytest.mark.parametrize("leaves,tree", [ - ( - (b"single leaf",), +@pytest.mark.parametrize( + "leaves,tree", + [ + ((b"single leaf",), ((hash_eth2(b"single leaf"),),)), ( - (hash_eth2(b"single leaf"),), - ), - ), - ( - (b"left", b"right"), - ( - (hash_eth2(hash_eth2(b"left") + hash_eth2(b"right")),), - (hash_eth2(b"left"), hash_eth2(b"right")), + (b"left", b"right"), + ( + (hash_eth2(hash_eth2(b"left") + hash_eth2(b"right")),), + (hash_eth2(b"left"), hash_eth2(b"right")), + ), ), - ), - ( - (b"1", b"2", b"3", b"4"), ( + (b"1", b"2", b"3", b"4"), ( - hash_eth2( + ( hash_eth2( - hash_eth2(b"1") + hash_eth2(b"2") - ) + hash_eth2( - hash_eth2(b"3") + hash_eth2(b"4") - ) - ), - ), - ( - hash_eth2( - hash_eth2(b"1") + hash_eth2(b"2") + hash_eth2(hash_eth2(b"1") + hash_eth2(b"2")) + + hash_eth2(hash_eth2(b"3") + hash_eth2(b"4")) + ), ), - hash_eth2( - hash_eth2(b"3") + hash_eth2(b"4") + ( + hash_eth2(hash_eth2(b"1") + hash_eth2(b"2")), + hash_eth2(hash_eth2(b"3") + hash_eth2(b"4")), ), - ), - ( - hash_eth2(b"1"), - hash_eth2(b"2"), - hash_eth2(b"3"), - hash_eth2(b"4"), + (hash_eth2(b"1"), hash_eth2(b"2"), hash_eth2(b"3"), hash_eth2(b"4")), ), ), - ), -]) + ], +) def test_merkle_tree_calculation(leaves, tree): calculated_tree = calc_merkle_tree(leaves) assert calculated_tree == tree @@ -73,38 +54,33 @@ def test_invalid_merkle_root_calculation(leave_number): get_merkle_root_from_items((b"",) * leave_number) -@pytest.mark.parametrize("leaves,index,proof", [ - ( - (b"1", b"2"), - 0, - (hash_eth2(b"2"),), - ), - ( - (b"1", b"2"), - 1, - (hash_eth2(b"1"),), - ), - ( - (b"1", b"2", b"3", b"4"), - 0, - (hash_eth2(b"2"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))), - ), - ( - (b"1", b"2", b"3", b"4"), - 1, - (hash_eth2(b"1"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))), - ), - ( - (b"1", b"2", b"3", b"4"), - 2, - (hash_eth2(b"4"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))), - ), - ( - (b"1", b"2", b"3", b"4"), - 3, - (hash_eth2(b"3"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))), - ), -]) +@pytest.mark.parametrize( + "leaves,index,proof", + [ + ((b"1", b"2"), 0, (hash_eth2(b"2"),)), + ((b"1", b"2"), 1, (hash_eth2(b"1"),)), + ( + (b"1", b"2", b"3", b"4"), + 0, + (hash_eth2(b"2"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))), + ), + ( + (b"1", b"2", b"3", b"4"), + 1, + (hash_eth2(b"1"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))), + ), + ( + (b"1", b"2", b"3", b"4"), + 2, + (hash_eth2(b"4"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))), + ), + ( + (b"1", b"2", b"3", b"4"), + 3, + (hash_eth2(b"3"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))), + ), + ], +) def test_merkle_proofs(leaves, index, proof): tree = calc_merkle_tree(leaves) root = get_root(tree) @@ -117,7 +93,9 @@ def test_merkle_proofs(leaves, index, proof): assert not verify_merkle_proof(root, b"\x00" * 32, index, proof) assert not verify_merkle_proof(root, item, (index + 1) % len(leaves), proof) for replaced_index in range(len(proof)): - altered_proof = proof[:replaced_index] + (b"\x00" * 32,) + proof[replaced_index + 1:] + altered_proof = ( + proof[:replaced_index] + (b"\x00" * 32,) + proof[replaced_index + 1 :] + ) assert not verify_merkle_proof(root, item, index, altered_proof) @@ -132,11 +110,7 @@ def test_single_element_merkle_proof(): assert not verify_merkle_proof(root, b"1", 0, (b"\x00" * 32,)) -@pytest.mark.parametrize("leaves", [ - (b"1",), - (b"1", b"2"), - (b"1", b"2", b"3", b"4"), -]) +@pytest.mark.parametrize("leaves", [(b"1",), (b"1", b"2"), (b"1", b"2", b"3", b"4")]) def test_proof_generation_index_validation(leaves): tree = calc_merkle_tree(leaves) for invalid_index in [-1, len(leaves)]: diff --git a/tests/eth2/utils-tests/merkle-utils/test_sparse_merkle_trees.py b/tests/eth2/utils-tests/merkle-utils/test_sparse_merkle_trees.py index 201113e9a0..27d6681820 100644 --- a/tests/eth2/utils-tests/merkle-utils/test_sparse_merkle_trees.py +++ b/tests/eth2/utils-tests/merkle-utils/test_sparse_merkle_trees.py @@ -1,8 +1,6 @@ import pytest -from eth2._utils.hash import ( - hash_eth2, -) +from eth2._utils.hash import hash_eth2 from eth2._utils.merkle.sparse import ( calc_merkle_tree, get_merkle_proof, @@ -11,24 +9,27 @@ ) -@pytest.mark.parametrize("items,expected_root", [ - ( - (b"1",), - b'p\xb1\xcf\x0ej\x9b{\xf3\x16\xb2\x0f~,l\x15\xdc\xd3\xcdJ?K\x05GD`~9\xfe\x1e\xae\xf82', - ), - ( - (b"1", b"2",), - b'\x83H%\xac_\xbd\x03\xd7!\x95Z\x08\xa1\x0c\xe8/\x83\xfe\x8a\x9b\xe5fe\x94J\xd4\xf5\x1c&FE\xdd', # noqa: E501 - ), - ( - (b"1", b"2", b"3",), - b'\xc2\x95\xf3\xf8:\xc1" \xf1\xe4\x87b_\xa4\xdb\xa9\x14e\xd3\xa9D\x85j\x17\xf5R\xc4\xdd\x88"\x8aJ', # noqa: E501 - ), - ( - (b"1", b"2", b"3", b"4"), - b'\x81!\xeaI4\xfc4_\x15\x13b\xa7tT#i\x9fT5\x1fs\x83B\xbc\x9f\xeb\xa1\x9ekv\xc5g', - ), -]) +@pytest.mark.parametrize( + "items,expected_root", + [ + ( + (b"1",), + b"p\xb1\xcf\x0ej\x9b{\xf3\x16\xb2\x0f~,l\x15\xdc\xd3\xcdJ?K\x05GD`~9\xfe\x1e\xae\xf82", + ), + ( + (b"1", b"2"), + b"\x83H%\xac_\xbd\x03\xd7!\x95Z\x08\xa1\x0c\xe8/\x83\xfe\x8a\x9b\xe5fe\x94J\xd4\xf5\x1c&FE\xdd", # noqa: E501 + ), + ( + (b"1", b"2", b"3"), + b'\xc2\x95\xf3\xf8:\xc1" \xf1\xe4\x87b_\xa4\xdb\xa9\x14e\xd3\xa9D\x85j\x17\xf5R\xc4\xdd\x88"\x8aJ', # noqa: E501 + ), + ( + (b"1", b"2", b"3", b"4"), + b"\x81!\xeaI4\xfc4_\x15\x13b\xa7tT#i\x9fT5\x1fs\x83B\xbc\x9f\xeb\xa1\x9ekv\xc5g", + ), + ], +) def test_merkle_root_and_proofs(items, expected_root): tree = calc_merkle_tree(items) assert get_root(tree) == expected_root @@ -38,14 +39,17 @@ def test_merkle_root_and_proofs(items, expected_root): assert verify_merkle_proof(expected_root, hash_eth2(item), index, proof) assert not verify_merkle_proof(b"\x32" * 32, hash_eth2(item), index, proof) - assert not verify_merkle_proof(expected_root, hash_eth2(b"\x32" * 32), index, proof) + assert not verify_merkle_proof( + expected_root, hash_eth2(b"\x32" * 32), index, proof + ) if len(items) > 1: assert not verify_merkle_proof( - expected_root, - hash_eth2(item), - (index + 1) % len(items), - proof + expected_root, hash_eth2(item), (index + 1) % len(items), proof ) for replaced_index in range(len(proof)): - altered_proof = proof[:replaced_index] + (b"\x32" * 32,) + proof[replaced_index + 1:] - assert not verify_merkle_proof(expected_root, hash_eth2(item), index, altered_proof) + altered_proof = ( + proof[:replaced_index] + (b"\x32" * 32,) + proof[replaced_index + 1 :] + ) + assert not verify_merkle_proof( + expected_root, hash_eth2(item), index, altered_proof + ) diff --git a/tests/eth2/utils-tests/numeric-utils/test_bitwise_xor.py b/tests/eth2/utils-tests/numeric-utils/test_bitwise_xor.py index 595b426b1f..c548fdb109 100644 --- a/tests/eth2/utils-tests/numeric-utils/test_bitwise_xor.py +++ b/tests/eth2/utils-tests/numeric-utils/test_bitwise_xor.py @@ -3,11 +3,6 @@ from eth2._utils.numeric import bitwise_xor -@pytest.mark.parametrize( - 'a,b,result', - [ - (b'\x00' * 32, b'\x0a' * 32, b'\x0a' * 32), - ] -) +@pytest.mark.parametrize("a,b,result", [(b"\x00" * 32, b"\x0a" * 32, b"\x0a" * 32)]) def test_bitwise_xor_success(a, b, result): assert bitwise_xor(a, b) == result diff --git a/tests/eth2/utils-tests/numeric-utils/test_integer_squareroot.py b/tests/eth2/utils-tests/numeric-utils/test_integer_squareroot.py index f1cffcc7dd..12dd6d09b2 100644 --- a/tests/eth2/utils-tests/numeric-utils/test_integer_squareroot.py +++ b/tests/eth2/utils-tests/numeric-utils/test_integer_squareroot.py @@ -1,12 +1,8 @@ +from hypothesis import given +from hypothesis import strategies as st import pytest -from hypothesis import ( - given, - strategies as st, -) -from eth2._utils.numeric import ( - integer_squareroot, -) +from eth2._utils.numeric import integer_squareroot @given(st.integers(min_value=0, max_value=100)) @@ -17,7 +13,7 @@ def test_integer_squareroot_correct(value): @pytest.mark.parametrize( - 'value,expected', + "value,expected", ( (0, 0), (1, 1), @@ -28,20 +24,14 @@ def test_integer_squareroot_correct(value): (65535, 255), (65536, 256), (18446744073709551615, 4294967295), - ) + ), ) def test_integer_squareroot_success(value, expected): actual = integer_squareroot(value) assert actual == expected -@pytest.mark.parametrize( - 'value', - ( - (1.5), - (-1), - ) -) +@pytest.mark.parametrize("value", ((1.5), (-1))) def test_integer_squareroot_edge_cases(value): with pytest.raises(ValueError): integer_squareroot(value) diff --git a/tests/eth2/utils-tests/tuple-utils/test_tuple.py b/tests/eth2/utils-tests/tuple-utils/test_tuple.py index 30fd862564..c6435eaa90 100644 --- a/tests/eth2/utils-tests/tuple-utils/test_tuple.py +++ b/tests/eth2/utils-tests/tuple-utils/test_tuple.py @@ -1,57 +1,24 @@ +from eth_utils import ValidationError import pytest -from eth_utils import ( - ValidationError, -) - -from eth2._utils.tuple import ( - update_tuple_item, -) +from eth2._utils.tuple import update_tuple_item @pytest.mark.parametrize( - ( - 'tuple_data, index, new_value, expected' - ), + ("tuple_data, index, new_value, expected"), [ - ( - (1, ) * 10, - 0, - -99, - (-99,) + (1, ) * 9, - ), - ( - (1, ) * 10, - 5, - -99, - (1, ) * 5 + (-99,) + (1, ) * 4, - ), - ( - (1, ) * 10, - 9, - -99, - (1, ) * 9 + (-99,), - ), - ( - (1, ) * 10, - 10, - -99, - ValidationError(), - ) - ] + ((1,) * 10, 0, -99, (-99,) + (1,) * 9), + ((1,) * 10, 5, -99, (1,) * 5 + (-99,) + (1,) * 4), + ((1,) * 10, 9, -99, (1,) * 9 + (-99,)), + ((1,) * 10, 10, -99, ValidationError()), + ], ) def test_update_tuple_item(tuple_data, index, new_value, expected): if isinstance(expected, Exception): with pytest.raises(ValidationError): - update_tuple_item( - tuple_data=tuple_data, - index=index, - new_value=new_value, - ) + update_tuple_item(tuple_data=tuple_data, index=index, new_value=new_value) else: result = update_tuple_item( - tuple_data=tuple_data, - index=index, - new_value=new_value, + tuple_data=tuple_data, index=index, new_value=new_value ) assert result == expected diff --git a/tests/integration/test_etherscan_checkpoint_resolver.py b/tests/integration/test_etherscan_checkpoint_resolver.py new file mode 100644 index 0000000000..102a9c6b10 --- /dev/null +++ b/tests/integration/test_etherscan_checkpoint_resolver.py @@ -0,0 +1,23 @@ +import pytest + +from eth_utils import encode_hex +from trinity.plugins.builtin.syncer.cli import ( + parse_checkpoint_uri, + is_block_hash, +) + +# This is just the score at the tip as it was at some point on August 26th 2019 +# It serves as anchor so that we have *some* minimal expected score to test against. +MIN_EXPECTED_SCORE = 11631608640717612820968 + + +@pytest.mark.parametrize( + 'uri', + ( + 'eth://block/byetherscan/latest', + ) +) +def test_parse_checkpoint(uri): + checkpoint = parse_checkpoint_uri(uri) + assert checkpoint.score >= MIN_EXPECTED_SCORE + assert is_block_hash(encode_hex(checkpoint.block_hash)) diff --git a/tests/integration/test_lightchain_integration.py b/tests/integration/test_lightchain_integration.py index 73369acb57..7b8ee37e8a 100644 --- a/tests/integration/test_lightchain_integration.py +++ b/tests/integration/test_lightchain_integration.py @@ -1,8 +1,10 @@ import asyncio import logging from pathlib import Path +import shutil import socket import subprocess +import tempfile import time import pytest @@ -47,8 +49,11 @@ async def geth_port(unused_tcp_port): @pytest.fixture def geth_datadir(): - datadir = Path(__file__).parent / 'fixtures' / 'geth_lightchain_datadir' - return datadir.absolute() + fixture_datadir = Path(__file__).parent / 'fixtures' / 'geth_lightchain_datadir' + with tempfile.TemporaryDirectory() as temp_dir: + datadir = Path(temp_dir) / 'geth' + shutil.copytree(fixture_datadir, datadir) + yield datadir @pytest.fixture diff --git a/tests/integration/test_trinity_cli.py b/tests/integration/test_trinity_cli.py index ee6a96dacc..bab842bfc3 100644 --- a/tests/integration/test_trinity_cli.py +++ b/tests/integration/test_trinity_cli.py @@ -5,12 +5,8 @@ import pexpect import pytest -from eth_utils import ( - encode_hex, -) from eth.constants import ( GENESIS_BLOCK_NUMBER, - GENESIS_PARENT_HASH, ) from tests.integration.helpers import ( @@ -21,11 +17,17 @@ from trinity.config import ( TrinityConfig, ) +from trinity.constants import ( + ASSETS_DIR, +) from trinity._utils.async_iter import ( contains_all ) +ROPSTEN_GENESIS_HASH = '0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d' +MAINNET_GENESIS_HASH = '0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3' + # IMPORTANT: Test names are intentionally short here because they end up # in the path name of the isolated Trinity paths that pytest produces for # us. @@ -120,13 +122,35 @@ async def test_light_boot(async_process_runner, command): @pytest.mark.parametrize( - 'command', + 'command, expected_network_id, expected_genesis_hash', ( - ('trinity', ), + (('trinity',), 1, MAINNET_GENESIS_HASH), + (('trinity', '--ropsten'), 3, ROPSTEN_GENESIS_HASH), + ( + ( + 'trinity', + f'--genesis={ASSETS_DIR}/eip1085/devnet.json', + # We don't have a way to refer to the tmp xdg_trinity_root here so we + # make up this replacement marker + '--data-dir={trinity_root_path}/devnet', + '--network-id=5' + ), 5, '0x065fd78e53dcef113bf9d7732dac7c5132dcf85c9588a454d832722ceb097422'), ) ) @pytest.mark.asyncio -async def test_web3(command, async_process_runner): +async def test_web3(command, + expected_network_id, + expected_genesis_hash, + xdg_trinity_root, + async_process_runner): + + command = tuple( + fragment.replace('{trinity_root_path}', str(xdg_trinity_root)) + for fragment + in command + ) + attach_cmd = list(command[1:] + ('attach',)) + await async_process_runner.run(command, timeout_sec=40) assert await contains_all(async_process_runner.stderr, { "Started DB server process", @@ -139,17 +163,17 @@ async def test_web3(command, async_process_runner): "EventBus Endpoint bjson-rpc-api connecting to other Endpoints", }) - attached_trinity = pexpect.spawn('trinity', ['attach'], logfile=sys.stdout, encoding="utf-8") + attached_trinity = pexpect.spawn('trinity', attach_cmd, logfile=sys.stdout, encoding="utf-8") try: attached_trinity.expect("An instance of Web3 connected to the running chain") attached_trinity.sendline("w3.net.version") - attached_trinity.expect("'1'") + attached_trinity.expect(f"'{expected_network_id}'") attached_trinity.sendline("w3") attached_trinity.expect("web3.main.Web3") attached_trinity.sendline("w3.eth.getBlock('latest').blockNumber") attached_trinity.expect(str(GENESIS_BLOCK_NUMBER)) - attached_trinity.sendline("w3.eth.getBlock('latest').parentHash") - attached_trinity.expect(encode_hex(GENESIS_PARENT_HASH)) + attached_trinity.sendline("w3.eth.getBlock('latest').hash") + attached_trinity.expect(expected_genesis_hash) except pexpect.TIMEOUT: raise Exception("Trinity attach timeout") finally: diff --git a/tests/json-fixtures-over-rpc/test_rpc_fixtures.py b/tests/json-fixtures-over-rpc/test_rpc_fixtures.py index 645341ea95..2bee6445d9 100644 --- a/tests/json-fixtures-over-rpc/test_rpc_fixtures.py +++ b/tests/json-fixtures-over-rpc/test_rpc_fixtures.py @@ -434,9 +434,8 @@ class MainnetFullChain(FullChain): @pytest.mark.asyncio async def test_rpc_against_fixtures(event_bus, chain_fixture, fixture_data): - rpc = RPCServer( - initialize_eth1_modules(MainnetFullChain(None), event_bus), event_bus, - ) + chain = MainnetFullChain(None) + rpc = RPCServer(initialize_eth1_modules(chain, event_bus), chain, event_bus) setup_result, setup_error = await call_rpc(rpc, 'evm_resetToGenesisFixture', [chain_fixture]) # We need to advance the event loop for modules to be able to pickup the new chain diff --git a/tests/libp2p/bcc/conftest.py b/tests/libp2p/bcc/conftest.py index 94e8746f2d..933aa89e03 100644 --- a/tests/libp2p/bcc/conftest.py +++ b/tests/libp2p/bcc/conftest.py @@ -2,16 +2,9 @@ import pytest -from eth2.beacon.tools.factories import ( - BeaconChainFactory, -) - +from eth2.beacon.tools.factories import BeaconChainFactory from trinity.protocol.bcc_libp2p import utils - -from trinity.tools.bcc_factories import ( - NodeFactory, -) - +from trinity.tools.bcc_factories import NodeFactory MOCK_TIME = 0.01 @@ -35,10 +28,7 @@ async def nodes(num_nodes): @pytest.fixture async def nodes_with_chain(num_nodes): - chains = tuple([ - BeaconChainFactory() - for _ in range(num_nodes) - ]) + chains = tuple([BeaconChainFactory() for _ in range(num_nodes)]) async for _nodes in make_nodes(num_nodes, chains): yield _nodes @@ -48,10 +38,7 @@ async def make_nodes(num_nodes, chains=None): _nodes = NodeFactory.create_batch(num_nodes) else: assert num_nodes == len(chains) - _nodes = tuple( - NodeFactory(chain=chain) - for chain in chains - ) + _nodes = tuple(NodeFactory(chain=chain) for chain in chains) for n in _nodes: asyncio.ensure_future(n.run()) await n.events.started.wait() diff --git a/tests/libp2p/bcc/test_node.py b/tests/libp2p/bcc/test_node.py index 59ccb6cc95..ac44f5d99c 100644 --- a/tests/libp2p/bcc/test_node.py +++ b/tests/libp2p/bcc/test_node.py @@ -1,22 +1,13 @@ import asyncio -import pytest - from libp2p.peer.id import ID +import pytest from p2p.tools.factories import get_open_port +from trinity.tools.bcc_factories import NodeFactory -from trinity.tools.bcc_factories import ( - NodeFactory, -) - -@pytest.mark.parametrize( - "num_nodes", - ( - 1, - ) -) +@pytest.mark.parametrize("num_nodes", (1,)) @pytest.mark.asyncio async def test_node(nodes): node = nodes[0] @@ -24,54 +15,28 @@ async def test_node(nodes): assert node.host.get_addrs() == expected_addrs -@pytest.mark.parametrize( - "num_nodes", - ( - 3, - ) -) +@pytest.mark.parametrize("num_nodes", (3,)) @pytest.mark.asyncio async def test_node_dial_peer(nodes): # Test: Exception raised when dialing a wrong addr with pytest.raises(ConnectionRefusedError): - await nodes[0].dial_peer( - nodes[1].listen_ip, - get_open_port(), - ID("123"), - ) + await nodes[0].dial_peer(nodes[1].listen_ip, get_open_port(), ID("123")) # Test: 0 <-> 1 - await nodes[0].dial_peer( - nodes[1].listen_ip, - nodes[1].listen_port, - nodes[1].peer_id, - ) + await nodes[0].dial_peer(nodes[1].listen_ip, nodes[1].listen_port, nodes[1].peer_id) assert nodes[0].peer_id in nodes[1].host.get_network().connections assert nodes[1].peer_id in nodes[0].host.get_network().connections # Test: Second dial to a connected peer does not open a new connection original_conn = nodes[1].host.get_network().connections[nodes[0].peer_id] - await nodes[0].dial_peer( - nodes[1].listen_ip, - nodes[1].listen_port, - nodes[1].peer_id, - ) + await nodes[0].dial_peer(nodes[1].listen_ip, nodes[1].listen_port, nodes[1].peer_id) assert nodes[1].host.get_network().connections[nodes[0].peer_id] is original_conn # Test: 0 <-> 1 <-> 2 - await nodes[2].dial_peer( - nodes[1].listen_ip, - nodes[1].listen_port, - nodes[1].peer_id, - ) + await nodes[2].dial_peer(nodes[1].listen_ip, nodes[1].listen_port, nodes[1].peer_id) assert nodes[1].peer_id in nodes[2].host.get_network().connections assert nodes[2].peer_id in nodes[1].host.get_network().connections assert len(nodes[1].host.get_network().connections) == 2 -@pytest.mark.parametrize( - "num_nodes", - ( - 3, - ) -) +@pytest.mark.parametrize("num_nodes", (3,)) @pytest.mark.asyncio async def test_node_dial_peer_maddr(nodes): # Test: 0 <-> 1 <-> 2 @@ -83,12 +48,7 @@ async def test_node_dial_peer_maddr(nodes): assert nodes[0].peer_id in nodes[1].host.get_network().connections -@pytest.mark.parametrize( - "num_nodes", - ( - 2, - ) -) +@pytest.mark.parametrize("num_nodes", (2,)) @pytest.mark.asyncio async def test_node_connect_preferred_nodes(nodes): new_node = NodeFactory( diff --git a/tests/libp2p/bcc/test_req_resp.py b/tests/libp2p/bcc/test_req_resp.py index 60ba063d09..68b87d5dd6 100644 --- a/tests/libp2p/bcc/test_req_resp.py +++ b/tests/libp2p/bcc/test_req_resp.py @@ -2,28 +2,14 @@ import pytest -from trinity.protocol.bcc_libp2p.node import ( - REQ_RESP_HELLO_SSZ, -) -from trinity.protocol.bcc_libp2p.configs import ( - ResponseCode, -) -from trinity.protocol.bcc_libp2p.exceptions import ( - HandshakeFailure, -) -from trinity.protocol.bcc_libp2p.messages import ( - HelloRequest, -) -from trinity.protocol.bcc_libp2p.utils import ( - read_req, - write_resp, -) +from trinity.protocol.bcc_libp2p.configs import ResponseCode +from trinity.protocol.bcc_libp2p.exceptions import HandshakeFailure +from trinity.protocol.bcc_libp2p.messages import HelloRequest +from trinity.protocol.bcc_libp2p.node import REQ_RESP_HELLO_SSZ +from trinity.protocol.bcc_libp2p.utils import read_req, write_resp -@pytest.mark.parametrize( - "num_nodes", - (2,), -) +@pytest.mark.parametrize("num_nodes", (2,)) @pytest.mark.asyncio async def test_hello_success(nodes_with_chain): nodes = nodes_with_chain @@ -34,19 +20,19 @@ async def test_hello_success(nodes_with_chain): assert nodes[0].peer_id in nodes[1].handshaked_peers -@pytest.mark.parametrize( - "num_nodes", - (2,), -) +@pytest.mark.parametrize("num_nodes", (2,)) @pytest.mark.asyncio -async def test_hello_failure_invalid_hello_packet(nodes_with_chain, monkeypatch, mock_timeout): +async def test_hello_failure_invalid_hello_packet( + nodes_with_chain, monkeypatch, mock_timeout +): nodes = nodes_with_chain await nodes[0].dial_peer_maddr(nodes[1].listen_maddr_with_peer_id) def _make_inconsistent_hello_packet(): return HelloRequest( - fork_version=b"\x12\x34\x56\x78", # version different from another node. + fork_version=b"\x12\x34\x56\x78" # version different from another node. ) + monkeypatch.setattr(nodes[0], "_make_hello_packet", _make_inconsistent_hello_packet) monkeypatch.setattr(nodes[1], "_make_hello_packet", _make_inconsistent_hello_packet) # Test: Handshake fails when either side sends invalid hello packets. @@ -56,10 +42,7 @@ def _make_inconsistent_hello_packet(): assert nodes[0].peer_id not in nodes[1].handshaked_peers -@pytest.mark.parametrize( - "num_nodes", - (2,), -) +@pytest.mark.parametrize("num_nodes", (2,)) @pytest.mark.asyncio async def test_hello_failure_failure_response(nodes_with_chain): nodes = nodes_with_chain @@ -69,6 +52,7 @@ async def fake_handle_hello(stream): await read_req(stream, HelloRequest) # The overridden `resp_code` can be anything other than `ResponseCode.SUCCESS` await write_resp(stream, "error msg", ResponseCode.INVALID_REQUEST) + # Mock the handler. nodes[1].host.set_stream_handler(REQ_RESP_HELLO_SSZ, fake_handle_hello) # Test: Handshake fails when the response is not success. diff --git a/tests/libp2p/bcc/test_utils.py b/tests/libp2p/bcc/test_utils.py index 6ca27dc8ec..ce90848530 100644 --- a/tests/libp2p/bcc/test_utils.py +++ b/tests/libp2p/bcc/test_utils.py @@ -1,22 +1,16 @@ import asyncio -from typing import ( - NamedTuple, -) - -import pytest +from typing import NamedTuple from eth_keys import datatypes +from libp2p.peer.id import ID +import pytest -from trinity.protocol.bcc_libp2p.configs import ( - ResponseCode, -) +from trinity.protocol.bcc_libp2p.configs import ResponseCode from trinity.protocol.bcc_libp2p.exceptions import ( ReadMessageFailure, WriteMessageFailure, ) -from trinity.protocol.bcc_libp2p.messages import ( - HelloRequest, -) +from trinity.protocol.bcc_libp2p.messages import HelloRequest from trinity.protocol.bcc_libp2p.utils import ( peer_id_from_pubkey, read_req, @@ -25,18 +19,13 @@ write_resp, ) -from libp2p.peer.id import ( - ID, -) - - # Wrong type of `fork_version`, which should be `bytes4`. invalid_ssz_msg = HelloRequest(fork_version="1") def test_peer_id_from_pubkey(): pubkey = datatypes.PublicKey( - b'n\x85UD\xe9^\xbfo\x05\xd1z\xbd\xe5k\x87Y\xe9\xfa\xb3z:\xf8z\xc5\xd7K\xa6\x00\xbbc\xda4M\x10\x1cO\x88\tl\x82\x7f\xd7\xec6\xd8\xdc\xe2\x9c\xdcG\xa5\xea|\x9e\xc57\xf8G\xbe}\xfa\x10\xe9\x12' # noqa: E501 + b"n\x85UD\xe9^\xbfo\x05\xd1z\xbd\xe5k\x87Y\xe9\xfa\xb3z:\xf8z\xc5\xd7K\xa6\x00\xbbc\xda4M\x10\x1cO\x88\tl\x82\x7f\xd7\xec6\xd8\xdc\xe2\x9c\xdcG\xa5\xea|\x9e\xc57\xf8G\xbe}\xfa\x10\xe9\x12" # noqa: E501 ) peer_id_expected = ID.from_base58("QmQiv6sR3qHqhUVgC5qUBVWi8YzM6HknYbu4oQKVAqPCGF") assert peer_id_from_pubkey(pubkey) == peer_id_expected @@ -57,7 +46,7 @@ async def read(self, n: int = -1) -> bytes: buf = bytearray() # Exit with empty bytes directly if `n == 0`. if n == 0: - return b'' + return b"" # Force to blocking wait for first byte. buf.extend(await self._queue.get()) while not self._queue.empty(): @@ -72,12 +61,7 @@ async def write(self, data: bytes) -> int: return len(data) -@pytest.mark.parametrize( - "msg", - ( - HelloRequest(), - ) -) +@pytest.mark.parametrize("msg", (HelloRequest(),)) @pytest.mark.asyncio async def test_read_write_req_msg(msg): s = FakeNetStream() @@ -86,12 +70,7 @@ async def test_read_write_req_msg(msg): assert msg_read == msg -@pytest.mark.parametrize( - "msg", - ( - HelloRequest(), - ) -) +@pytest.mark.parametrize("msg", (HelloRequest(),)) @pytest.mark.asyncio async def test_read_write_resp_msg(msg): s = FakeNetStream() @@ -150,7 +129,8 @@ async def test_read_resp_failure(monkeypatch, mock_timeout): async def _fake_read(n): return b"" - monkeypatch.setattr(s, 'read', _fake_read) + + monkeypatch.setattr(s, "read", _fake_read) with pytest.raises(ReadMessageFailure): await read_resp(s, HelloRequest) diff --git a/tests/p2p/test_connection.py b/tests/p2p/test_connection.py index 77664c14d0..5d307dfad1 100644 --- a/tests/p2p/test_connection.py +++ b/tests/p2p/test_connection.py @@ -26,7 +26,7 @@ async def test_connection_waits_to_feed_protocol_streams(): async with ConnectionPairFactory(start_streams=False) as (alice_connection, bob_connection): got_ping = asyncio.Event() - async def _handle_ping(msg): + async def _handle_ping(conn, msg): got_ping.set() alice_connection.add_command_handler(Ping, _handle_ping) @@ -142,16 +142,16 @@ async def test_connection_protocol_and_command_handlers(): done = asyncio.Event() - async def _handler_second_protocol(cmd, msg): + async def _handler_second_protocol(conn, cmd, msg): messages_second_protocol.append((cmd, msg)) - async def _handler_cmd_A(msg): + async def _handler_cmd_A(conn, msg): messages_cmd_A.append(msg) - async def _handler_cmd_D(msg): + async def _handler_cmd_D(conn, msg): messages_cmd_D.append(msg) - async def _handler_cmd_C(msg): + async def _handler_cmd_C(conn, msg): done.set() alice_connection.add_protocol_handler(SecondProtocol, _handler_second_protocol) diff --git a/tests/p2p/test_peer_pair_factory.py b/tests/p2p/test_peer_pair_factory.py index 752f77a46e..f4166893bf 100644 --- a/tests/p2p/test_peer_pair_factory.py +++ b/tests/p2p/test_peer_pair_factory.py @@ -3,6 +3,7 @@ import pytest from p2p.tools.factories import ParagonPeerPairFactory +from p2p.p2p_proto import Ping, Pong @pytest.mark.asyncio @@ -11,17 +12,17 @@ async def test_connection_factory_with_ParagonPeer(): got_ping = asyncio.Event() got_pong = asyncio.Event() - def handle_ping(cmd, msg): + async def handle_ping(conn, msg): got_ping.set() bob.base_protocol.send_pong() - def handle_pong(cmd, msg): + async def handle_pong(conn, msg): got_pong.set() - alice.handle_p2p_msg = handle_pong - bob.handle_p2p_msg = handle_ping + alice.connection.add_command_handler(Pong, handle_pong) + bob.connection.add_command_handler(Ping, handle_ping) alice.base_protocol.send_ping() - await asyncio.wait_for(got_ping.wait(), timeout=0.1) - await asyncio.wait_for(got_pong.wait(), timeout=0.1) + await asyncio.wait_for(got_ping.wait(), timeout=1) + await asyncio.wait_for(got_pong.wait(), timeout=1) diff --git a/tests/plugins/eth2/beacon/test_receive_server.py b/tests/plugins/eth2/beacon/test_receive_server.py index 2aa45132ef..8deb9692d6 100644 --- a/tests/plugins/eth2/beacon/test_receive_server.py +++ b/tests/plugins/eth2/beacon/test_receive_server.py @@ -637,9 +637,12 @@ class MockState: slot = XIAO_LONG_BAO_CONFIG.GENESIS_SLOT state = MockState() + def mock_get_head_state(self): + return state + def mock_get_attestation_data_slot(state, data, config): return data.slot - mocker.patch("eth2.beacon.state_machines.base.BeaconStateMachine.state", state) + mocker.patch("eth2.beacon.chains.base.BeaconChain.get_head_state", mock_get_head_state) mocker.patch( "trinity.protocol.bcc.servers.get_attestation_data_slot", mock_get_attestation_data_slot, diff --git a/tests/plugins/eth2/beacon/test_validator.py b/tests/plugins/eth2/beacon/test_validator.py index 774e59bf1b..65dba99989 100644 --- a/tests/plugins/eth2/beacon/test_validator.py +++ b/tests/plugins/eth2/beacon/test_validator.py @@ -127,7 +127,7 @@ def _get_slot_with_validator_selected(candidate_indices, state, config): async def test_validator_propose_block_succeeds(event_loop, event_bus): alice, bob = await get_linked_validators(event_loop=event_loop, event_bus=event_bus) state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() slot, proposer_index = _get_slot_with_validator_selected( alice.validator_privkeys, @@ -159,7 +159,7 @@ async def test_validator_propose_block_succeeds(event_loop, event_bus): async def test_validator_propose_block_fails(event_loop, event_bus): alice, bob = await get_linked_validators(event_loop=event_loop, event_bus=event_bus) state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() assert set(alice.validator_privkeys).intersection(set(bob.validator_privkeys)) == set() slot, proposer_index = _get_slot_with_validator_selected( @@ -183,7 +183,7 @@ async def test_validator_propose_block_fails(event_loop, event_bus): async def test_validator_skip_block(event_loop, event_bus): alice = await get_validator(event_loop=event_loop, event_bus=event_bus, indices=[0]) state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() slot = state.slot + 1 post_state = alice.skip_block( slot=slot, @@ -255,7 +255,7 @@ async def handle_second_tick(slot): async def test_validator_handle_first_tick(event_loop, event_bus, monkeypatch): alice, bob = await get_linked_validators(event_loop=event_loop, event_bus=event_bus) state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() # test: `handle_first_tick` should call `propose_block` if the validator get selected slot_to_propose, index = _get_slot_with_validator_selected( @@ -279,8 +279,7 @@ async def propose_block(proposer_index, slot, state, state_machine, head_block): @pytest.mark.asyncio async def test_validator_handle_second_tick(event_loop, event_bus, monkeypatch): alice, bob = await get_linked_validators(event_loop=event_loop, event_bus=event_bus) - state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() # test: `handle_second_tick` should call `attest` # and skip_block` if `state.slot` is behind latest slot @@ -308,7 +307,7 @@ async def test_validator_get_committee_assigment(event_loop, event_bus): alice_indices = [7] alice = await get_validator(event_loop=event_loop, event_bus=event_bus, indices=alice_indices) state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() epoch = compute_epoch_of_slot(state.slot, state_machine.config.SLOTS_PER_EPOCH) assert alice.this_epoch_assignment[alice_indices[0]][0] == -1 @@ -322,7 +321,7 @@ async def test_validator_attest(event_loop, event_bus, monkeypatch): alice = await get_validator(event_loop=event_loop, event_bus=event_bus, indices=alice_indices) head = alice.chain.get_canonical_head() state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() epoch = compute_epoch_of_slot(state.slot, state_machine.config.SLOTS_PER_EPOCH) assignment = alice._get_this_epoch_assignment(alice_indices[0], epoch) @@ -357,7 +356,7 @@ async def test_validator_include_ready_attestations(event_loop, event_bus, monke alice_indices = list(range(8)) alice = await get_validator(event_loop=event_loop, event_bus=event_bus, indices=alice_indices) state_machine = alice.chain.get_state_machine() - state = state_machine.state + state = alice.chain.get_head_state() attesting_slot = state.slot + 1 attestations = await alice.attest(attesting_slot) diff --git a/tox.ini b/tox.ini index 43af8420c0..cb67f15f3c 100644 --- a/tox.ini +++ b/tox.ini @@ -7,6 +7,7 @@ envlist= py37-rpc-state-{quadratic,sstore,zero_knowledge} py37-libp2p py{36,37}-lint + py{36,37}-lint-eth2 py{36,37}-wheel-cli py36-docs @@ -15,6 +16,15 @@ max-line-length= 100 exclude= ignore= +[isort] +force_sort_within_sections=True +known_third_party=hypothesis,pytest,async_generator,cytoolz,trio_typing,pytest_trio,factory +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=88 + [testenv] usedevelop=True passenv = @@ -132,16 +142,23 @@ commands = pytest -n 1 {posargs:tests/libp2p} [common-lint] -deps = .[p2p,trinity,lint,eth2] +deps = .[p2p,trinity,lint] commands= - flake8 {toxinidir}/p2p - flake8 {toxinidir}/tests - flake8 {toxinidir}/tests-trio - flake8 {toxinidir}/trinity - flake8 {toxinidir}/scripts - flake8 {toxinidir}/eth2 - flake8 {toxinidir}/setup.py - mypy -p p2p -p trinity -p eth2 --config-file {toxinidir}/mypy.ini + flake8 --config {toxinidir}/tox.ini {toxinidir}/p2p + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/__init__.py + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/conftest.py + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/core + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/integration + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/json-fixtures-over-rpc + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/p2p + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/p2p-trio + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/plugins + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests/trinity_long_run + flake8 --config {toxinidir}/tox.ini {toxinidir}/tests-trio + flake8 --config {toxinidir}/tox.ini {toxinidir}/trinity + flake8 --config {toxinidir}/tox.ini {toxinidir}/scripts + flake8 --config {toxinidir}/tox.ini {toxinidir}/setup.py + mypy -p p2p -p trinity --config-file {toxinidir}/mypy.ini [testenv:py36-lint] @@ -152,3 +169,25 @@ commands= {[common-lint]commands} [testenv:py37-lint] deps = {[common-lint]deps} commands= {[common-lint]commands} + + +[testenv:py36-lint-eth2] +deps = .[lint,eth2-lint,eth2] +commands= + flake8 --config {toxinidir}/.flake8-eth2 {toxinidir}/eth2 + flake8 --config {toxinidir}/.flake8-eth2 {toxinidir}/tests/eth2 + flake8 --config {toxinidir}/.flake8-eth2 {toxinidir}/tests/libp2p + mypy -p eth2 --config-file {toxinidir}/mypy.ini + black --check eth2 tests/eth2 tests/libp2p + isort --recursive --check-only eth2 tests/eth2 tests/libp2p + + +[testenv:py37-lint-eth2] +deps = .[lint,eth2-lint,eth2] +commands= + flake8 --config {toxinidir}/.flake8-eth2 {toxinidir}/eth2 + flake8 --config {toxinidir}/.flake8-eth2 {toxinidir}/tests/eth2 + flake8 --config {toxinidir}/.flake8-eth2 {toxinidir}/tests/libp2p + mypy -p eth2 --config-file {toxinidir}/mypy.ini + black --check eth2 tests/eth2 tests/libp2p + isort --recursive --check-only eth2 tests/eth2 tests/libp2p \ No newline at end of file diff --git a/trinity/__init__.py b/trinity/__init__.py index 549aab7e8a..d3145bf0bd 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -18,11 +18,11 @@ def is_uvloop_supported() -> bool: if is_uvloop_supported(): # Set `uvloop` as the default event loop - import asyncio # noqa: E402 + import asyncio from eth._warnings import catch_and_ignore_import_warning with catch_and_ignore_import_warning(): - import uvloop # noqa: E402 + import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/trinity/assets/eip1085/devnet.json b/trinity/assets/eip1085/devnet.json new file mode 100644 index 0000000000..26e4ae4fbf --- /dev/null +++ b/trinity/assets/eip1085/devnet.json @@ -0,0 +1,24 @@ +{ + "accounts":{ + "0x00000000000000000000000000000000deadbeef":{ + "balance":"0x0ad78ebc5ac6200000" + }, + "0x00000000000000000000000000000000deadcafe":{ + "balance":"0x0ad78ebc5ac6200000" + } + }, + "genesis":{ + "author":"0x0000000000000000000000000000000000000000", + "difficulty":"0x1", + "extraData":"0xdeadbeef", + "gasLimit":"0x900000", + "nonce":"0x00000000deadbeef", + "timestamp":"0x0" + }, + "params":{ + "petersburgForkBlock":"0x0", + "chainId":"0xdeadbeef", + "miningMethod":"ethash" + }, + "version":"1" +} diff --git a/trinity/cli_parser.py b/trinity/cli_parser.py index 1a297e0502..a15fba91a5 100644 --- a/trinity/cli_parser.py +++ b/trinity/cli_parser.py @@ -281,7 +281,7 @@ def __call__(self, chain_parser.add_argument( '--genesis', help=( - "File containing a custom genesis block header" + "File containing a custom genesis configuration file per EIP1085" ), action=EIP1085GenesisLoader, ) diff --git a/trinity/plugins/builtin/json_rpc/plugin.py b/trinity/plugins/builtin/json_rpc/plugin.py index a74f5bfa57..e15c8885c7 100644 --- a/trinity/plugins/builtin/json_rpc/plugin.py +++ b/trinity/plugins/builtin/json_rpc/plugin.py @@ -3,9 +3,6 @@ _SubParsersAction, ) import asyncio -from typing import ( - Tuple -) from lahja import EndpointAPI @@ -31,7 +28,6 @@ RPCServer, ) from trinity.rpc.modules import ( - BaseRPCModule, initialize_beacon_modules, initialize_eth1_modules, ) @@ -61,40 +57,44 @@ def configure_parser(cls, arg_parser: ArgumentParser, subparser: _SubParsersActi help="Disables the JSON-RPC Server", ) - def setup_eth1_modules(self, trinity_config: TrinityConfig) -> Tuple[BaseRPCModule, ...]: - eth1_app_config = trinity_config.get_app_config(Eth1AppConfig) + def chain_for_eth1_config(self, trinity_config: TrinityConfig, + eth1_app_config: Eth1AppConfig) -> AsyncChainAPI: chain_config = eth1_app_config.get_chain_config() - chain: AsyncChainAPI db = DBClient.connect(trinity_config.database_ipc_path) if eth1_app_config.database_mode is Eth1DbMode.LIGHT: header_db = HeaderDB(db) event_bus_light_peer_chain = EventBusLightPeerChain(self.event_bus) - chain = chain_config.light_chain_class(header_db, peer_chain=event_bus_light_peer_chain) + return chain_config.light_chain_class( + header_db, peer_chain=event_bus_light_peer_chain + ) elif eth1_app_config.database_mode is Eth1DbMode.FULL: - chain = chain_config.full_chain_class(db) + return chain_config.full_chain_class(db) else: raise Exception(f"Unsupported Database Mode: {eth1_app_config.database_mode}") - return initialize_eth1_modules(chain, self.event_bus) - - def setup_beacon_modules(self) -> Tuple[BaseRPCModule, ...]: - - return initialize_beacon_modules(None, self.event_bus) + def chain_for_config(self, trinity_config: TrinityConfig) -> AsyncChainAPI: + if trinity_config.has_app_config(BeaconAppConfig): + return None + elif trinity_config.has_app_config(Eth1AppConfig): + eth1_app_config = trinity_config.get_app_config(Eth1AppConfig) + return self.chain_for_eth1_config(trinity_config, eth1_app_config) + else: + raise Exception("Unsupported Node Type") def do_start(self) -> None: - trinity_config = self.boot_info.trinity_config + chain = self.chain_for_config(trinity_config) if trinity_config.has_app_config(Eth1AppConfig): - modules = self.setup_eth1_modules(trinity_config) + modules = initialize_eth1_modules(chain, self.event_bus) elif trinity_config.has_app_config(BeaconAppConfig): - modules = self.setup_beacon_modules() + modules = initialize_beacon_modules(chain, self.event_bus) else: raise Exception("Unsupported Node Type") - rpc = RPCServer(modules, self.event_bus) + rpc = RPCServer(modules, chain, self.event_bus) ipc_server = IPCServer(rpc, self.boot_info.trinity_config.jsonrpc_ipc_path) asyncio.ensure_future(exit_with_services( diff --git a/trinity/plugins/builtin/request_server/plugin.py b/trinity/plugins/builtin/request_server/plugin.py index 2db79c161b..3618a9f404 100644 --- a/trinity/plugins/builtin/request_server/plugin.py +++ b/trinity/plugins/builtin/request_server/plugin.py @@ -27,18 +27,10 @@ from trinity.extensibility import ( AsyncioIsolatedPlugin, ) -from trinity.protocol.bcc.servers import ( - BCCRequestServer, -) -from trinity.protocol.eth.servers import ( - ETHRequestServer -) -from trinity.protocol.les.servers import ( - LightRequestServer, -) -from trinity._utils.shutdown import ( - exit_with_services, -) +from trinity.protocol.bcc.servers import BCCRequestServer +from trinity.protocol.eth.servers import ETHRequestServer +from trinity.protocol.les.servers import LightRequestServer +from trinity._utils.shutdown import exit_with_services class RequestServerPlugin(AsyncioIsolatedPlugin): diff --git a/trinity/plugins/builtin/syncer/cli.py b/trinity/plugins/builtin/syncer/cli.py index 200dedc12a..e7ed764106 100644 --- a/trinity/plugins/builtin/syncer/cli.py +++ b/trinity/plugins/builtin/syncer/cli.py @@ -11,14 +11,20 @@ ) from eth_utils import ( + decode_hex, is_hex, + to_int, remove_0x_prefix, - decode_hex, ValidationError, ) from trinity.sync.common.checkpoint import Checkpoint +from .etherscan_api import ( + get_block_by_number, + get_latest_block, +) + def is_block_hash(value: str) -> bool: return is_hex(value) and len(remove_0x_prefix(value)) == 64 @@ -34,7 +40,31 @@ def parse_checkpoint_uri(uri: str) -> Checkpoint: except ValueError as e: raise ValidationError(str(e)) - scheme, netloc, path, query = parsed.scheme, parsed.netloc, parsed.path.lower(), parsed.query + path = parsed.path.lower() + if path.startswith('/byhash'): + return parse_byhash_uri(parsed) + elif path == '/byetherscan/latest': + return parse_byetherscan_uri(parsed) + else: + raise ValidationError("Not a valid checkpoint URI") + + +BLOCKS_FROM_TIP = 50 + + +def parse_byetherscan_uri(parsed: urllib.parse.ParseResult) -> Checkpoint: + + latest_block_number = get_latest_block() + checkpoint_block_number = latest_block_number - BLOCKS_FROM_TIP + checkpoint_block_response = get_block_by_number(checkpoint_block_number) + checkpoint_score = to_int(hexstr=checkpoint_block_response['totalDifficulty']) + checkpoint_hash = checkpoint_block_response['hash'] + + return Checkpoint(Hash32(decode_hex(checkpoint_hash)), checkpoint_score) + + +def parse_byhash_uri(parsed: urllib.parse.ParseResult) -> Checkpoint: + scheme, netloc, query = parsed.scheme, parsed.netloc, parsed.query try: parsed_query = urllib.parse.parse_qsl(query) @@ -48,10 +78,9 @@ def parse_checkpoint_uri(uri: str) -> Checkpoint: # allows copying out a value from e.g etherscan. score = remove_non_digits(query_dict.get('score', '')) - is_by_hash = path.startswith('/byhash') parts = PurePosixPath(parsed.path).parts - if len(parts) != 3 or scheme != 'eth' or netloc != 'block' or not is_by_hash or not score: + if len(parts) != 3 or scheme != 'eth' or netloc != 'block' or not score: raise ValidationError( 'checkpoint string must be of the form' '"eth://block/byhash/?score="' diff --git a/trinity/plugins/builtin/syncer/etherscan_api.py b/trinity/plugins/builtin/syncer/etherscan_api.py new file mode 100644 index 0000000000..67277cdbe4 --- /dev/null +++ b/trinity/plugins/builtin/syncer/etherscan_api.py @@ -0,0 +1,54 @@ +from typing import ( + Any, + Dict, +) +from eth_utils import ( + to_hex, + to_int, +) + +import requests + +from trinity.exceptions import BaseTrinityError + + +ETHERSCAN_API_URL = "https://api.etherscan.io/api" +ETHERSCAN_PROXY_API_URL = f"{ETHERSCAN_API_URL}?module=proxy" + + +class EtherscanAPIError(BaseTrinityError): + pass + + +def etherscan_post(action: str) -> Any: + response = requests.post(f"{ETHERSCAN_PROXY_API_URL}&action={action}") + + if response.status_code not in [200, 201]: + raise EtherscanAPIError( + f"Invalid status code: {response.status_code}, {response.reason}" + ) + + try: + value = response.json() + except ValueError as err: + raise EtherscanAPIError(f"Invalid response: {response.text}") from err + + message = value.get('message', '') + result = value['result'] + + api_error = message == 'NOTOK' or result == 'Error!' + + if api_error: + raise EtherscanAPIError(f"API error: {message}, result: {result}") + + return value['result'] + + +def get_latest_block() -> int: + response = etherscan_post("eth_blockNumber") + return to_int(hexstr=response) + + +def get_block_by_number(block_number: int) -> Dict[str, Any]: + num = to_hex(primitive=block_number) + return etherscan_post(f"eth_getBlockByNumber&tag={num}&boolean=false") diff --git a/trinity/plugins/builtin/syncer/plugin.py b/trinity/plugins/builtin/syncer/plugin.py index bd5c12c697..508f88ff0f 100644 --- a/trinity/plugins/builtin/syncer/plugin.py +++ b/trinity/plugins/builtin/syncer/plugin.py @@ -206,7 +206,8 @@ def configure_parser(cls, arg_parser: ArgumentParser) -> None: action=NormalizeCheckpointURI, help=( "Start beam sync from a trusted checkpoint specified using URI syntax:" - "eth://block/byhash/?score=" + "By specific block, eth://block/byhash/?score=" + "Let etherscan pick a block near the tip, eth://block/byetherscan/latest" ), default=None, ) diff --git a/trinity/plugins/eth2/beacon/validator.py b/trinity/plugins/eth2/beacon/validator.py index 085c7863ff..df6136d80e 100644 --- a/trinity/plugins/eth2/beacon/validator.py +++ b/trinity/plugins/eth2/beacon/validator.py @@ -143,7 +143,7 @@ def _get_this_epoch_assignment(self, # update `this_epoch_assignment` if it's outdated if this_epoch > self.this_epoch_assignment[validator_index][0]: state_machine = self.chain.get_state_machine() - state = state_machine.state + state = self.chain.get_head_state() self.this_epoch_assignment[validator_index] = ( this_epoch, get_committee_assignment( @@ -158,7 +158,7 @@ def _get_this_epoch_assignment(self, async def handle_first_tick(self, slot: Slot) -> None: head = self.chain.get_canonical_head() state_machine = self.chain.get_state_machine() - state = state_machine.state + state = self.chain.get_head_state() self.logger.debug( # Align with debug log below bold_green("Head epoch=%s slot=%s state_root=%s"), @@ -212,7 +212,7 @@ async def handle_first_tick(self, slot: Slot) -> None: async def handle_second_tick(self, slot: Slot) -> None: state_machine = self.chain.get_state_machine() - state = state_machine.state + state = self.chain.get_head_state() if state.slot < slot: self.skip_block( slot=slot, @@ -307,7 +307,7 @@ async def attest(self, slot: Slot) -> Tuple[Attestation, ...]: attestations: Tuple[Attestation, ...] = () head = self.chain.get_canonical_head() state_machine = self.chain.get_state_machine() - state = state_machine.state + state = self.chain.get_head_state() epoch = compute_epoch_of_slot(slot, self.slots_per_epoch) validator_assignments = { diff --git a/trinity/protocol/bcc/peer.py b/trinity/protocol/bcc/peer.py index d83ed3d891..8f459afe4c 100644 --- a/trinity/protocol/bcc/peer.py +++ b/trinity/protocol/bcc/peer.py @@ -10,8 +10,8 @@ BroadcastConfig, ) -from p2p.abc import CommandAPI, NodeAPI -from p2p.handshake import DevP2PReceipt, HandshakeReceipt +from p2p.abc import CommandAPI, HandshakeReceiptAPI, NodeAPI +from p2p.handshake import DevP2PReceipt from p2p.peer import ( BasePeer, BasePeerFactory, @@ -78,7 +78,7 @@ class BCCPeer(BasePeer): def process_handshake_receipts(self, devp2p_receipt: DevP2PReceipt, - protocol_receipts: Sequence[HandshakeReceipt]) -> None: + protocol_receipts: Sequence[HandshakeReceiptAPI]) -> None: super().process_handshake_receipts(devp2p_receipt, protocol_receipts) for receipt in protocol_receipts: if isinstance(receipt, BCCHandshakeReceipt): diff --git a/trinity/protocol/bcc/servers.py b/trinity/protocol/bcc/servers.py index 80424ddfbb..ecf0f9c231 100644 --- a/trinity/protocol/bcc/servers.py +++ b/trinity/protocol/bcc/servers.py @@ -420,7 +420,7 @@ def _validate_attestations(self, attestations: Iterable[Attestation]) -> Iterable[Attestation]: state_machine = self.chain.get_state_machine() config = state_machine.config - state = state_machine.state + state = self.chain.get_head_state() for attestation in attestations: # Fast forward to state in future slot in order to pass # attestation.data.slot validity check @@ -567,7 +567,7 @@ def _is_block_seen(self, block: BaseBeaconBlock) -> bool: def get_ready_attestations(self) -> Iterable[Attestation]: state_machine = self.chain.get_state_machine() config = state_machine.config - state = state_machine.state + state = self.chain.get_head_state() for attestation in self.attestation_pool.get_all(): data = attestation.data attestation_slot = get_attestation_data_slot(state, data, config) diff --git a/trinity/protocol/eth/peer.py b/trinity/protocol/eth/peer.py index 5c20a5924a..90fc13ffdd 100644 --- a/trinity/protocol/eth/peer.py +++ b/trinity/protocol/eth/peer.py @@ -16,8 +16,8 @@ BroadcastConfig, ) -from p2p.abc import CommandAPI, NodeAPI -from p2p.handshake import DevP2PReceipt, HandshakeReceipt +from p2p.abc import CommandAPI, ConnectionAPI, HandshakeReceiptAPI, NodeAPI +from p2p.handshake import DevP2PReceipt from p2p.protocol import ( Payload, ) @@ -80,7 +80,7 @@ class ETHPeer(BaseChainPeer): def process_handshake_receipts(self, devp2p_receipt: DevP2PReceipt, - protocol_receipts: Sequence[HandshakeReceipt]) -> None: + protocol_receipts: Sequence[HandshakeReceiptAPI]) -> None: super().process_handshake_receipts(devp2p_receipt, protocol_receipts) for receipt in protocol_receipts: if isinstance(receipt, ETHHandshakeReceipt): @@ -106,17 +106,18 @@ def requests(self) -> ETHExchangeHandler: self._requests = ETHExchangeHandler(self) return self._requests - def handle_sub_proto_msg(self, cmd: CommandAPI, msg: Payload) -> None: - if isinstance(cmd, NewBlock): - msg = cast(Dict[str, Any], msg) - header, _, _ = msg['block'] - actual_head = header.parent_hash - actual_td = msg['total_difficulty'] - header.difficulty - if actual_td > self.head_td: - self.head_hash = actual_head - self.head_td = actual_td - - super().handle_sub_proto_msg(cmd, msg) + def setup_protocol_handlers(self) -> None: + self.connection.add_command_handler(NewBlock, self._handle_new_block) + + async def _handle_new_block(self, connection: ConnectionAPI, msg: Payload) -> None: + msg = cast(Dict[str, Any], msg) + header, _, _ = msg['block'] + actual_head = header.parent_hash + actual_td = msg['total_difficulty'] - header.difficulty + + if actual_td > self.head_td: + self.head_hash = actual_head + self.head_td = actual_td class ETHProxyPeer(BaseProxyPeer): diff --git a/trinity/protocol/les/peer.py b/trinity/protocol/les/peer.py index 71b938bcec..09a7604e0e 100644 --- a/trinity/protocol/les/peer.py +++ b/trinity/protocol/les/peer.py @@ -25,14 +25,12 @@ BroadcastConfig, ) -from p2p.abc import CommandAPI, NodeAPI -from p2p.handshake import DevP2PReceipt, HandshakeReceipt +from p2p.abc import CommandAPI, ConnectionAPI, HandshakeReceiptAPI, NodeAPI +from p2p.handshake import DevP2PReceipt, Handshaker from p2p.peer_pool import BasePeerPool from p2p.typing import Payload from trinity.rlp.block_body import BlockBody -from p2p.handshake import Handshaker - from trinity.protocol.common.peer import ( BaseChainPeer, BaseProxyPeer, @@ -85,7 +83,7 @@ class LESPeer(BaseChainPeer): def process_handshake_receipts(self, devp2p_receipt: DevP2PReceipt, - protocol_receipts: Sequence[HandshakeReceipt]) -> None: + protocol_receipts: Sequence[HandshakeReceiptAPI]) -> None: super().process_handshake_receipts(devp2p_receipt, protocol_receipts) for receipt in protocol_receipts: if isinstance(receipt, LESHandshakeReceipt): @@ -112,14 +110,14 @@ def requests(self) -> LESExchangeHandler: self._requests = LESExchangeHandler(self) return self._requests - def handle_sub_proto_msg(self, cmd: CommandAPI, msg: Payload) -> None: - if isinstance(cmd, Announce): - head_info = cast(Dict[str, Union[int, Hash32, BlockNumber]], msg) - self.head_td = cast(int, head_info['head_td']) - self.head_hash = cast(Hash32, head_info['head_hash']) - self.head_number = cast(BlockNumber, head_info['head_number']) + def setup_protocol_handlers(self) -> None: + self.connection.add_command_handler(Announce, self._handle_announce) - super().handle_sub_proto_msg(cmd, msg) + async def _handle_announce(self, connection: ConnectionAPI, msg: Payload) -> None: + head_info = cast(Dict[str, Union[int, Hash32, BlockNumber]], msg) + self.head_td = cast(int, head_info['head_td']) + self.head_hash = cast(Hash32, head_info['head_hash']) + self.head_number = cast(BlockNumber, head_info['head_number']) class LESProxyPeer(BaseProxyPeer): @@ -165,8 +163,8 @@ async def get_handshakers(self) -> Tuple[Handshaker, ...]: head_num=head.block_number, genesis_hash=genesis_hash, serve_headers=True, - serve_chain_since=0, # TODO: these should be configurable to allow us to serve this data. + serve_chain_since=None, serve_state_since=None, serve_recent_state=None, serve_recent_chain=None, diff --git a/trinity/rpc/main.py b/trinity/rpc/main.py index ee71190669..a9e0daf98b 100644 --- a/trinity/rpc/main.py +++ b/trinity/rpc/main.py @@ -14,6 +14,7 @@ ValidationError, ) +from trinity.chains.base import AsyncChainAPI from trinity.rpc.modules import ( BaseRPCModule, ) @@ -64,9 +65,11 @@ class RPCServer: def __init__(self, modules: Sequence[BaseRPCModule], + chain: AsyncChainAPI=None, event_bus: EndpointAPI=None) -> None: self.event_bus = event_bus self.modules: Dict[str, BaseRPCModule] = {} + self.chain = chain for module in modules: name = module.name.lower() @@ -113,7 +116,7 @@ async def _get_result(self, method = self._lookup_method(request['method']) params = request.get('params', []) result = await execute_with_retries( - self.event_bus, method, params + self.event_bus, method, params, self.chain, ) if request['method'] == 'evm_resetToGenesisFixture': diff --git a/trinity/rpc/modules/_util.py b/trinity/rpc/modules/_util.py new file mode 100644 index 0000000000..4e6658515c --- /dev/null +++ b/trinity/rpc/modules/_util.py @@ -0,0 +1,36 @@ +from typing import ( + Union, +) + +from eth_typing import ( + BlockNumber, +) +from eth_utils import ( + is_integer, +) + +from eth.rlp.headers import ( + BlockHeader, +) + +from trinity.chains.base import AsyncChainAPI + + +async def get_header(chain: AsyncChainAPI, at_block: Union[str, int]) -> BlockHeader: + if at_block == 'pending': + raise NotImplementedError("RPC interface does not support the 'pending' block at this time") + elif at_block == 'latest': + at_header = chain.get_canonical_head() + elif at_block == 'earliest': + # TODO find if genesis block can be non-zero. Why does 'earliest' option even exist? + block = await chain.coro_get_canonical_block_by_number(BlockNumber(0)) + at_header = block.header + # mypy doesn't have user defined type guards yet + # https://github.com/python/mypy/issues/5206 + elif is_integer(at_block) and at_block >= 0: # type: ignore + block = await chain.coro_get_canonical_block_by_number(BlockNumber(int(at_block))) + at_header = block.header + else: + raise TypeError("Unrecognized block reference: %r" % at_block) + + return at_header diff --git a/trinity/rpc/modules/eth.py b/trinity/rpc/modules/eth.py index dd5c9f81e3..c2791cad9b 100644 --- a/trinity/rpc/modules/eth.py +++ b/trinity/rpc/modules/eth.py @@ -41,10 +41,10 @@ SpoofTransaction, ) +from trinity.chains.base import AsyncChainAPI from trinity.constants import ( TO_NETWORKING_BROADCAST_CONFIG, ) -from trinity.chains.base import AsyncChainAPI from trinity.rpc.format import ( block_to_dict, header_to_dict, @@ -65,25 +65,7 @@ validate_transaction_gas_estimation_dict, ) - -async def get_header(chain: AsyncChainAPI, at_block: Union[str, int]) -> BlockHeader: - if at_block == 'pending': - raise NotImplementedError("RPC interface does not support the 'pending' block at this time") - elif at_block == 'latest': - at_header = chain.get_canonical_head() - elif at_block == 'earliest': - # TODO find if genesis block can be non-zero. Why does 'earliest' option even exist? - block = await chain.coro_get_canonical_block_by_number(BlockNumber(0)) - at_header = block.header - # mypy doesn't have user defined type guards yet - # https://github.com/python/mypy/issues/5206 - elif is_integer(at_block) and at_block >= 0: # type: ignore - block = await chain.coro_get_canonical_block_by_number(BlockNumber(int(at_block))) - at_header = block.header - else: - raise TypeError("Unrecognized block reference: %r" % at_block) - - return at_header +from ._util import get_header async def state_at_block( @@ -155,6 +137,7 @@ async def blockNumber(self) -> str: num = self.chain.get_canonical_head().block_number return hex(num) + @retryable(which_block_arg_name='at_block') @format_params(identity, to_int_if_hex) async def call(self, txn_dict: Dict[str, Any], at_block: Union[str, int]) -> str: header = await get_header(self.chain, at_block) @@ -168,6 +151,7 @@ async def coinbase(self) -> str: coinbase_address = ZERO_ADDRESS return encode_hex(coinbase_address) + @retryable(which_block_arg_name='at_block') @format_params(identity, to_int_if_hex) async def estimateGas(self, txn_dict: Dict[str, Any], at_block: Union[str, int]) -> str: header = await get_header(self.chain, at_block) @@ -179,7 +163,7 @@ async def estimateGas(self, txn_dict: Dict[str, Any], at_block: Union[str, int]) async def gasPrice(self) -> str: return hex(int(os.environ.get('TRINITY_GAS_PRICE', to_wei(1, 'gwei')))) - @retryable + @retryable(which_block_arg_name='at_block') @format_params(decode_hex, to_int_if_hex) async def getBalance(self, address: Address, at_block: Union[str, int]) -> str: state = await state_at_block(self.chain, at_block) @@ -211,14 +195,14 @@ async def getBlockTransactionCountByNumber(self, at_block: Union[str, int]) -> s block = await get_block_at_number(self.chain, at_block) return hex(len(block.transactions)) - @retryable + @retryable(which_block_arg_name='at_block') @format_params(decode_hex, to_int_if_hex) async def getCode(self, address: Address, at_block: Union[str, int]) -> str: state = await state_at_block(self.chain, at_block) code = state.get_code(address) return encode_hex(code) - @retryable + @retryable(which_block_arg_name='at_block') @format_params(decode_hex, to_int_if_hex, to_int_if_hex) async def getStorageAt(self, address: Address, position: int, at_block: Union[str, int]) -> str: if not is_integer(position) or position < 0: diff --git a/trinity/rpc/retry.py b/trinity/rpc/retry.py index c3f6df4150..14e2ce8d92 100644 --- a/trinity/rpc/retry.py +++ b/trinity/rpc/retry.py @@ -2,10 +2,12 @@ Tools for retrying failed RPC methods. If we're beam syncing we can fault in missing data from remote peers. """ +import inspect import itertools from typing import ( Any, Callable, + Optional, TypeVar, ) @@ -17,31 +19,73 @@ MissingStorageTrieNode, ) +from trinity.chains.base import AsyncChainAPI from trinity.sync.common.events import ( CollectMissingAccount, CollectMissingBytecode, CollectMissingStorage, ) +from trinity.rpc.modules._util import get_header + Func = Callable[..., Any] Meth = TypeVar('Meth', bound=Func) RETRYABLE_ATTRIBUTE_NAME = '_is_rpc_retryable' +AT_BLOCK_ATTRIBUTE_NAME = '_at_block_parameter' MAX_RETRIES = 1000 -def retryable(func: Meth) -> Meth: - setattr(func, RETRYABLE_ATTRIBUTE_NAME, True) - return func +def retryable(which_block_arg_name: str) -> Func: + """ + A decorator which marks eth_* RPCs which: + - are idempotent + - throw errors which the beam syncer can help to recover from + + :param which_block_arg_name: names one of the arguments of the wrapped function. + Specifically, the arg used to pass in the block identifier ("at_block", usually) + """ + def make_meth_retryable(meth: Meth) -> Meth: + sig = inspect.signature(meth) + if which_block_arg_name not in sig.parameters: + raise Exception( + f'"{which_block_arg_name}" does not name an argument to this function' + ) + + setattr(meth, RETRYABLE_ATTRIBUTE_NAME, True) + setattr(meth, AT_BLOCK_ATTRIBUTE_NAME, which_block_arg_name) + return meth + return make_meth_retryable def is_retryable(func: Func) -> bool: return getattr(func, RETRYABLE_ATTRIBUTE_NAME, False) -async def execute_with_retries(event_bus: EndpointAPI, func: Func, params: Any) -> None: +async def check_requested_block_age(chain: Optional[AsyncChainAPI], + func: Func, params: Any) -> None: + sig = inspect.signature(func) + params = sig.bind(*params) + + try: + at_block_name = getattr(func, AT_BLOCK_ATTRIBUTE_NAME) + except AttributeError as e: + raise Exception("Function {func} was not decorated with @retryable") from e + + at_block = params.arguments[at_block_name] + + requested_header = await get_header(chain, at_block) + requested_block = requested_header.block_number + current_block = chain.get_canonical_head().block_number + + if requested_block < current_block - 64: + raise Exception(f'block "{at_block}" is too old to be fetched over the network') + + +async def execute_with_retries(event_bus: EndpointAPI, func: Func, params: Any, + chain: Optional[AsyncChainAPI]) -> None: """ If a beam sync (or anything which responds to CollectMissingAccount) is running then attempt to fetch missing data from it before giving up. @@ -63,6 +107,8 @@ async def execute_with_retries(event_bus: EndpointAPI, func: Func, params: Any) if not event_bus.is_any_endpoint_subscribed_to(CollectMissingAccount): raise + await check_requested_block_age(chain, func, params) + await event_bus.request(CollectMissingAccount( exc.missing_node_hash, exc.address_hash, @@ -81,6 +127,8 @@ async def execute_with_retries(event_bus: EndpointAPI, func: Func, params: Any) if not event_bus.is_any_endpoint_subscribed_to(CollectMissingBytecode): raise + await check_requested_block_age(chain, func, params) + await event_bus.request(CollectMissingBytecode( bytecode_hash=exc.missing_code_hash, urgent=True, @@ -97,6 +145,8 @@ async def execute_with_retries(event_bus: EndpointAPI, func: Func, params: Any) if not event_bus.is_any_endpoint_subscribed_to(CollectMissingStorage): raise + await check_requested_block_age(chain, func, params) + await event_bus.request(CollectMissingStorage( missing_node_hash=exc.missing_node_hash, storage_key=exc.requested_key, diff --git a/trinity/server.py b/trinity/server.py index 2f1a18f915..0ed74c621f 100644 --- a/trinity/server.py +++ b/trinity/server.py @@ -36,13 +36,11 @@ from trinity.db.eth1.header import BaseAsyncHeaderDB from trinity.protocol.common.context import ChainContext from trinity.protocol.common.peer import BasePeerPool -from trinity.protocol.eth.peer import ETHPeerPool -from trinity.protocol.les.peer import LESPeerPool from trinity.protocol.bcc.context import BeaconContext from trinity.protocol.bcc.peer import BCCPeerPool -from trinity.protocol.bcc.servers import ( - BCCReceiveServer, -) +from trinity.protocol.bcc.servers import BCCReceiveServer +from trinity.protocol.eth.peer import ETHPeerPool +from trinity.protocol.les.peer import LESPeerPool DIAL_IN_OUT_RATIO = 0.75 BOUND_IP = '0.0.0.0' @@ -169,7 +167,7 @@ async def _receive_handshake( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: factory = self.peer_pool.get_peer_factory() handshakers = await factory.get_handshakers() - multiplexer, devp2p_receipt, protocol_receipts = await receive_handshake( + connection = await receive_handshake( reader=reader, writer=writer, private_key=self.privkey, @@ -179,12 +177,7 @@ async def _receive_handshake( ) # Create and register peer in peer_pool - peer = factory.create_peer( - multiplexer=multiplexer, - devp2p_receipt=devp2p_receipt, - protocol_receipts=protocol_receipts, - inbound=True, - ) + peer = factory.create_peer(connection) if self.peer_pool.is_full: await peer.disconnect(DisconnectReason.too_many_peers) diff --git a/trinity/tools/factories.py b/trinity/tools/factories.py index d56e37e792..2ffbfb5447 100644 --- a/trinity/tools/factories.py +++ b/trinity/tools/factories.py @@ -33,7 +33,7 @@ from trinity.protocol.eth.peer import ETHPeer, ETHPeerFactory from trinity.protocol.eth.proto import ETHHandshakeParams, ETHProtocol -from trinity.protocol.les.handshaker import LESV1Handshaker +from trinity.protocol.les.handshaker import LESV2Handshaker, LESV1Handshaker from trinity.protocol.les.peer import LESPeer, LESPeerFactory from trinity.protocol.les.proto import LESHandshakeParams, LESProtocol, LESProtocolV2 @@ -188,6 +188,20 @@ class Meta: ) +class LESV1HandshakerFactory(factory.Factory): + class Meta: + model = LESV1Handshaker + + handshake_params = factory.SubFactory(LESHandshakeParamsFactory, version=LESProtocol.version) + + +class LESV2HandshakerFactory(factory.Factory): + class Meta: + model = LESV2Handshaker + + handshake_params = factory.SubFactory(LESHandshakeParamsFactory, version=LESProtocolV2.version) + + class LESV1Peer(LESPeer): supported_sub_protocols = (LESProtocol,) # type: ignore