Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid deep copy on lz4 decompression #7437

Merged
merged 4 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
from collections.abc import Callable
from contextlib import suppress
from functools import partial
from random import randint
from typing import Literal

Expand Down Expand Up @@ -64,12 +65,16 @@
if parse_version(lz4.__version__) < parse_version("0.23.1"):
raise ImportError("Need lz4 >= 0.23.1")

from lz4.block import compress as lz4_compress
from lz4.block import decompress as lz4_decompress
import lz4.block

compressions["lz4"] = {
"compress": lz4_compress,
"decompress": lz4_decompress,
"compress": lz4.block.compress,
# Avoid expensive deep copies when deserializing writeable numpy arrays
# See distributed.protocol.numpy.deserialize_numpy_ndarray
# Note that this is only useful for buffers smaller than distributed.comm.shard;
# larger ones are deep-copied between decompression and serialization anyway in
# order to merge them.
"decompress": partial(lz4.block.decompress, return_bytearray=True),
}
default_compression = "lz4"

Expand Down
7 changes: 6 additions & 1 deletion distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def deserialize_numpy_ndarray(header, frames):
# This should exclusively happen when the underlying buffer is read-only, e.g.
# a read-only mmap.mmap or a bytes object.
# Specifically, these are the known use cases:
# 1. decompressed output of a buffer that was not sharded
# 1. decompression with a library that does not support output to bytearray
# (lz4 does; snappy, zlib, and zstd don't).
# Note that this only applies to buffers whose uncompressed size was small
# enough that they weren't sharded (distributed.comm.shard); for larger
# buffers the decompressed output is deep-copied beforehand into a bytearray
# in order to merge it.
# 2. unspill with zict <2.3.0 (https://github.com/dask/zict/pull/74)
x = np.require(x, requirements=["W"])

Expand Down
1 change: 0 additions & 1 deletion distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def test_itemsize(dt, size):


def test_compress_numpy():
pytest.importorskip("lz4")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, why this change? If we didn't have lz4, snappy, or zstandard installed (all of which are optional I think) then I'd expect this to fail.

The only compressor we have by default, I think, is zlib, and we don't compress with that by default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, if you have snappy but not lz4 it will succeed.
zstandard does not install itself as a default compressor.
Amended the tests to reflect this.

x = np.ones(10000000, dtype="i4")
frames = dumps({"x": to_serialize(x)})
assert sum(map(nbytes, frames)) < x.nbytes
Expand Down
47 changes: 34 additions & 13 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_compression_config(config, default):


def test_compression_1():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
b = x.tobytes()
Expand All @@ -61,7 +60,6 @@ def test_compression_1():


def test_compression_2():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.random.random(10000)
msg = dumps(to_serialize(x.data))
Expand All @@ -70,7 +68,6 @@ def test_compression_2():


def test_compression_3():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)
frames = dumps({"x": Serialize(x.data)})
Expand All @@ -80,7 +77,6 @@ def test_compression_3():


def test_compression_without_deserialization():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
x = np.ones(1000000)

Expand All @@ -91,6 +87,18 @@ def test_compression_without_deserialization():
assert all(len(frame) < 1000000 for frame in msg["x"].frames)


def test_lz4_decompression_avoids_deep_copy():
"""Test that lz4 output is a bytearray, not bytes, so that numpy deserialization is
not forced to perform a deep copy to obtain a writeable array.
Note that zlib, zstandard, and snappy don't have this option.
"""
pytest.importorskip("lz4")
a = bytearray(1_000_000)
b = compressions["lz4"]["compress"](a)
c = compressions["lz4"]["decompress"](b)
assert isinstance(c, bytearray)


def test_small():
assert sum(map(nbytes, dumps(b""))) < 10
assert sum(map(nbytes, dumps(1))) < 10
Expand All @@ -106,7 +114,13 @@ def test_small_and_big():

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_maybe_compress(lib, compression):
if lib:
Expand All @@ -126,7 +140,13 @@ def test_maybe_compress(lib, compression):

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_compression_thread_safety(lib, compression):
if lib:
Expand Down Expand Up @@ -164,7 +184,13 @@ def test_compress_decompress(fn):

@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
[
(None, None),
("zlib", "zlib"),
("lz4", "lz4"),
("snappy", "snappy"),
("zstandard", "zstd"),
],
)
def test_maybe_compress_config_default(lib, compression):
if lib:
Expand All @@ -185,7 +211,6 @@ def test_maybe_compress_config_default(lib, compression):

def test_maybe_compress_sample():
np = pytest.importorskip("numpy")
lz4 = pytest.importorskip("lz4")
payload = np.random.randint(0, 255, size=10000).astype("u1").tobytes()
fmt, compressed = maybe_compress(payload)
assert fmt is None
Expand All @@ -202,10 +227,8 @@ def test_large_bytes():
assert len(frames[1]) < 1000


@pytest.mark.slow
def test_large_messages():
np = pytest.importorskip("numpy")
pytest.importorskip("lz4")
if MEMORY_LIMIT < 8e9:
pytest.skip("insufficient memory")

Expand Down Expand Up @@ -249,7 +272,6 @@ def test_loads_deserialize_False():


def test_loads_without_deserialization_avoids_compression():
pytest.importorskip("lz4")
b = b"0" * 100000

msg = {"x": 1, "data": to_serialize(b)}
Expand Down Expand Up @@ -313,10 +335,9 @@ def test_dumps_loads_Serialized():

def test_maybe_compress_memoryviews():
np = pytest.importorskip("numpy")
pytest.importorskip("lz4")
x = np.arange(1000000, dtype="int64")
compression, payload = maybe_compress(x.data)
assert compression == "lz4"
assert compression in {"lz4", "snappy", "zstd", "zlib"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would be sad if we used zlib by default in any configuration. I'll bet that it's faster to just send data uncompressed over the network.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right. I misread the code; default compression is lz4 -> snappy -> None.
I've amended the tests and added a specific test for the priority order.

assert len(payload) < x.nbytes * 0.75


Expand Down
1 change: 0 additions & 1 deletion distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def test_serialize_bytes(kwargs):

@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_serialize_list_compress():
pytest.importorskip("lz4")
x = np.ones(1000000)
L = serialize_bytelist(x)
assert sum(map(nbytes, L)) < x.nbytes / 2
Expand Down