Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialization: zero-copy merge subframes when possible #5208

Merged
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from ..utils import ensure_bytes, has_keyword
from . import pickle
from .compression import decompress, maybe_compress
from .utils import frame_split_size, msgpack_opts, pack_frames_prelude, unpack_frames
from .utils import (
frame_split_size,
merge_memoryviews,
msgpack_opts,
pack_frames_prelude,
unpack_frames,
)

dask_serialize = dask.utils.Dispatch("dask_serialize")
dask_deserialize = dask.utils.Dispatch("dask_deserialize")
Expand Down Expand Up @@ -463,15 +469,18 @@ def merge_and_deserialize(header, frames, deserializers=None):
deserialize
serialize_and_split
"""
merged_frames = []
if "split-num-sub-frames" not in header:
merged_frames = frames
else:
merged_frames = []
for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
if n == 1:
merged_frames.append(frames[offset])
else:
merged_frames.append(bytearray().join(frames[offset : offset + n]))
subframes = frames[offset : offset + n]
try:
merged = merge_memoryviews(subframes)
except (ValueError, TypeError):
merged = bytearray().join(subframes)

merged_frames.append(merged)

return deserialize(header, merged_frames, deserializers=deserializers)

Expand Down
4 changes: 3 additions & 1 deletion distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,14 @@ def test_dumps_serialize_numpy_large():
frames = dumps([to_serialize(x)])
dtype, shape = x.dtype, x.shape
checksum = crc32(x)
del x
[y] = loads(frames)

assert (y.dtype, y.shape) == (dtype, shape)
assert crc32(y) == checksum, "Arrays are unequal"

x[:] = 2 # shared buffer; serialization is zero-copy
assert (x == y).all(), "Data was copied"


@pytest.mark.parametrize(
"dt,size",
Expand Down
105 changes: 104 additions & 1 deletion distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from distributed.protocol.utils import pack_frames, unpack_frames
from __future__ import annotations

import pytest

from distributed.protocol.utils import merge_memoryviews, pack_frames, unpack_frames


def test_pack_frames():
Expand All @@ -8,3 +12,102 @@ def test_pack_frames():
frames2 = unpack_frames(b)

assert frames == frames2


class TestMergeMemroyviews:
def test_empty(self):
empty = merge_memoryviews([])
assert isinstance(empty, memoryview) and len(empty) == 0

def test_one(self):
base = bytearray(range(10))
base_mv = memoryview(base)
assert merge_memoryviews([base_mv]) is base_mv

@pytest.mark.parametrize(
"slices",
[
[slice(None, 3), slice(3, None)],
[slice(1, 3), slice(3, None)],
[slice(1, 3), slice(3, -1)],
[slice(0, 0), slice(None)],
[slice(None), slice(-1, -1)],
[slice(0, 0), slice(0, 0)],
[slice(None, 3), slice(3, 7), slice(7, None)],
[slice(2, 3), slice(3, 7), slice(7, 9)],
[slice(2, 3), slice(3, 7), slice(7, 9), slice(9, 9)],
[slice(1, 2), slice(2, 5), slice(5, 8), slice(8, None)],
],
)
def test_parts(self, slices: list[slice]):
base = bytearray(range(10))
base_mv = memoryview(base)

equiv_start = min(s.indices(10)[0] for s in slices)
equiv_stop = max(s.indices(10)[1] for s in slices)
equiv = base_mv[equiv_start:equiv_stop]

parts = [base_mv[s] for s in slices]
result = merge_memoryviews(parts)
assert result.obj is base
assert len(result) == len(equiv)
assert result == equiv

def test_readonly_buffer(self):
pytest.importorskip(
"numpy", reason="Read-only buffer zero-copy merging requires NumPy"
)
base = bytes(range(10))
base_mv = memoryview(base)

result = merge_memoryviews([base_mv[:4], base_mv[4:]])
assert result.obj is base
assert len(result) == len(base)
assert result == base

def test_catch_non_memoryview(self):
with pytest.raises(TypeError, match="Expected memoryview"):
merge_memoryviews([b"1234", memoryview(b"4567")])

with pytest.raises(TypeError, match="expected memoryview"):
merge_memoryviews([memoryview(b"123"), b"1234"])

@pytest.mark.parametrize(
"slices",
[
[slice(None, 3), slice(4, None)],
[slice(None, 3), slice(2, None)],
[slice(1, 3), slice(3, 6), slice(9, None)],
],
)
def test_catch_gaps(self, slices: list[slice]):
base = bytearray(range(10))
base_mv = memoryview(base)

parts = [base_mv[s] for s in slices]
with pytest.raises(ValueError, match="does not start where the previous ends"):
merge_memoryviews(parts)

def test_catch_different_buffer(self):
base = bytearray(range(8))
base_mv = memoryview(base)
with pytest.raises(ValueError, match="different buffer"):
merge_memoryviews([base_mv, memoryview(base.copy())])

def test_catch_different_non_contiguous(self):
base = bytearray(range(8))
base_mv = memoryview(base)[::-1]
with pytest.raises(ValueError, match="non-contiguous"):
merge_memoryviews([base_mv[:3], base_mv[3:]])

def test_catch_multidimensional(self):
base = bytearray(range(6))
base_mv = memoryview(base).cast("B", [3, 2])
with pytest.raises(ValueError, match="has 2 dimensions, not 1"):
merge_memoryviews([base_mv[:1], base_mv[1:]])

def test_catch_different_formats(self):
base = bytearray(range(8))
base_mv = memoryview(base)
with pytest.raises(ValueError, match="inconsistent format: I vs B"):
merge_memoryviews([base_mv[:4], base_mv[4:].cast("I")])
117 changes: 117 additions & 0 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import ctypes
import struct
from collections.abc import Sequence

import dask

Expand Down Expand Up @@ -81,3 +85,116 @@ def unpack_frames(b):
start = end

return frames


def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview:
"""
Zero-copy "concatenate" a sequence of contiguous memoryviews.

Returns a new memoryview which slices into the underlying buffer
to extract out the portion equivalent to all of ``mvs`` being concatenated.

All the memoryviews must:
* Share the same underlying buffer (``.obj``)
* When merged, cover a continuous portion of that buffer with no gaps
* Have the same strides
* Be 1-dimensional
* Have the same format
* Be contiguous

Raises ValueError if these conditions are not met.
"""
if not mvs:
return memoryview(bytearray())
if len(mvs) == 1:
return mvs[0]

first = mvs[0]
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(first, memoryview):
raise TypeError(f"Expected memoryview; got {type(first)}")
obj = first.obj
format = first.format

first_start_addr = 0
nbytes = 0
for i, mv in enumerate(mvs):
if not isinstance(mv, memoryview):
raise TypeError(f"{i}: expected memoryview; got {type(mv)}")

if mv.nbytes == 0:
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved
continue

if mv.obj is not obj:
raise ValueError(
f"{i}: memoryview has different buffer: {mv.obj!r} vs {obj!r}"
)
if not mv.contiguous:
raise ValueError(f"{i}: memoryview non-contiguous")
if mv.ndim != 1:
raise ValueError(f"{i}: memoryview has {mv.ndim} dimensions, not 1")
if mv.format != format:
raise ValueError(f"{i}: inconsistent format: {mv.format} vs {format}")

start_addr = address_of_memoryview(mv)
if first_start_addr == 0:
first_start_addr = start_addr
else:
expected_addr = first_start_addr + nbytes
if start_addr != expected_addr:
raise ValueError(
f"memoryview {i} does not start where the previous ends. "
f"Expected {expected_addr:x}, starts {start_addr - expected_addr} byte(s) away."
)
nbytes += mv.nbytes

if nbytes == 0:
# all memoryviews were zero-length
assert len(first) == 0
return first

assert first_start_addr != 0, "Underlying buffer is null pointer?!"

base_mv = memoryview(obj).cast("B")
base_start_addr = address_of_memoryview(base_mv)
start_index = first_start_addr - base_start_addr

return base_mv[start_index : start_index + nbytes].cast(format)


one_byte_carr = ctypes.c_byte * 1
# ^ length and type don't matter, just use it to get the address of the first byte


def address_of_memoryview(mv: memoryview) -> int:
"""
Get the pointer to the first byte of a memoryview's data.

If the memoryview is read-only, NumPy must be installed.
"""
# NOTE: this method relies on pointer arithmetic to figure out
# where each memoryview starts within the underlying buffer.
# There's no direct API to get the address of a memoryview,
# so we use a trick through ctypes and the buffer protocol:
# https://mattgwwalker.wordpress.com/2020/10/15/address-of-a-buffer-in-python/
try:
carr = one_byte_carr.from_buffer(mv)
except TypeError:
# `mv` is read-only. `from_buffer` requires the buffer to be writeable.
# See https://bugs.python.org/issue11427 for discussion.
# This typically comes from `deserialize_bytes`, where `mv.obj` is an
# immutable bytestring.
pass
else:
return ctypes.addressof(carr)

try:
import numpy as np
except ImportError:
raise ValueError(
f"Cannot get address of read-only memoryview {mv} since NumPy is not installed."
)

# NumPy doesn't mind read-only buffers. We could just use this method
# for all cases, but it's nice to use the pure-Python method for the common
# case of writeable buffers (created by TCP comms, for example).
return np.asarray(mv).__array_interface__["data"][0]