Skip to content

Commit

Permalink
Add tests, handle coordinates, move to main API
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-czech committed Sep 30, 2020
1 parent 8d0efc4 commit fc1b17a
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 170 deletions.
227 changes: 128 additions & 99 deletions rechunker/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""User-facing functions."""
import html
import textwrap
from typing import Union, Mapping
from typing import Union

import zarr
import dask
Expand All @@ -17,6 +17,7 @@
extract_zarr_variable_encoding,
DIMENSION_KEY,
)
from xarray.conventions import encode_dataset_coordinates


class Rechunked:
Expand Down Expand Up @@ -77,18 +78,13 @@ def execute(self, **kwargs):
return self._target

def __repr__(self):
entries = []
entries.append(f"\n* Source : {repr(self._source)}")
if self._intermediate is not None:
intermediate = f"\n* Intermediate: {repr(self._intermediate)}"
else:
intermediate = ""

return textwrap.dedent(
f"""\
<Rechunked>
* Source : {repr(self._source)}{{}}
* Target : {repr(self._target)}
"""
).format(intermediate)
entries.append(f"\n* Intermediate: {repr(self._intermediate)}")
entries.append(f"\n* Target : {repr(self._target)}")
entries = "\n".join(entries)
return f"<Rechunked>{entries}\n"

def _repr_html_(self):
entries = {}
Expand Down Expand Up @@ -145,6 +141,10 @@ def _get_dims_from_zarr_array(z_array):
return z_array.attrs["_ARRAY_DIMENSIONS"]


def _encode_zarr_attributes(attrs):
return {k: encode_zarr_attr_value(v) for k, v in attrs.items()}


def _zarr_empty(shape, store_or_group, chunks, dtype, name=None, **kwargs):
# wrapper that maybe creates the array within a group
if name is not None:
Expand All @@ -158,6 +158,27 @@ def _zarr_empty(shape, store_or_group, chunks, dtype, name=None, **kwargs):
)


ZARR_OPTIONS = [
"compressor",
"filters",
"order",
"cache_metadata",
"cache_attrs",
"overwrite",
]


def _validate_options(options):
if not options:
return
for o in options:
if o not in ZARR_OPTIONS:
raise ValueError(
f"Zarr options must not include {o} (got {o}={options[o]}). "
f"Only the following options are supported: {ZARR_OPTIONS}."
)


def _get_executor(name: str) -> Executor:
# converts a string name into a Executor instance
# imports are conditional to avoid hard dependencies
Expand Down Expand Up @@ -185,70 +206,6 @@ def _get_executor(name: str) -> Executor:
raise ValueError(f"unrecognized executor {name}")


def rechunk_dataset(
source: xarray.Dataset,
encoding: Mapping,
max_mem,
target_store,
temp_store=None,
executor: Union[str, Executor] = "dask",
):
def _encode_zarr_attributes(attrs):
return {k: encode_zarr_attr_value(v) for k, v in attrs.items()}

if isinstance(executor, str):
executor = _get_executor(executor)
if temp_store:
temp_group = zarr.group(temp_store)
else:
temp_group = zarr.group(tempfile.mkdtemp(".zarr", "temp_store_"))
target_group = zarr.group(target_store)
target_group.attrs.update(_encode_zarr_attributes(source.attrs))

copy_specs = []
for variable in source:
array = source[variable].copy()

# Update the array encoding with provided parameters and apply it
has_chunk_encoding = "chunks" in array.encoding
array.encoding.update(encoding.get(variable, {}))
array = encode_zarr_variable(array)

# Determine target chunking for array and remove it prior to
# validation/extraction ONLY if the array isn't also coming
# from a Zarr store (otherwise blocks need to be checked for overlap)
target_chunks = array.encoding.get("chunks")
if not has_chunk_encoding:
array.encoding.pop("chunks", None)
array_encoding = extract_zarr_variable_encoding(
array, raise_on_invalid=True, name=variable
)

# Default to chunking based on array shape if not explicitly provided
default_chunks = array_encoding.pop("chunks")
target_chunks = target_chunks or default_chunks

# Extract array attributes along with reserved property for
# xarray dimension names
array_attrs = _encode_zarr_attributes(array.attrs)
array_attrs[DIMENSION_KEY] = encode_zarr_attr_value(array.dims)

copy_spec = _setup_array_rechunk(
dask.array.asarray(array),
target_chunks,
max_mem,
target_group,
target_options=array_encoding,
temp_store_or_group=temp_group,
temp_options=array_encoding,
name=variable,
)
copy_spec.write.array.attrs.update(array_attrs) # type: ignore
copy_specs.append(copy_spec)
plan = executor.prepare_plan(copy_specs)
return Rechunked(executor, plan, source, temp_group, target_group)


def rechunk(
source,
target_chunks,
Expand Down Expand Up @@ -278,7 +235,7 @@ def rechunk(
attributes (see Xarray :ref:`xarray:zarr_encoding`.)
A value of ``None`` means that the array will
be copied with no change to its chunk structure.
- For a group, a dict is required. The keys correspond to array names.
- For a group of arrays, a dict is required. The keys correspond to array names.
The values are ``target_chunks`` arguments for the array. For example,
``{'foo': (20, 10), 'bar': {'x': 3, 'y': 5}, 'baz': None}``.
*All arrays you want to rechunk must be explicitly named.* Arrays
Expand All @@ -291,16 +248,25 @@ def rechunk(
The location in which to store the final, rechunked result.
Will be passed directly to :py:meth:`zarr.creation.create`
target_options: Dict, optional
Additional keyword arguments used to create target arrays.
See :py:meth:`zarr.creation.create` for arguments available.
Must not include any of [``shape``, ``chunks``, ``dtype``, ``store``].
Additional keyword arguments used to control array storage.
If the source is :py:class:`xarray.Dataset`, then these options will be used
to encode variables in the same manner as the ``encoding`` parameter in
:py:meth:`xarray.Dataset.to_zarr`. Otherwise, these options will be passed
to :py:meth:`zarr.creation.create`. The structure depends on ``source``.
- For a single array source, this should be a single dict such
as ``{'compressor': zarr.Blosc(), 'order': 'F'}``.
- For a group of arrays, a nested dict is required with values
like the above keyed by array name. For example,
``{'foo': {'compressor': zarr.Blosc(), 'order': 'F'}, 'bar': {'compressor': None}}``.
temp_store : str, MutableMapping, or zarr.Store object, optional
Location of temporary store for intermediate data. Can be deleted
once rechunking is complete.
temp_options: Dict, optional
Additional keyword arguments used to create intermediate arrays.
See :py:meth:`zarr.creation.create` for arguments available.
Must not include any of [``shape``, ``chunks``, ``dtype``, ``store``].
Options with same semantics as ``target_options`` for ``temp_store`` rather than
``target_store``. Defaults to ``target_options`` and has no effect when source
is of type xarray.Dataset.
executor: str or rechunker.types.Executor
Implementation of the execution engine for copying between zarr arrays.
Supplying a custom Executor is currently even more experimental than the
Expand Down Expand Up @@ -335,7 +301,79 @@ def _setup_rechunk(
temp_store=None,
temp_options=None,
):
if isinstance(source, zarr.hierarchy.Group):
if temp_options is None:
temp_options = target_options
target_options = target_options or {}
temp_options = temp_options or {}

if isinstance(source, xarray.Dataset):
if not isinstance(target_chunks, dict):
raise ValueError(
"You must specify ``target-chunks`` as a dict when rechunking a dataset."
)

variables, attrs = encode_dataset_coordinates(source)
attrs = _encode_zarr_attributes(attrs)

if temp_store:
temp_group = zarr.group(temp_store)
else:
temp_group = zarr.group(tempfile.mkdtemp(".zarr", "temp_store_"))
target_group = zarr.group(target_store)
target_group.attrs.update(attrs)

copy_specs = []
for name, variable in variables.items():
# This isn't strictly necessary because a shallow copy
# also occurs in `encode_dataset_coordinates` but do it
# anyways in case the coord encoding function changes
variable = variable.copy()

# Update the array encoding with provided options and apply it;
# note that at this point the `options` may contain any valid property
# applicable for the `encoding` parameter in Dataset.to_zarr other than "chunks"
options = target_options.get(name, {})
if "chunks" in options:
raise ValueError(
f"Chunks must be provided in ``target_chunks`` rather than options (variable={name})"
)
variable.encoding.update(options)
variable = encode_zarr_variable(variable)

# Extract the array encoding to get a default chunking, a step
# which will also ensure that the target chunking is compatible
# with the current chunking (only necessary for on-disk arrays)
variable_encoding = extract_zarr_variable_encoding(
variable, raise_on_invalid=False, name=name
)
variable_chunks = target_chunks.get(name, variable_encoding["chunks"])

# Restrict options to only those that are specific to zarr and
# not managed internally
options = {k: v for k, v in options.items() if k in ZARR_OPTIONS}
_validate_options(options)

# Extract array attributes along with reserved property for
# xarray dimension names
variable_attrs = _encode_zarr_attributes(variable.attrs)
variable_attrs[DIMENSION_KEY] = encode_zarr_attr_value(variable.dims)

copy_spec = _setup_array_rechunk(
dask.array.asarray(variable),
variable_chunks,
max_mem,
target_group,
target_options=options,
temp_store_or_group=temp_group,
temp_options=options,
name=name,
)
copy_spec.write.array.attrs.update(variable_attrs) # type: ignore
copy_specs.append(copy_spec)

return copy_specs, temp_group, target_group

elif isinstance(source, zarr.hierarchy.Group):
if not isinstance(target_chunks, dict):
raise ValueError(
"You must specify ``target-chunks`` as a dict when rechunking a group."
Expand All @@ -344,7 +382,7 @@ def _setup_rechunk(
if temp_store:
temp_group = zarr.group(temp_store)
else:
temp_group = None
temp_group = zarr.group(tempfile.mkdtemp(".zarr", "temp_store_"))
target_group = zarr.group(target_store)
target_group.attrs.update(source.attrs)

Expand All @@ -355,9 +393,9 @@ def _setup_rechunk(
array_target_chunks,
max_mem,
target_group,
target_options=target_options,
target_options=target_options.get(array_name),
temp_store_or_group=temp_group,
temp_options=temp_options,
temp_options=temp_options.get(array_name),
name=array_name,
)
copy_specs.append(copy_spec)
Expand All @@ -380,18 +418,9 @@ def _setup_rechunk(
return [copy_spec], intermediate, target

else:
raise ValueError("Source must be a Zarr Array or Group, or a Dask Array.")


def _validate_options(options):
if not options:
return
for k in ["shape", "chunks", "dtype", "store", "name"]:
if k in options:
raise ValueError(
f"Optional array arguments must not include {k} (provided {k}={options[k]}). "
"Values for this property are managed internally."
)
raise ValueError(
f"Source must be a Zarr Array, Zarr Group, Dask Array or Xarray Dataset (not {type(source)})."
)


def _setup_array_rechunk(
Expand Down
Loading

0 comments on commit fc1b17a

Please sign in to comment.