Skip to content

Commit

Permalink
Avoid deep copy on lz4 decompression (#7437)
Browse files Browse the repository at this point in the history
Speed up deserialization when
a. lz4 is installed, and
b. the buffer is compressible, and
c. the buffer is smaller than 64 MiB (distributed.comm.shard)
Note that the default chunk size in dask.array is 128 MiB.

Note that this does not prevent a memory flare, as there's an unnecessary deep copy upstream as well:
https://github.com/python-lz4/python-lz4/blob/79370987909663d4e6ef743762768ebf970a2383/lz4/block/_block.c#L256
  • Loading branch information
crusaderky authored Dec 29, 2022
1 parent f3995b5 commit 875207b
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 21 deletions.
13 changes: 9 additions & 4 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
from collections.abc import Callable
from contextlib import suppress
from functools import partial
from random import randint
from typing import Literal

Expand Down Expand Up @@ -64,12 +65,16 @@
if parse_version(lz4.__version__) < parse_version("0.23.1"):
raise ImportError("Need lz4 >= 0.23.1")

from lz4.block import compress as lz4_compress
from lz4.block import decompress as lz4_decompress
import lz4.block

compressions["lz4"] = {
"compress": lz4_compress,
"decompress": lz4_decompress,
"compress": lz4.block.compress,
# Avoid expensive deep copies when deserializing writeable numpy arrays
# See distributed.protocol.numpy.deserialize_numpy_ndarray
# Note that this is only useful for buffers smaller than distributed.comm.shard;
# larger ones are deep-copied between decompression and serialization anyway in
# order to merge them.
"decompress": partial(lz4.block.decompress, return_bytearray=True),
}
default_compression = "lz4"

Expand Down
7 changes: 6 additions & 1 deletion distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def deserialize_numpy_ndarray(header, frames):
# This should exclusively happen when the underlying buffer is read-only, e.g.
# a read-only mmap.mmap or a bytes object.
# Specifically, these are the known use cases:
# 1. decompressed output of a buffer that was not sharded
# 1. decompression with a library that does not support output to bytearray
# (lz4 does; snappy, zlib, and zstd don't).
# Note that this only applies to buffers whose uncompressed size was small
# enough that they weren't sharded (distributed.comm.shard); for larger
# buffers the decompressed output is deep-copied beforehand into a bytearray
# in order to merge it.
# 2. unspill with zict <2.3.0 (https://github.com/dask/zict/pull/74)
x = np.require(x, requirements=["W"])

Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
serialize,
to_serialize,
)
from distributed.protocol.compression import maybe_compress
from distributed.protocol.compression import default_compression, maybe_compress
from distributed.protocol.numpy import itemsize
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.system import MEMORY_LIMIT
Expand Down Expand Up @@ -216,8 +216,8 @@ def test_itemsize(dt, size):
assert itemsize(np.dtype(dt)) == size


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compress_numpy():
pytest.importorskip("lz4")
x = np.ones(10000000, dtype="i4")
frames = dumps({"x": to_serialize(x)})
assert sum(map(nbytes, frames)) < x.nbytes
Expand Down
74 changes: 61 additions & 13 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ def test_protocol():
assert loads(dumps(msg)) == msg


def test_default_compression():
"""Test that the default compression algorithm is lz4 -> snappy -> None.
If neither is installed, test that we don't fall back to the very slow zlib.
"""
try:
import lz4 # noqa: F401

assert default_compression == "lz4"
return
except ImportError:
pass
try:
import snappy # noqa: F401

assert default_compression == "snappy"
except ImportError:
assert default_compression is None


@pytest.mark.parametrize(
"config,default",
[
Expand All @@ -49,8 +68,8 @@ def test_compression_config(config, default):
assert get_default_compression() == default


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_1():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
b = x.tobytes()
Expand All @@ -60,17 +79,17 @@ def test_compression_1():
assert {"x": b} == y


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_2():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.random.random(10000)
msg = dumps(to_serialize(x.data))
compression = msgpack.loads(msg[1]).get("compression")
assert all(c is None for c in compression)


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_3():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
frames = dumps({"x": Serialize(x.data)})
Expand All @@ -79,8 +98,8 @@ def test_compression_3():
assert {"x": x.data} == y


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_compression_without_deserialization():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)

Expand All @@ -91,6 +110,18 @@ def test_compression_without_deserialization():
assert all(len(frame) < 1000000 for frame in msg["x"].frames)


def test_lz4_decompression_avoids_deep_copy():
"""Test that lz4 output is a bytearray, not bytes, so that numpy deserialization is
not forced to perform a deep copy to obtain a writeable array.
Note that zlib, zstandard, and snappy don't have this option.
"""
pytest.importorskip("lz4")
a = bytearray(1_000_000)
b = compressions["lz4"]["compress"](a)
c = compressions["lz4"]["decompress"](b)
assert isinstance(c, bytearray)


def test_small():
assert sum(map(nbytes, dumps(b""))) < 10
assert sum(map(nbytes, dumps(1))) < 10
Expand All @@ -106,7 +137,13 @@ def test_small_and_big():

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_maybe_compress(lib, compression):
if lib:
Expand All @@ -126,7 +163,13 @@ def test_maybe_compress(lib, compression):

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_compression_thread_safety(lib, compression):
if lib:
Expand Down Expand Up @@ -164,7 +207,13 @@ def test_compress_decompress(fn):

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_maybe_compress_config_default(lib, compression):
if lib:
Expand All @@ -183,9 +232,9 @@ def test_maybe_compress_config_default(lib, compression):
assert compressions[rc]["decompress"](rd) == payload


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_maybe_compress_sample():
np = pytest.importorskip("numpy")
lz4 = pytest.importorskip("lz4")
payload = np.random.randint(0, 255, size=10000).astype("u1").tobytes()
fmt, compressed = maybe_compress(payload)
assert fmt is None
Expand All @@ -202,10 +251,9 @@ def test_large_bytes():
assert len(frames[1]) < 1000


@pytest.mark.slow
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_large_messages():
np = pytest.importorskip("numpy")
pytest.importorskip("lz4")
if MEMORY_LIMIT < 8e9:
pytest.skip("insufficient memory")

Expand Down Expand Up @@ -248,8 +296,8 @@ def test_loads_deserialize_False():
assert result == 123


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_loads_without_deserialization_avoids_compression():
pytest.importorskip("lz4")
b = b"0" * 100000

msg = {"x": 1, "data": to_serialize(b)}
Expand Down Expand Up @@ -311,12 +359,12 @@ def test_dumps_loads_Serialized():
assert result == result3


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
def test_maybe_compress_memoryviews():
np = pytest.importorskip("numpy")
pytest.importorskip("lz4")
x = np.arange(1000000, dtype="int64")
compression, payload = maybe_compress(x.data)
assert compression == "lz4"
assert compression == default_compression
assert len(payload) < x.nbytes * 0.75


Expand Down
3 changes: 2 additions & 1 deletion distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
serialize_bytes,
to_serialize,
)
from distributed.protocol.compression import default_compression
from distributed.protocol.serialize import check_dask_serializable
from distributed.utils import ensure_memoryview, nbytes
from distributed.utils_test import NO_AMM, gen_test, inc
Expand Down Expand Up @@ -279,9 +280,9 @@ def test_serialize_bytes(kwargs):
assert str(x) == str(y)


@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_serialize_list_compress():
pytest.importorskip("lz4")
x = np.ones(1000000)
L = serialize_bytelist(x)
assert sum(map(nbytes, L)) < x.nbytes / 2
Expand Down

0 comments on commit 875207b

Please sign in to comment.