Skip to content

Commit

Permalink
New: restore Message.loads() and Message.dumps()
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenein committed Apr 25, 2023
1 parent 216a7c0 commit e4ec436
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 36 deletions.
7 changes: 0 additions & 7 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@ The decorator has been removed. You should inherit your message classes from `#!

Replace annotations like `#!python foo: int = field(1)` with `#!python foo: Annotated[int, Field(1)]`.

## Serialization and deserialization

Replace:

- `#!python message.dumps()` with `#!python bytes(message)` or `#!python message.write_to(…)`
- `#!python loads()` with `#!python read_from(BytesIO(…))`

## Well-known types

`#!python typing.Any`, `#!python datetime.datetime`, and `#!python datetime.timedelta` are no longer mapped into the `.proto` types. Use `#!python pure_protobuf.well_known.Any_`, `#!python pure_protobuf.well_known.Timestamp`, and `#!python pure_protobuf.well_known.Duration` explicitly.
Expand Down
24 changes: 23 additions & 1 deletion pure_protobuf/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC
from io import BytesIO
from typing import IO, Any, ClassVar, Dict, Tuple

from typing_extensions import Self
Expand Down Expand Up @@ -92,15 +93,36 @@ def read_from(cls, io: IO[bytes]) -> Self:

return cls(**values)

@classmethod
def loads(cls, buffer: bytes) -> Self:
"""
Read a message from the buffer.
This is functionally the same as calling `read_from(BytesIO(buffer))`.
"""
return cls.read_from(BytesIO(buffer))

def write_to(self, io: IO[bytes]) -> None:
"""Write the message to the file."""
for _, (name, descriptor) in self.__PROTOBUF_FIELDS_BY_NUMBER__.items():
descriptor.write(getattr(self, name), io)

def __bytes__(self) -> bytes:
"""Convert the message to a bytestring."""
"""
Convert the message to a bytestring.
This is functionally the same as calling `dumps()` or `write_to(BytesIO(…))`.
"""
return to_bytes(BaseMessage.write_to, self)

def dumps(self) -> bytes:
"""
Convert the message to a bytestring.
This is functionally the same as calling `bytes(message)` or `write_to(BytesIO(…))`.
"""
return bytes(self)

def __setattr__(self, name: str, value: Any) -> None: # noqa: D105
super().__setattr__(name, value)
descriptor = self.__PROTOBUF_FIELDS_BY_NAME__[name]
Expand Down
51 changes: 23 additions & 28 deletions tests/message/test_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from io import BytesIO
from typing import List, Optional

from typing_extensions import Annotated
Expand All @@ -22,9 +21,9 @@ class Message(BaseMessage):
message = Message(a=uint(150))
bytes_ = b"\x08\x96\x01"

assert Message.read_from(BytesIO()) == Message()
assert Message.loads(b"") == Message()
assert bytes(message) == bytes_
assert Message.read_from(BytesIO(bytes_)) == message
assert Message.loads(bytes_) == message


def test_simple_message_unknown_field() -> None:
Expand All @@ -33,13 +32,13 @@ class Message(BaseMessage):
a: Annotated[uint, Field(1)] = uint(0)

# fmt: off
assert Message.read_from(BytesIO(
assert Message.loads(
b"\x21\x01\x02\x03\x04\x05\x06\x07\x08" # extra 64-bit
b"\x35\x01\x02\x03\x04" # extra 32-bit
b"\x42\x01\x00" # extra bytes
b"\x10\xFF\x01" # extra varint
b"\x08\x96\x01", # field `a`
)) == Message(a=uint(150))
) == Message(a=uint(150))
# fmt: on


Expand All @@ -56,9 +55,9 @@ class Message(BaseMessage):
message = Message(b="testing")
bytes_ = b"\x12\x07\x74\x65\x73\x74\x69\x6e\x67"

assert Message.read_from(BytesIO()) == Message()
assert Message.loads(b"") == Message()
assert bytes(message) == bytes_
assert Message.read_from(BytesIO(bytes_)) == message
assert Message.loads(bytes_) == message


def test_embedded_message() -> None:
Expand All @@ -78,9 +77,9 @@ class Parent(BaseMessage):
message = Parent(c=Child(a=uint(150)))
bytes_ = b"\x1A\x03\x08\x96\x01"

assert Parent.read_from(BytesIO()) == Parent()
assert Parent.loads(b"") == Parent()
assert bytes(message) == bytes_
assert Parent.read_from(BytesIO(bytes_)) == message
assert Parent.loads(bytes_) == message


def test_merge_embedded_messages_repeated() -> None:
Expand All @@ -102,12 +101,12 @@ class Outer(BaseMessage):

assert (
# fmt: off
Outer.read_from(BytesIO(
Outer.loads(
b"\x0A\x00" # foo == None
b"\x0A\x02\x08\x00" # foo == [0]
b"\x0A\x03\x08\x96\x01" # foo == [150]
b"\x0A\x00", # foo == None
))
)
== Outer(inner=Inner(foo=[uint(0), uint(150)]))
# fmt: on
)
Expand All @@ -126,10 +125,10 @@ class Outer(BaseMessage):

assert (
# fmt: off
Outer.read_from(BytesIO(
Outer.loads(
b"\x0A\x02\x08\x01" # foo == 1
b"\x0A\x02\x08\x02", # foo == 2
))
)
== Outer(inner=Inner(foo=uint(2)))
# fmt: on
)
Expand All @@ -148,7 +147,7 @@ def test_read_unpacked_repeated_as_packed() -> None:
class Test(BaseMessage):
foo: Annotated[List[uint], Field(1, packed=True)]

assert Test.read_from(BytesIO(b"\x08\x01\x08\x02")) == Test(foo=[uint(1), uint(2)])
assert Test.loads(b"\x08\x01\x08\x02") == Test(foo=[uint(1), uint(2)])


def test_read_packed_repeated_as_unpacked() -> None:
Expand All @@ -164,7 +163,7 @@ def test_read_packed_repeated_as_unpacked() -> None:
class Test(BaseMessage):
foo: Annotated[List[uint], Field(1, packed=False)]

assert Test.read_from(BytesIO(b"\x0A\x04\x01\x96\x01\x02")) == Test(
assert Test.loads(b"\x0A\x04\x01\x96\x01\x02") == Test(
foo=[uint(1), uint(150), uint(2)],
)

Expand All @@ -181,7 +180,7 @@ class Parent(BaseMessage):
message = Parent(children=[Child(payload=uint(42)), Child(payload=uint(43))])
encoded = bytes(message)
assert encoded == bytes.fromhex("0a02102a 0a02102b")
assert Parent.read_from(BytesIO(encoded)) == message
assert Parent.loads(encoded) == message


def test_merge_grandchild() -> None:
Expand All @@ -197,17 +196,13 @@ class Child(BaseMessage):
class Parent(BaseMessage):
child: Annotated[Child, Field(1)]

assert Parent.read_from(
BytesIO(
bytes(Parent(child=Child(child=Grandchild(payload=uint(42)))))
+ bytes(Parent(child=Child(child=Grandchild(payload=uint(43))))),
),
assert Parent.loads(
bytes(Parent(child=Child(child=Grandchild(payload=uint(42)))))
+ bytes(Parent(child=Child(child=Grandchild(payload=uint(43))))),
) == Parent(child=Child(child=Grandchild(payload=uint(43))))

assert Parent.read_from(
BytesIO(
bytes(Parent(child=Child(child=Grandchild(payload=uint(42))))) + bytes(Parent(child=Child())),
),
assert Parent.loads(
bytes(Parent(child=Child(child=Grandchild(payload=uint(42))))) + bytes(Parent(child=Child())),
) == Parent(child=Child(child=Grandchild(payload=uint(42))))


Expand All @@ -227,7 +222,7 @@ class Message(BaseMessage):

part_1 = bytes(Message(field=[42, 43]))
part_2 = bytes(Message(field=[100500, 100501]))
assert Message.read_from(BytesIO(part_1 + part_2)) == Message(field=[42, 43, 100500, 100501])
assert Message.loads(part_1 + part_2) == Message(field=[42, 43, 100500, 100501])


def test_one_of_assignment_dataclass() -> None:
Expand Down Expand Up @@ -258,7 +253,7 @@ class Message(BaseMessage):
foo: Annotated[Optional[int], Field(1, one_of=foo_or_bar)] = None
bar: Annotated[Optional[int], Field(2, one_of=foo_or_bar)] = None

message = Message.read_from(BytesIO(b"\x08\x02\x10\x04"))
message = Message.loads(b"\x08\x02\x10\x04")
assert message.foo_or_bar == 2
assert message.bar == 2
assert message.foo is None
Expand All @@ -278,7 +273,7 @@ class Child(BaseMessage):
class Parent(BaseMessage):
child: Annotated[Child, Field(1)]

message = Parent.read_from(BytesIO(bytes.fromhex("0a020802 0a021004")))
message = Parent.loads(bytes.fromhex("0a020802 0a021004"))
assert message.child.foo_or_bar == 2
assert message.child.bar == 2
assert message.child.foo is None
Expand Down

0 comments on commit e4ec436

Please sign in to comment.