Skip to content

Commit

Permalink
decompress_into
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 10, 2023
1 parent fbb16d6 commit 2967335
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 18 deletions.
78 changes: 64 additions & 14 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import zlib
from collections.abc import Callable, Iterable
from collections.abc import Callable, Sequence
from contextlib import suppress
from functools import partial
from random import randint
Expand All @@ -32,13 +32,14 @@ class Compression(NamedTuple):
name: None | str
compress: Callable[[AnyBytes], AnyBytes]
decompress: Callable[[AnyBytes], AnyBytes]
decompress_into: Callable[[AnyBytes, memoryview], None] | None


compressions: dict[str | None | Literal[False], Compression] = {
None: Compression(None, identity, identity),
False: Compression(None, identity, identity), # alias
"auto": Compression(None, identity, identity),
"zlib": Compression("zlib", zlib.compress, zlib.decompress),
None: Compression(None, identity, identity, None),
False: Compression(None, identity, identity, None), # alias
"auto": Compression(None, identity, identity, None),
"zlib": Compression("zlib", zlib.compress, zlib.decompress, None),
}


Expand All @@ -56,7 +57,9 @@ class Compression(NamedTuple):
except TypeError:
raise ImportError("Need snappy >= 0.5.3")

compressions["snappy"] = Compression("snappy", snappy.compress, snappy.decompress)
compressions["snappy"] = Compression(
"snappy", snappy.compress, snappy.decompress, None
)
compressions["auto"] = compressions["snappy"]

with suppress(ImportError):
Expand All @@ -77,6 +80,7 @@ class Compression(NamedTuple):
# larger ones are deep-copied between decompression and serialization anyway in
# order to merge them.
partial(lz4.block.decompress, return_bytearray=True),
None,
)
compressions["auto"] = compressions["lz4"]

Expand All @@ -89,7 +93,10 @@ class Compression(NamedTuple):
raise ImportError("Need cramjam >= 2.7.0")

compressions["cramjam.lz4"] = Compression(
"cramjam.lz4", cramjam.lz4.compress_block, cramjam.lz4.decompress_block
"cramjam.lz4",
cramjam.lz4.compress_block,
cramjam.lz4.decompress_block,
cramjam.lz4.decompress_block_into,
)
compressions["auto"] = compressions["cramjam.lz4"]

Expand All @@ -112,7 +119,7 @@ def zstd_decompress(data):
zstd_decompressor = zstandard.ZstdDecompressor()
return zstd_decompressor.decompress(data)

compressions["zstd"] = Compression("zstd", zstd_compress, zstd_decompress)
compressions["zstd"] = Compression("zstd", zstd_compress, zstd_decompress, None)


def get_compression_settings(key: str) -> str | None:
Expand Down Expand Up @@ -209,9 +216,52 @@ def maybe_compress(


@context_meter.meter("decompress")
def decompress(header: dict[str, Any], frames: Iterable[AnyBytes]) -> list[AnyBytes]:
"""Decompress frames according to information in the header"""
return [
compressions[name].decompress(frame)
for name, frame in zip(header["compression"], frames)
]
def decompress(
header: dict[str, Any], frames: Sequence[AnyBytes]
) -> tuple[dict[str, Any], list[AnyBytes]]:
"""Decompress frames according to information in the header.
See also
--------
merge_and_deserialize
"""
from distributed.comm.utils import host_array

if "split-num-sub-frames" not in header:
frames = [
compressions[name].decompress(frame)
for name, frame in zip(header["compression"], frames)
]
return header, frames

merged_frames: list[AnyBytes] = []
split_num_sub_frames: list[int] = []
split_offsets: list[int] = []

for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
compression_names = header["compression"][offset : offset + n]
compression = compressions[compression_names[0]]
subframes = frames[offset : offset + n]

if compression.decompress_into and len(set(compression_names)) == 1:
nbytes = header["uncompressed_size"][offset : offset + n]
merged = host_array(sum(nbytes))
merged_offset = 0
for frame_i, nbytes_i in zip(subframes, nbytes):
merged_i = merged[merged_offset : merged_offset + nbytes_i]
compression.decompress_into(frame_i, merged_i)
merged_offset += nbytes_i
merged_frames.append(merged)
split_num_sub_frames.append(1)
split_offsets.append(len(split_offsets))

else:
for name, frame in zip(compression_names, subframes):
merged_frames.append(compressions[name].decompress(frame))
split_num_sub_frames.append(n)
split_offsets.append(offset)

header = header.copy()
header["split-num-sub-frames"] = split_num_sub_frames
header["split-offsets"] = split_offsets
return header, merged_frames
7 changes: 6 additions & 1 deletion distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def dumps( # type: ignore[no-untyped-def]

def _inplace_compress_frames(header, frames):
compression = list(header.get("compression", [None] * len(frames)))
uncompressed_size = tuple(
frame.nbytes if isinstance(frame, memoryview) else len(frame)
for frame in frames
)

for i in range(len(frames)):
if compression[i] is None:
Expand All @@ -52,6 +56,7 @@ def _inplace_compress_frames(header, frames):
)

header["compression"] = tuple(compression)
header["uncompressed_size"] = tuple(uncompressed_size)

def create_serialized_sub_frames(obj: Serialized | Serialize) -> list:
if isinstance(obj, Serialized):
Expand Down Expand Up @@ -134,7 +139,7 @@ def _decode_default(obj):
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
sub_header, sub_frames = decompress(sub_header, sub_frames)
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
Expand Down
7 changes: 6 additions & 1 deletion distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def merge_and_deserialize(header, frames, deserializers=None):
--------
deserialize
serialize_and_split
decompress
"""
if "split-num-sub-frames" not in header:
merged_frames = frames
Expand Down Expand Up @@ -664,6 +665,10 @@ def serialize_bytelist(
) -> list[bytes | bytearray | memoryview]:
header, frames = serialize_and_split(x, **kwargs)
if frames:
header["uncompressed_size"] = [
frame.nbytes if isinstance(frame, memoryview) else len(frame)
for frame in frames
]
header["compression"], frames = zip(
*(maybe_compress(frame, compression=compression) for frame in frames)
)
Expand All @@ -687,7 +692,7 @@ def deserialize_bytes(b):
header = msgpack.loads(header, raw=False, use_list=False)
else:
header = {}
frames = decompress(header, frames)
header, frames = decompress(header, frames)
return merge_and_deserialize(header, frames)


Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def decompress(v):
counters[1] += 1
return zlib.decompress(v)

compressions["dummy"] = Compression("dummy", compress, decompress)
compressions["dummy"] = Compression("dummy", compress, decompress, None)
yield counters
del compressions["dummy"]

Expand Down
46 changes: 45 additions & 1 deletion distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import dask.config
from dask.sizeof import sizeof

from distributed.compatibility import WINDOWS
from distributed.compatibility import WINDOWS, randbytes
from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import compressions, get_compression_settings
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
Expand Down Expand Up @@ -371,3 +371,47 @@ def gen_deeply_nested(depth):

msg = gen_deeply_nested(sys.getrecursionlimit() // 4)
assert isinstance(serialize(msg), tuple)


def test_decompress_into(compression):
"""An object is sharded into two frames and then compressed. When it's decompressed,
decompress_into (if available) is used to write both messages into the same buffer.
"""
a = "".join(c * 2000 for c in "abcdefghij")
assert len(a) == 20_000
msg = to_serialize(a)
frames = dumps(msg, context={"compression": compression}, frame_split_size=10_000)

assert len(frames) == 5
if compression:
assert len(frames[2]) < 600
assert len(frames[3]) < 600
else:
assert len(frames[2]) == len(frames[3]) == 10_000

assert loads(frames) == a


@pytest.mark.parametrize("swap", [False, True])
def test_mixed_compression_for_subframes(compression, swap):
"""Serialize an object that gets sharded into two subframes.
One subframe is compressed, the other isn't.
"""
a = randbytes(10_000)
b = b"x" * 10_000
if swap:
a, b = b, a
msg = to_serialize(a + b)
frames = dumps(msg, context={"compression": compression}, frame_split_size=10_000)

assert len(frames) == 4
if swap and compression:
assert len(frames[2]) < 600
assert len(frames[3]) == 10_000
elif compression:
assert len(frames[2]) == 10_000
assert len(frames[3]) < 600
else:
assert len(frames[2]) == len(frames[3]) == 10_000

assert loads(frames) == a + b

0 comments on commit 2967335

Please sign in to comment.