diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 2555a7672e..626af01186 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -8,6 +8,7 @@ import logging from collections.abc import Callable from contextlib import suppress +from functools import partial from random import randint from typing import Literal @@ -64,12 +65,16 @@ if parse_version(lz4.__version__) < parse_version("0.23.1"): raise ImportError("Need lz4 >= 0.23.1") - from lz4.block import compress as lz4_compress - from lz4.block import decompress as lz4_decompress + import lz4.block compressions["lz4"] = { - "compress": lz4_compress, - "decompress": lz4_decompress, + "compress": lz4.block.compress, + # Avoid expensive deep copies when deserializing writeable numpy arrays + # See distributed.protocol.numpy.deserialize_numpy_ndarray + # Note that this is only useful for buffers smaller than distributed.comm.shard; + # larger ones are deep-copied between decompression and serialization anyway in + # order to merge them. + "decompress": partial(lz4.block.decompress, return_bytearray=True), } default_compression = "lz4" diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index ca269e68fa..02e73cfb9e 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -138,7 +138,12 @@ def deserialize_numpy_ndarray(header, frames): # This should exclusively happen when the underlying buffer is read-only, e.g. # a read-only mmap.mmap or a bytes object. # Specifically, these are the known use cases: - # 1. decompressed output of a buffer that was not sharded + # 1. decompression with a library that does not support output to bytearray + # (lz4 does; snappy, zlib, and zstd don't). + # Note that this only applies to buffers whose uncompressed size was small + # enough that they weren't sharded (distributed.comm.shard); for larger + # buffers the decompressed output is deep-copied beforehand into a bytearray + # in order to merge it. # 2. unspill with zict <2.3.0 (https://github.com/dask/zict/pull/74) x = np.require(x, requirements=["W"]) diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index c336eaf7dd..2f3a8fdbe5 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -17,7 +17,7 @@ serialize, to_serialize, ) -from distributed.protocol.compression import maybe_compress +from distributed.protocol.compression import default_compression, maybe_compress from distributed.protocol.numpy import itemsize from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.system import MEMORY_LIMIT @@ -216,8 +216,8 @@ def test_itemsize(dt, size): assert itemsize(np.dtype(dt)) == size +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_compress_numpy(): - pytest.importorskip("lz4") x = np.ones(10000000, dtype="i4") frames = dumps({"x": to_serialize(x)}) assert sum(map(nbytes, frames)) < x.nbytes diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 57e56be498..af79507f60 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -31,6 +31,25 @@ def test_protocol(): assert loads(dumps(msg)) == msg +def test_default_compression(): + """Test that the default compression algorithm is lz4 -> snappy -> None. + If neither is installed, test that we don't fall back to the very slow zlib. + """ + try: + import lz4 # noqa: F401 + + assert default_compression == "lz4" + return + except ImportError: + pass + try: + import snappy # noqa: F401 + + assert default_compression == "snappy" + except ImportError: + assert default_compression is None + + @pytest.mark.parametrize( "config,default", [ @@ -49,8 +68,8 @@ def test_compression_config(config, default): assert get_default_compression() == default +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_compression_1(): - pytest.importorskip("lz4") np = pytest.importorskip("numpy") x = np.ones(1000000) b = x.tobytes() @@ -60,8 +79,8 @@ def test_compression_1(): assert {"x": b} == y +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_compression_2(): - pytest.importorskip("lz4") np = pytest.importorskip("numpy") x = np.random.random(10000) msg = dumps(to_serialize(x.data)) @@ -69,8 +88,8 @@ def test_compression_2(): assert all(c is None for c in compression) +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_compression_3(): - pytest.importorskip("lz4") np = pytest.importorskip("numpy") x = np.ones(1000000) frames = dumps({"x": Serialize(x.data)}) @@ -79,8 +98,8 @@ def test_compression_3(): assert {"x": x.data} == y +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_compression_without_deserialization(): - pytest.importorskip("lz4") np = pytest.importorskip("numpy") x = np.ones(1000000) @@ -91,6 +110,18 @@ def test_compression_without_deserialization(): assert all(len(frame) < 1000000 for frame in msg["x"].frames) +def test_lz4_decompression_avoids_deep_copy(): + """Test that lz4 output is a bytearray, not bytes, so that numpy deserialization is + not forced to perform a deep copy to obtain a writeable array. + Note that zlib, zstandard, and snappy don't have this option. + """ + pytest.importorskip("lz4") + a = bytearray(1_000_000) + b = compressions["lz4"]["compress"](a) + c = compressions["lz4"]["decompress"](b) + assert isinstance(c, bytearray) + + def test_small(): assert sum(map(nbytes, dumps(b""))) < 10 assert sum(map(nbytes, dumps(1))) < 10 @@ -106,7 +137,13 @@ def test_small_and_big(): @pytest.mark.parametrize( "lib,compression", - [(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")], + [ + (None, None), + ("zlib", "zlib"), + ("lz4", "lz4"), + ("snappy", "snappy"), + ("zstandard", "zstd"), + ], ) def test_maybe_compress(lib, compression): if lib: @@ -126,7 +163,13 @@ def test_maybe_compress(lib, compression): @pytest.mark.parametrize( "lib,compression", - [(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")], + [ + (None, None), + ("zlib", "zlib"), + ("lz4", "lz4"), + ("snappy", "snappy"), + ("zstandard", "zstd"), + ], ) def test_compression_thread_safety(lib, compression): if lib: @@ -164,7 +207,13 @@ def test_compress_decompress(fn): @pytest.mark.parametrize( "lib,compression", - [(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")], + [ + (None, None), + ("zlib", "zlib"), + ("lz4", "lz4"), + ("snappy", "snappy"), + ("zstandard", "zstd"), + ], ) def test_maybe_compress_config_default(lib, compression): if lib: @@ -183,9 +232,9 @@ def test_maybe_compress_config_default(lib, compression): assert compressions[rc]["decompress"](rd) == payload +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_maybe_compress_sample(): np = pytest.importorskip("numpy") - lz4 = pytest.importorskip("lz4") payload = np.random.randint(0, 255, size=10000).astype("u1").tobytes() fmt, compressed = maybe_compress(payload) assert fmt is None @@ -202,10 +251,9 @@ def test_large_bytes(): assert len(frames[1]) < 1000 -@pytest.mark.slow +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_large_messages(): np = pytest.importorskip("numpy") - pytest.importorskip("lz4") if MEMORY_LIMIT < 8e9: pytest.skip("insufficient memory") @@ -248,8 +296,8 @@ def test_loads_deserialize_False(): assert result == 123 +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_loads_without_deserialization_avoids_compression(): - pytest.importorskip("lz4") b = b"0" * 100000 msg = {"x": 1, "data": to_serialize(b)} @@ -311,12 +359,12 @@ def test_dumps_loads_Serialized(): assert result == result3 +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") def test_maybe_compress_memoryviews(): np = pytest.importorskip("numpy") - pytest.importorskip("lz4") x = np.arange(1000000, dtype="int64") compression, payload = maybe_compress(x.data) - assert compression == "lz4" + assert compression == default_compression assert len(payload) < x.nbytes * 0.75 diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 6832b7375b..cb5d20d1e0 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -33,6 +33,7 @@ serialize_bytes, to_serialize, ) +from distributed.protocol.compression import default_compression from distributed.protocol.serialize import check_dask_serializable from distributed.utils import ensure_memoryview, nbytes from distributed.utils_test import NO_AMM, gen_test, inc @@ -279,9 +280,9 @@ def test_serialize_bytes(kwargs): assert str(x) == str(y) +@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy") @pytest.mark.skipif(np is None, reason="Test needs numpy") def test_serialize_list_compress(): - pytest.importorskip("lz4") x = np.ones(1000000) L = serialize_bytelist(x) assert sum(map(nbytes, L)) < x.nbytes / 2