Skip to content

Commit

Permalink
Deserialization: zero-copy merge subframes when possible (#5208)
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoseph92 authored Nov 12, 2021
1 parent a787503 commit 5a75023
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 8 deletions.
21 changes: 15 additions & 6 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from ..utils import ensure_bytes, has_keyword
from . import pickle
from .compression import decompress, maybe_compress
from .utils import frame_split_size, msgpack_opts, pack_frames_prelude, unpack_frames
from .utils import (
frame_split_size,
merge_memoryviews,
msgpack_opts,
pack_frames_prelude,
unpack_frames,
)

dask_serialize = dask.utils.Dispatch("dask_serialize")
dask_deserialize = dask.utils.Dispatch("dask_deserialize")
Expand Down Expand Up @@ -466,15 +472,18 @@ def merge_and_deserialize(header, frames, deserializers=None):
deserialize
serialize_and_split
"""
merged_frames = []
if "split-num-sub-frames" not in header:
merged_frames = frames
else:
merged_frames = []
for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
if n == 1:
merged_frames.append(frames[offset])
else:
merged_frames.append(bytearray().join(frames[offset : offset + n]))
subframes = frames[offset : offset + n]
try:
merged = merge_memoryviews(subframes)
except (ValueError, TypeError):
merged = bytearray().join(subframes)

merged_frames.append(merged)

return deserialize(header, merged_frames, deserializers=deserializers)

Expand Down
4 changes: 3 additions & 1 deletion distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,14 @@ def test_dumps_serialize_numpy_large():
frames = dumps([to_serialize(x)])
dtype, shape = x.dtype, x.shape
checksum = crc32(x)
del x
[y] = loads(frames)

assert (y.dtype, y.shape) == (dtype, shape)
assert crc32(y) == checksum, "Arrays are unequal"

x[:] = 2 # shared buffer; serialization is zero-copy
assert (x == y).all(), "Data was copied"


@pytest.mark.parametrize(
"dt,size",
Expand Down
105 changes: 104 additions & 1 deletion distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from distributed.protocol.utils import pack_frames, unpack_frames
from __future__ import annotations

import pytest

from distributed.protocol.utils import merge_memoryviews, pack_frames, unpack_frames


def test_pack_frames():
Expand All @@ -8,3 +12,102 @@ def test_pack_frames():
frames2 = unpack_frames(b)

assert frames == frames2


class TestMergeMemroyviews:
def test_empty(self):
empty = merge_memoryviews([])
assert isinstance(empty, memoryview) and len(empty) == 0

def test_one(self):
base = bytearray(range(10))
base_mv = memoryview(base)
assert merge_memoryviews([base_mv]) is base_mv

@pytest.mark.parametrize(
"slices",
[
[slice(None, 3), slice(3, None)],
[slice(1, 3), slice(3, None)],
[slice(1, 3), slice(3, -1)],
[slice(0, 0), slice(None)],
[slice(None), slice(-1, -1)],
[slice(0, 0), slice(0, 0)],
[slice(None, 3), slice(3, 7), slice(7, None)],
[slice(2, 3), slice(3, 7), slice(7, 9)],
[slice(2, 3), slice(3, 7), slice(7, 9), slice(9, 9)],
[slice(1, 2), slice(2, 5), slice(5, 8), slice(8, None)],
],
)
def test_parts(self, slices: list[slice]):
base = bytearray(range(10))
base_mv = memoryview(base)

equiv_start = min(s.indices(10)[0] for s in slices)
equiv_stop = max(s.indices(10)[1] for s in slices)
equiv = base_mv[equiv_start:equiv_stop]

parts = [base_mv[s] for s in slices]
result = merge_memoryviews(parts)
assert result.obj is base
assert len(result) == len(equiv)
assert result == equiv

def test_readonly_buffer(self):
pytest.importorskip(
"numpy", reason="Read-only buffer zero-copy merging requires NumPy"
)
base = bytes(range(10))
base_mv = memoryview(base)

result = merge_memoryviews([base_mv[:4], base_mv[4:]])
assert result.obj is base
assert len(result) == len(base)
assert result == base

def test_catch_non_memoryview(self):
with pytest.raises(TypeError, match="Expected memoryview"):
merge_memoryviews([b"1234", memoryview(b"4567")])

with pytest.raises(TypeError, match="expected memoryview"):
merge_memoryviews([memoryview(b"123"), b"1234"])

@pytest.mark.parametrize(
"slices",
[
[slice(None, 3), slice(4, None)],
[slice(None, 3), slice(2, None)],
[slice(1, 3), slice(3, 6), slice(9, None)],
],
)
def test_catch_gaps(self, slices: list[slice]):
base = bytearray(range(10))
base_mv = memoryview(base)

parts = [base_mv[s] for s in slices]
with pytest.raises(ValueError, match="does not start where the previous ends"):
merge_memoryviews(parts)

def test_catch_different_buffer(self):
base = bytearray(range(8))
base_mv = memoryview(base)
with pytest.raises(ValueError, match="different buffer"):
merge_memoryviews([base_mv, memoryview(base.copy())])

def test_catch_different_non_contiguous(self):
base = bytearray(range(8))
base_mv = memoryview(base)[::-1]
with pytest.raises(ValueError, match="non-contiguous"):
merge_memoryviews([base_mv[:3], base_mv[3:]])

def test_catch_multidimensional(self):
base = bytearray(range(6))
base_mv = memoryview(base).cast("B", [3, 2])
with pytest.raises(ValueError, match="has 2 dimensions, not 1"):
merge_memoryviews([base_mv[:1], base_mv[1:]])

def test_catch_different_formats(self):
base = bytearray(range(8))
base_mv = memoryview(base)
with pytest.raises(ValueError, match="inconsistent format: I vs B"):
merge_memoryviews([base_mv[:4], base_mv[4:].cast("I")])
117 changes: 117 additions & 0 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import ctypes
import struct
from collections.abc import Sequence

import dask

Expand Down Expand Up @@ -81,3 +85,116 @@ def unpack_frames(b):
start = end

return frames


def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview:
"""
Zero-copy "concatenate" a sequence of contiguous memoryviews.
Returns a new memoryview which slices into the underlying buffer
to extract out the portion equivalent to all of ``mvs`` being concatenated.
All the memoryviews must:
* Share the same underlying buffer (``.obj``)
* When merged, cover a continuous portion of that buffer with no gaps
* Have the same strides
* Be 1-dimensional
* Have the same format
* Be contiguous
Raises ValueError if these conditions are not met.
"""
if not mvs:
return memoryview(bytearray())
if len(mvs) == 1:
return mvs[0]

first = mvs[0]
if not isinstance(first, memoryview):
raise TypeError(f"Expected memoryview; got {type(first)}")
obj = first.obj
format = first.format

first_start_addr = 0
nbytes = 0
for i, mv in enumerate(mvs):
if not isinstance(mv, memoryview):
raise TypeError(f"{i}: expected memoryview; got {type(mv)}")

if mv.nbytes == 0:
continue

if mv.obj is not obj:
raise ValueError(
f"{i}: memoryview has different buffer: {mv.obj!r} vs {obj!r}"
)
if not mv.contiguous:
raise ValueError(f"{i}: memoryview non-contiguous")
if mv.ndim != 1:
raise ValueError(f"{i}: memoryview has {mv.ndim} dimensions, not 1")
if mv.format != format:
raise ValueError(f"{i}: inconsistent format: {mv.format} vs {format}")

start_addr = address_of_memoryview(mv)
if first_start_addr == 0:
first_start_addr = start_addr
else:
expected_addr = first_start_addr + nbytes
if start_addr != expected_addr:
raise ValueError(
f"memoryview {i} does not start where the previous ends. "
f"Expected {expected_addr:x}, starts {start_addr - expected_addr} byte(s) away."
)
nbytes += mv.nbytes

if nbytes == 0:
# all memoryviews were zero-length
assert len(first) == 0
return first

assert first_start_addr != 0, "Underlying buffer is null pointer?!"

base_mv = memoryview(obj).cast("B")
base_start_addr = address_of_memoryview(base_mv)
start_index = first_start_addr - base_start_addr

return base_mv[start_index : start_index + nbytes].cast(format)


one_byte_carr = ctypes.c_byte * 1
# ^ length and type don't matter, just use it to get the address of the first byte


def address_of_memoryview(mv: memoryview) -> int:
"""
Get the pointer to the first byte of a memoryview's data.
If the memoryview is read-only, NumPy must be installed.
"""
# NOTE: this method relies on pointer arithmetic to figure out
# where each memoryview starts within the underlying buffer.
# There's no direct API to get the address of a memoryview,
# so we use a trick through ctypes and the buffer protocol:
# https://mattgwwalker.wordpress.com/2020/10/15/address-of-a-buffer-in-python/
try:
carr = one_byte_carr.from_buffer(mv)
except TypeError:
# `mv` is read-only. `from_buffer` requires the buffer to be writeable.
# See https://bugs.python.org/issue11427 for discussion.
# This typically comes from `deserialize_bytes`, where `mv.obj` is an
# immutable bytestring.
pass
else:
return ctypes.addressof(carr)

try:
import numpy as np
except ImportError:
raise ValueError(
f"Cannot get address of read-only memoryview {mv} since NumPy is not installed."
)

# NumPy doesn't mind read-only buffers. We could just use this method
# for all cases, but it's nice to use the pure-Python method for the common
# case of writeable buffers (created by TCP comms, for example).
return np.asarray(mv).__array_interface__["data"][0]

0 comments on commit 5a75023

Please sign in to comment.