diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 27d3a275398..5b19f64c80a 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -6,7 +6,7 @@ from __future__ import annotations import zlib -from collections.abc import Callable, Iterable +from collections.abc import Callable, Sequence from contextlib import suppress from functools import partial from random import randint @@ -32,13 +32,14 @@ class Compression(NamedTuple): name: None | str compress: Callable[[AnyBytes], AnyBytes] decompress: Callable[[AnyBytes], AnyBytes] + decompress_into: Callable[[AnyBytes, memoryview], None] | None compressions: dict[str | None | Literal[False], Compression] = { - None: Compression(None, identity, identity), - False: Compression(None, identity, identity), # alias - "auto": Compression(None, identity, identity), - "zlib": Compression("zlib", zlib.compress, zlib.decompress), + None: Compression(None, identity, identity, None), + False: Compression(None, identity, identity, None), # alias + "auto": Compression(None, identity, identity, None), + "zlib": Compression("zlib", zlib.compress, zlib.decompress, None), } @@ -56,7 +57,9 @@ class Compression(NamedTuple): except TypeError: raise ImportError("Need snappy >= 0.5.3") - compressions["snappy"] = Compression("snappy", snappy.compress, snappy.decompress) + compressions["snappy"] = Compression( + "snappy", snappy.compress, snappy.decompress, None + ) compressions["auto"] = compressions["snappy"] with suppress(ImportError): @@ -77,6 +80,7 @@ class Compression(NamedTuple): # larger ones are deep-copied between decompression and serialization anyway in # order to merge them. partial(lz4.block.decompress, return_bytearray=True), + None, ) compressions["auto"] = compressions["lz4"] @@ -89,7 +93,10 @@ class Compression(NamedTuple): raise ImportError("Need cramjam >= 2.7.0") compressions["cramjam.lz4"] = Compression( - "cramjam.lz4", cramjam.lz4.compress_block, cramjam.lz4.decompress_block + "cramjam.lz4", + cramjam.lz4.compress_block, + cramjam.lz4.decompress_block, + cramjam.lz4.decompress_block_into, ) compressions["auto"] = compressions["cramjam.lz4"] @@ -112,7 +119,7 @@ def zstd_decompress(data): zstd_decompressor = zstandard.ZstdDecompressor() return zstd_decompressor.decompress(data) - compressions["zstd"] = Compression("zstd", zstd_compress, zstd_decompress) + compressions["zstd"] = Compression("zstd", zstd_compress, zstd_decompress, None) def get_compression_settings(key: str) -> str | None: @@ -209,9 +216,52 @@ def maybe_compress( @context_meter.meter("decompress") -def decompress(header: dict[str, Any], frames: Iterable[AnyBytes]) -> list[AnyBytes]: - """Decompress frames according to information in the header""" - return [ - compressions[name].decompress(frame) - for name, frame in zip(header["compression"], frames) - ] +def decompress( + header: dict[str, Any], frames: Sequence[AnyBytes] +) -> tuple[dict[str, Any], list[AnyBytes]]: + """Decompress frames according to information in the header. + + See also + -------- + merge_and_deserialize + """ + from distributed.comm.utils import host_array + + if "split-num-sub-frames" not in header: + frames = [ + compressions[name].decompress(frame) + for name, frame in zip(header["compression"], frames) + ] + return header, frames + + merged_frames: list[AnyBytes] = [] + split_num_sub_frames: list[int] = [] + split_offsets: list[int] = [] + + for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]): + compression_names = header["compression"][offset : offset + n] + compression = compressions[compression_names[0]] + subframes = frames[offset : offset + n] + + if compression.decompress_into and len(set(compression_names)) == 1: + nbytes = header["uncompressed_size"][offset : offset + n] + merged = host_array(sum(nbytes)) + merged_offset = 0 + for frame_i, nbytes_i in zip(subframes, nbytes): + merged_i = merged[merged_offset : merged_offset + nbytes_i] + compression.decompress_into(frame_i, merged_i) + merged_offset += nbytes_i + merged_frames.append(merged) + split_num_sub_frames.append(1) + split_offsets.append(len(split_offsets)) + + else: + for name, frame in zip(compression_names, subframes): + merged_frames.append(compressions[name].decompress(frame)) + split_num_sub_frames.append(n) + split_offsets.append(offset) + + header = header.copy() + header["split-num-sub-frames"] = split_num_sub_frames + header["split-offsets"] = split_offsets + return header, merged_frames diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index d58ee011297..4855a6b7824 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -44,6 +44,10 @@ def dumps( # type: ignore[no-untyped-def] def _inplace_compress_frames(header, frames): compression = list(header.get("compression", [None] * len(frames))) + uncompressed_size = tuple( + frame.nbytes if isinstance(frame, memoryview) else len(frame) + for frame in frames + ) for i in range(len(frames)): if compression[i] is None: @@ -52,6 +56,7 @@ def _inplace_compress_frames(header, frames): ) header["compression"] = tuple(compression) + header["uncompressed_size"] = tuple(uncompressed_size) def create_serialized_sub_frames(obj: Serialized | Serialize) -> list: if isinstance(obj, Serialized): @@ -134,7 +139,7 @@ def _decode_default(obj): sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] if deserialize: if "compression" in sub_header: - sub_frames = decompress(sub_header, sub_frames) + sub_header, sub_frames = decompress(sub_header, sub_frames) return merge_and_deserialize( sub_header, sub_frames, deserializers=deserializers ) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 0ba323a6a1f..8fe2d85f03b 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -488,6 +488,7 @@ def merge_and_deserialize(header, frames, deserializers=None): -------- deserialize serialize_and_split + decompress """ if "split-num-sub-frames" not in header: merged_frames = frames @@ -664,6 +665,10 @@ def serialize_bytelist( ) -> list[bytes | bytearray | memoryview]: header, frames = serialize_and_split(x, **kwargs) if frames: + header["uncompressed_size"] = [ + frame.nbytes if isinstance(frame, memoryview) else len(frame) + for frame in frames + ] header["compression"], frames = zip( *(maybe_compress(frame, compression=compression) for frame in frames) ) @@ -687,7 +692,7 @@ def deserialize_bytes(b): header = msgpack.loads(header, raw=False, use_list=False) else: header = {} - frames = decompress(header, frames) + header, frames = decompress(header, frames) return merge_and_deserialize(header, frames) diff --git a/distributed/protocol/tests/test_compression.py b/distributed/protocol/tests/test_compression.py index bb3a62fd3af..927c07ea59d 100644 --- a/distributed/protocol/tests/test_compression.py +++ b/distributed/protocol/tests/test_compression.py @@ -22,7 +22,7 @@ def decompress(v): counters[1] += 1 return zlib.decompress(v) - compressions["dummy"] = Compression("dummy", compress, decompress) + compressions["dummy"] = Compression("dummy", compress, decompress, None) yield counters del compressions["dummy"] diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index cdbc64ea38d..70e5449762f 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -9,7 +9,7 @@ import dask.config from dask.sizeof import sizeof -from distributed.compatibility import WINDOWS +from distributed.compatibility import WINDOWS, randbytes from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize from distributed.protocol.compression import compressions, get_compression_settings from distributed.protocol.cuda import cuda_deserialize, cuda_serialize @@ -371,3 +371,47 @@ def gen_deeply_nested(depth): msg = gen_deeply_nested(sys.getrecursionlimit() // 4) assert isinstance(serialize(msg), tuple) + + +def test_decompress_into(compression): + """An object is sharded into two frames and then compressed. When it's decompressed, + decompress_into (if available) is used to write both messages into the same buffer. + """ + a = "".join(c * 2000 for c in "abcdefghij") + assert len(a) == 20_000 + msg = to_serialize(a) + frames = dumps(msg, context={"compression": compression}, frame_split_size=10_000) + + assert len(frames) == 5 + if compression: + assert len(frames[2]) < 600 + assert len(frames[3]) < 600 + else: + assert len(frames[2]) == len(frames[3]) == 10_000 + + assert loads(frames) == a + + +@pytest.mark.parametrize("swap", [False, True]) +def test_mixed_compression_for_subframes(compression, swap): + """Serialize an object that gets sharded into two subframes. + One subframe is compressed, the other isn't. + """ + a = randbytes(10_000) + b = b"x" * 10_000 + if swap: + a, b = b, a + msg = to_serialize(a + b) + frames = dumps(msg, context={"compression": compression}, frame_split_size=10_000) + + assert len(frames) == 4 + if swap and compression: + assert len(frames[2]) < 600 + assert len(frames[3]) == 10_000 + elif compression: + assert len(frames[2]) == 10_000 + assert len(frames[3]) < 600 + else: + assert len(frames[2]) == len(frames[3]) == 10_000 + + assert loads(frames) == a + b