Skip to content
This repository has been archived by the owner on Jul 1, 2021. It is now read-only.

Commit

Permalink
Support multiple identity registries
Browse files Browse the repository at this point in the history
  • Loading branch information
jannikluhn committed Jun 12, 2019
1 parent 5736c94 commit 48d4037
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 98 deletions.
102 changes: 65 additions & 37 deletions p2p/discv5/enr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
abstractmethod,
)
import base64
import collections
from typing import (
Any,
AbstractSet,
Expand Down Expand Up @@ -36,8 +37,9 @@
)

from p2p.discv5.identity_schemes import (
identity_scheme_registry,
default_identity_scheme_registry as default_id_scheme_registry,
IdentityScheme,
IdentitySchemeRegistry,
)
from p2p.discv5.constants import (
MAX_ENR_SIZE,
Expand Down Expand Up @@ -71,11 +73,14 @@ def serialize(cls, enr: "BaseENR") -> Tuple[bytes, ...]:
))

@classmethod
def deserialize(cls, serialized_enr: Sequence[bytes]) -> "UnsignedENR":
def deserialize(cls,
serialized_enr: Sequence[bytes],
identity_scheme_registry: IdentitySchemeRegistry = default_id_scheme_registry,
) -> "UnsignedENR":
cls._validate_serialized_length(serialized_enr)
sequence_number = big_endian_int.deserialize(serialized_enr[0])
kv_pairs = cls._deserialize_kv_pairs(serialized_enr)
return UnsignedENR(sequence_number, kv_pairs)
return UnsignedENR(sequence_number, kv_pairs, identity_scheme_registry)

@classmethod
@to_dict
Expand Down Expand Up @@ -141,11 +146,22 @@ def serialize(cls, enr: "ENR") -> Tuple[bytes, ...]:
return (serialized_signature,) + serialized_content

@classmethod
def deserialize(cls, serialized_enr: Sequence[bytes]) -> "ENR":
def deserialize(cls,
serialized_enr: Sequence[bytes],
identity_scheme_registry: IdentitySchemeRegistry = default_id_scheme_registry,
) -> "ENR":
cls._validate_serialized_length(serialized_enr)
signature = binary.deserialize(serialized_enr[0])
unsigned_enr = ENRContentSedes.deserialize(serialized_enr[1:])
return ENR(unsigned_enr.sequence_number, dict(unsigned_enr), signature)
unsigned_enr = ENRContentSedes.deserialize(
serialized_enr[1:],
identity_scheme_registry=identity_scheme_registry,
)
return ENR(
unsigned_enr.sequence_number,
dict(unsigned_enr),
signature,
identity_scheme_registry,
)

@classmethod
def _validate_serialized_length(cls, serialized_enr: Sequence[bytes]) -> None:
Expand Down Expand Up @@ -173,40 +189,43 @@ def _validate_serialized_length(cls, serialized_enr: Sequence[bytes]) -> None:
class BaseENR(Mapping[bytes, Any], ABC):
def __init__(self,
sequence_number: int,
kv_pairs: Mapping[bytes, Any]) -> None:
kv_pairs: Mapping[bytes, Any],
identity_scheme_registry: IdentitySchemeRegistry = default_id_scheme_registry,
) -> None:
self._sequence_number = sequence_number
self._kv_pairs = dict(kv_pairs)
self._identity_scheme = self._pick_identity_scheme(identity_scheme_registry)

self._validate_sequence_number()
self._validate_required_keys()

def _validate_sequence_number(self) -> None:
if self.sequence_number < 0:
raise ValidationError("Sequence number is negative")

def _validate_required_keys(self) -> None:
for required_key in REQUIRED_ENR_KEYS:
if required_key not in self:
raise ValidationError(f"ENR is missing required key {required_key}")

@property
def sequence_number(self) -> int:
return self._sequence_number

def get_signing_message(self) -> bytes:
return rlp.encode(self, ENRContentSedes)

def get_identity_scheme(self) -> Type[IdentityScheme]:
def _pick_identity_scheme(self,
identity_scheme_registry: IdentitySchemeRegistry,
) -> Type[IdentityScheme]:
try:
identity_scheme_id = self[b"id"]
identity_scheme_id = self[IDENTITY_SCHEME_ENR_KEY]
except KeyError:
raise Exception("unreachable: id key existence is checked during initialization")
raise ValidationError("ENR does not specify identity scheme")

try:
return identity_scheme_registry[identity_scheme_id]
except KeyError:
raise ValidationError(f"ENR uses unsupported identity scheme {identity_scheme_id}")

@property
def identity_scheme(self) -> Type[IdentityScheme]:
return self._identity_scheme

@property
def sequence_number(self) -> int:
return self._sequence_number

def get_signing_message(self) -> bytes:
return rlp.encode(self, ENRContentSedes)

#
# Mapping interface
#
Expand Down Expand Up @@ -246,9 +265,17 @@ def __hash__(self) -> int:
class UnsignedENR(BaseENR, ENRContentSedes):

def to_signed_enr(self, private_key: bytes) -> "ENR":
identity_scheme = self.get_identity_scheme()
signature = identity_scheme.create_signature(self, private_key)
return ENR(self.sequence_number, dict(self), signature)
signature = self.identity_scheme.create_signature(self, private_key)

transient_identity_scheme_registry = IdentitySchemeRegistry()
transient_identity_scheme_registry.register(self.identity_scheme)

return ENR(
self.sequence_number,
dict(self),
signature,
identity_scheme_registry=transient_identity_scheme_registry,
)

def __eq__(self, other: Any) -> bool:
return other.__class__ is self.__class__ and dict(other) == dict(self)
Expand All @@ -264,31 +291,34 @@ class ENR(BaseENR, ENRSedes):
def __init__(self,
sequence_number: int,
kv_pairs: Mapping[bytes, Any],
signature: bytes) -> None:
signature: bytes,
identity_scheme_registry: IdentitySchemeRegistry = default_id_scheme_registry,
) -> None:
self._signature = signature
super().__init__(sequence_number, kv_pairs)
super().__init__(sequence_number, kv_pairs, identity_scheme_registry)

@classmethod
def from_repr(cls, representation: str) -> "ENR":
def from_repr(cls,
representation: str,
identity_scheme_registry: IdentitySchemeRegistry = default_id_scheme_registry,
) -> "ENR":
if not representation.startswith("enr:"):
raise ValidationError(f"Invalid ENR representation: {representation}")

unpadded_b64 = representation[4:]
padded_b64 = unpadded_b64 + "=" * (4 - len(unpadded_b64) % 4)
rlp_encoded = base64.urlsafe_b64decode(padded_b64)
return rlp.decode(rlp_encoded, cls)
return rlp.decode(rlp_encoded, cls, identity_scheme_registry=identity_scheme_registry)

@property
def signature(self) -> bytes:
return self._signature

def validate_signature(self) -> None:
identity_scheme = self.get_identity_scheme()
identity_scheme.validate_signature(self)
self.identity_scheme.validate_signature(self)

def extract_node_address(self) -> bytes:
identity_scheme = self.get_identity_scheme()
return identity_scheme.extract_node_address(self)
return self.identity_scheme.extract_node_address(self)

def __eq__(self, other: Any) -> bool:
return (
Expand All @@ -314,9 +344,7 @@ def __repr__(self) -> str:
))


REQUIRED_ENR_KEYS = (
b"id",
)
IDENTITY_SCHEME_ENR_KEY = b"id"

ENR_KEY_SEDES_MAPPING = {
b"id": binary,
Expand Down
39 changes: 25 additions & 14 deletions p2p/discv5/identity_schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
ABC,
abstractmethod,
)
from collections import (
UserDict,
)

from typing import (
Dict,
Type,
TYPE_CHECKING,
)
Expand All @@ -30,24 +32,33 @@
ENR,
)

# https://github.com/python/mypy/issues/5264#issuecomment-399407428
if TYPE_CHECKING:
IdentitySchemeRegistryBaseType = UserDict[bytes, Type["IdentityScheme"]]
else:
IdentitySchemeRegistryBaseType = UserDict


class IdentitySchemeRegistry(IdentitySchemeRegistryBaseType):

identity_scheme_registry: Dict[bytes, Type["IdentityScheme"]] = {}
def register(self,
identity_scheme_class: Type["IdentityScheme"]
) -> Type["IdentityScheme"]:
"""Class decorator to register identity schemes."""
if identity_scheme_class.id is None:
raise ValueError("Identity schemes must define ID")

if identity_scheme_class.id in self:
raise ValueError(
f"Identity scheme with id {identity_scheme_class.id} is already registered",
)

def register_identity_scheme(identity_scheme_class: Type["IdentityScheme"],
) -> Type["IdentityScheme"]:
"""Class decorator to register identity schemes."""
if identity_scheme_class.id is None:
raise ValueError("Identity schemes must define ID")
self[identity_scheme_class.id] = identity_scheme_class

if identity_scheme_class.id in identity_scheme_registry:
raise ValueError(
f"Identity scheme with id {identity_scheme_class.id} is already registered",
)
return identity_scheme_class

identity_scheme_registry[identity_scheme_class.id] = identity_scheme_class

return identity_scheme_class
default_identity_scheme_registry = IdentitySchemeRegistry()


class IdentityScheme(ABC):
Expand All @@ -70,7 +81,7 @@ def extract_node_address(cls, enr: "ENR") -> bytes:
pass


@register_identity_scheme
@default_identity_scheme_registry.register
class V4IdentityScheme(IdentityScheme):

id = b"v4"
Expand Down
Loading

0 comments on commit 48d4037

Please sign in to comment.