Skip to content

Commit

Permalink
Add rechunking for Xarray datasets #45
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-czech committed Sep 22, 2020
1 parent 8917e20 commit 48c0883
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 5 deletions.
74 changes: 73 additions & 1 deletion rechunker/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
"""User-facing functions."""
import html
import textwrap
from typing import Union
from typing import Union, Mapping

import zarr
import dask
import dask.array
import xarray
import tempfile

from rechunker.algorithm import rechunking_plan
from rechunker.types import ArrayProxy, CopySpec, Executor
from xarray.backends.zarr import (
encode_zarr_attr_value,
encode_zarr_variable,
extract_zarr_variable_encoding,
DIMENSION_KEY,
)


class Rechunked:
Expand Down Expand Up @@ -177,6 +185,70 @@ def _get_executor(name: str) -> Executor:
raise ValueError(f"unrecognized executor {name}")


def rechunk_dataset(
source: xarray.Dataset,
encoding: Mapping,
max_mem,
target_store,
temp_store=None,
executor: Union[str, Executor] = "dask",
):
def _encode_zarr_attributes(attrs):
return {k: encode_zarr_attr_value(v) for k, v in attrs.items()}

if isinstance(executor, str):
executor = _get_executor(executor)
if temp_store:
temp_group = zarr.group(temp_store)
else:
temp_group = zarr.group(tempfile.mkdtemp(".zarr", "temp_store_"))
target_group = zarr.group(target_store)
target_group.attrs.update(_encode_zarr_attributes(source.attrs))

copy_specs = []
for variable in source:
array = source[variable].copy()

# Update the array encoding with provided parameters and apply it
has_zarr_chunk_encoding = "chunks" in array.encoding
array.encoding.update(encoding.get(variable, {}))
array = encode_zarr_variable(array)

# Determine target chunking for array and remove it prior to
# validation/extraction ONLY if the array isn't also coming
# from a Zarr store (otherwise blocks need to be checked for overlap)
target_chunks = array.encoding.get("chunks")
if not has_zarr_chunk_encoding:
array.encoding.pop("chunks", None)
array_encoding = extract_zarr_variable_encoding(
array, raise_on_invalid=True, name=variable
)

# Default to chunking based on array shape if not explicitly provided
default_chunks = array_encoding.pop("chunks")
target_chunks = target_chunks or default_chunks

# Extract array attributes along with reserved property for
# xarray dimension names
array_attrs = _encode_zarr_attributes(array.attrs)
array_attrs[DIMENSION_KEY] = encode_zarr_attr_value(array.dims)

copy_spec = _setup_array_rechunk(
dask.array.asarray(array),
target_chunks,
max_mem,
target_group,
target_options=array_encoding,
temp_store_or_group=temp_group,
temp_options=array_encoding,
name=variable,
)
copy_spec.write.array.attrs.update(array_attrs) # type: ignore
copy_specs.append(copy_spec)
plan = executor.prepare_plan(copy_specs)
return Rechunked(executor, plan, source, temp_group, target_group)


def rechunk(
source,
target_chunks,
Expand Down
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
long_description = f.read()


install_requires = [
"dask[array]",
"zarr",
]
install_requires = ["dask[array]", "zarr", "xarray"]
doc_requires = [
"sphinx",
"sphinxcontrib-srclinks",
Expand Down
59 changes: 59 additions & 0 deletions tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import dask.array as dsa
import dask
import dask.core
import xarray
import numpy

from rechunker import api

Expand Down Expand Up @@ -35,6 +37,63 @@ def target_chunks(request):
return request.param


@pytest.mark.parametrize("shape", [(100, 50)])
@pytest.mark.parametrize("source_chunks", [(10, 50)])
@pytest.mark.parametrize("target_chunks", [(20, 10)])
@pytest.mark.parametrize("max_mem", ["10MB"])
@pytest.mark.parametrize("pass_temp", [True, False])
def test_rechunk_dataset(
tmp_path, shape, source_chunks, target_chunks, max_mem, pass_temp
):
target_store = str(tmp_path / "target.zarr")
temp_store = str(tmp_path / "temp.zarr")

a = numpy.arange(numpy.prod(shape)).reshape(shape).astype("f4")
a[-1] = numpy.nan
ds = xarray.Dataset(
dict(
a=xarray.DataArray(
a, dims=["x", "y"], attrs={"a1": 1, "a2": [1, 2, 3], "a3": "x"}
),
b=xarray.DataArray(numpy.ones(shape[0]), dims=["x"]),
c=xarray.DataArray(numpy.ones(shape[1]), dims=["y"]),
),
attrs={"a1": 1, "a2": [1, 2, 3], "a3": "x"},
)
ds = ds.chunk(chunks=dict(zip(["x", "y"], source_chunks)))
encoding = dict(
a=dict(
chunks=target_chunks,
compressor=zarr.Blosc(cname="zstd"),
dtype="int32",
scale_factor=0.1,
_FillValue=-9999,
),
b=dict(chunks=target_chunks[:1]),
)
rechunked = api.rechunk_dataset(
ds,
encoding=encoding,
max_mem=max_mem,
target_store=target_store,
temp_store=temp_store if pass_temp else None,
)
assert isinstance(rechunked, api.Rechunked)
rechunked.execute()

# Validate encoded variables
dst = xarray.open_zarr(target_store, decode_cf=False)
assert dst.a.dtype == encoding["a"]["dtype"]
assert all(dst.a.values[-1] == encoding["a"]["_FillValue"])

# Validate decoded variables
dst = xarray.open_zarr(target_store, decode_cf=True)
assert dst.a.data.chunksize == target_chunks
assert dst.b.data.chunksize == target_chunks[:1]
assert dst.c.data.chunksize == source_chunks[1:]
xarray.testing.assert_equal(ds.compute(), dst.compute())


@pytest.mark.parametrize("shape", [(8000, 8000)])
@pytest.mark.parametrize("source_chunks", [(200, 8000)])
@pytest.mark.parametrize("dtype", ["f4"])
Expand Down

0 comments on commit 48c0883

Please sign in to comment.