Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid deep copy on lz4 decompression #7437

Merged
merged 4 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
from collections.abc import Callable
from contextlib import suppress
from functools import partial
from random import randint
from typing import Literal

Expand Down Expand Up @@ -64,12 +65,16 @@
if parse_version(lz4.__version__) < parse_version("0.23.1"):
raise ImportError("Need lz4 >= 0.23.1")

from lz4.block import compress as lz4_compress
from lz4.block import decompress as lz4_decompress
import lz4.block

compressions["lz4"] = {
"compress": lz4_compress,
"decompress": lz4_decompress,
"compress": lz4.block.compress,
# Avoid expensive deep copies when deserializing writeable numpy arrays
# See distributed.protocol.numpy.deserialize_numpy_ndarray
# Note that this is only useful for buffers smaller than distributed.comm.shard;
# larger ones are deep-copied between decompression and serialization anyway in
# order to merge them.
"decompress": partial(lz4.block.decompress, return_bytearray=True),
}
default_compression = "lz4"

Expand Down
7 changes: 6 additions & 1 deletion distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def deserialize_numpy_ndarray(header, frames):
# This should exclusively happen when the underlying buffer is read-only, e.g.
# a read-only mmap.mmap or a bytes object.
# Specifically, these are the known use cases:
# 1. decompressed output of a buffer that was not sharded
# 1. decompression with a library that does not support output to bytearray
# (lz4 does; snappy, zlib, and zstd don't).
# Note that this only applies to buffers whose uncompressed size was small
# enough that they weren't sharded (distributed.comm.shard); for larger
# buffers the decompressed output is deep-copied beforehand into a bytearray
# in order to merge it.
# 2. unspill with zict <2.3.0 (https://github.com/dask/zict/pull/74)
x = np.require(x, requirements=["W"])

Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
serialize,
to_serialize,
)
from distributed.protocol.compression import maybe_compress
from distributed.protocol.compression import default_compression, maybe_compress
from distributed.protocol.numpy import itemsize
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.system import MEMORY_LIMIT
Expand Down Expand Up @@ -216,8 +216,8 @@ def test_itemsize(dt, size):
assert itemsize(np.dtype(dt)) == size


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compress_numpy():
pytest.importorskip("lz4")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, why this change? If we didn't have lz4, snappy, or zstandard installed (all of which are optional I think) then I'd expect this to fail.

The only compressor we have by default, I think, is zlib, and we don't compress with that by default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, if you have snappy but not lz4 it will succeed.
zstandard does not install itself as a default compressor.
Amended the tests to reflect this.

x = np.ones(10000000, dtype="i4")
frames = dumps({"x": to_serialize(x)})
assert sum(map(nbytes, frames)) < x.nbytes
Expand Down
74 changes: 61 additions & 13 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ def test_protocol():
assert loads(dumps(msg)) == msg


def test_default_compression():
"""Test that the default compression algorithm is lz4 -> snappy -> None.
If neither is installed, test that we don't fall back to the very slow zlib.
"""
try:
import lz4 # noqa: F401

assert default_compression == "lz4"
return
except ImportError:
pass
try:
import snappy # noqa: F401

assert default_compression == "snappy"
except ImportError:
assert default_compression is None


@pytest.mark.parametrize(
"config,default",
[
Expand All @@ -49,8 +68,8 @@ def test_compression_config(config, default):
assert get_default_compression() == default


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_1():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
b = x.tobytes()
Expand All @@ -60,17 +79,17 @@ def test_compression_1():
assert {"x": b} == y


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_2():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.random.random(10000)
msg = dumps(to_serialize(x.data))
compression = msgpack.loads(msg[1]).get("compression")
assert all(c is None for c in compression)


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_3():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
frames = dumps({"x": Serialize(x.data)})
Expand All @@ -79,8 +98,8 @@ def test_compression_3():
assert {"x": x.data} == y


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_without_deserialization():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)

Expand All @@ -91,6 +110,18 @@ def test_compression_without_deserialization():
assert all(len(frame) < 1000000 for frame in msg["x"].frames)


def test_lz4_decompression_avoids_deep_copy():
"""Test that lz4 output is a bytearray, not bytes, so that numpy deserialization is
not forced to perform a deep copy to obtain a writeable array.
Note that zlib, zstandard, and snappy don't have this option.
"""
pytest.importorskip("lz4")
a = bytearray(1_000_000)
b = compressions["lz4"]["compress"](a)
c = compressions["lz4"]["decompress"](b)
assert isinstance(c, bytearray)


def test_small():
assert sum(map(nbytes, dumps(b""))) < 10
assert sum(map(nbytes, dumps(1))) < 10
Expand All @@ -106,7 +137,13 @@ def test_small_and_big():

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_maybe_compress(lib, compression):
if lib:
Expand All @@ -126,7 +163,13 @@ def test_maybe_compress(lib, compression):

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_compression_thread_safety(lib, compression):
if lib:
Expand Down Expand Up @@ -164,7 +207,13 @@ def test_compress_decompress(fn):

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_maybe_compress_config_default(lib, compression):
if lib:
Expand All @@ -183,9 +232,9 @@ def test_maybe_compress_config_default(lib, compression):
assert compressions[rc]["decompress"](rd) == payload


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_maybe_compress_sample():
np = pytest.importorskip("numpy")
lz4 = pytest.importorskip("lz4")
payload = np.random.randint(0, 255, size=10000).astype("u1").tobytes()
fmt, compressed = maybe_compress(payload)
assert fmt is None
Expand All @@ -202,10 +251,9 @@ def test_large_bytes():
assert len(frames[1]) < 1000


@pytest.mark.slow
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_large_messages():
np = pytest.importorskip("numpy")
pytest.importorskip("lz4")
if MEMORY_LIMIT < 8e9:
pytest.skip("insufficient memory")

Expand Down Expand Up @@ -248,8 +296,8 @@ def test_loads_deserialize_False():
assert result == 123


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_loads_without_deserialization_avoids_compression():
pytest.importorskip("lz4")
b = b"0" * 100000

msg = {"x": 1, "data": to_serialize(b)}
Expand Down Expand Up @@ -311,12 +359,12 @@ def test_dumps_loads_Serialized():
assert result == result3


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_maybe_compress_memoryviews():
np = pytest.importorskip("numpy")
pytest.importorskip("lz4")
x = np.arange(1000000, dtype="int64")
compression, payload = maybe_compress(x.data)
assert compression == "lz4"
assert compression == default_compression
assert len(payload) < x.nbytes * 0.75


Expand Down
3 changes: 2 additions & 1 deletion distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
serialize_bytes,
to_serialize,
)
from distributed.protocol.compression import default_compression
from distributed.protocol.serialize import check_dask_serializable
from distributed.utils import ensure_memoryview, nbytes
from distributed.utils_test import NO_AMM, gen_test, inc
Expand Down Expand Up @@ -279,9 +280,9 @@ def test_serialize_bytes(kwargs):
assert str(x) == str(y)


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_serialize_list_compress():
pytest.importorskip("lz4")
x = np.ones(1000000)
L = serialize_bytelist(x)
assert sum(map(nbytes, L)) < x.nbytes / 2
Expand Down