Skip to content

Commit

Permalink
Resample dask arrays blockwise using pyresample (#128)
Browse files Browse the repository at this point in the history
Co-authored-by: Raphael Hagen <norlandrhagen@gmail.com>
  • Loading branch information
maxrjones and norlandrhagen authored Jun 17, 2024
1 parent bae9328 commit 5bfeaec
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- nodefaults
dependencies:
- dask
- jinja2
- esmpy>=8.2.0
- mercantile
- mpich
Expand All @@ -14,6 +15,7 @@ dependencies:
- pre-commit
- pydantic>=1.10
- pyproj
- pyresample
- pytest
- pytest-cov
- pytest-mypy
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Top level API
pyramid_create
pyramid_reproject
pyramid_regrid
pyramid_resample
1 change: 1 addition & 0 deletions ndpyramid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .coarsen import pyramid_coarsen
from .reproject import pyramid_reproject
from .regrid import pyramid_regrid
from .resample import pyramid_resample
from ._version import __version__
10 changes: 10 additions & 0 deletions ndpyramid/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@ def __init__(self, **data) -> None:

super().__init__(**data)
epsg_codes = {'web-mercator': 'EPSG:3857', 'equidistant-cylindrical': 'EPSG:4326'}
area_extents = {
'web-mercator': (
-20037508.342789244,
-20037508.342789248,
20037508.342789248,
20037508.342789244,
),
'equidistant-cylindrical': (-180, 180, 90, -90),
}
self._crs = epsg_codes[self.name]
self._proj = pyproj.Proj(self._crs)
self._area_extent = area_extents[self.name]

@pydantic.validate_call
def transform(self, *, dim: int) -> rasterio.transform.Affine:
Expand Down
283 changes: 283 additions & 0 deletions ndpyramid/resample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
from __future__ import annotations # noqa: F401

import warnings
from collections import defaultdict

import datatree as dt
import numpy as np
import xarray as xr
from pyproj.crs import CRS

from .common import Projection, ProjectionOptions
from .utils import (
add_metadata_and_zarr_encoding,
get_levels,
get_version,
multiscales_template,
)


def _da_resample(da, *, dim, projection_model, pixels_per_tile, other_chunk, resampling):
try:
from pyresample.area_config import create_area_def
from pyresample.future.resamplers.resampler import (
add_crs_xy_coords,
update_resampled_coords,
)
from pyresample.gradient import (
block_bilinear_interpolator,
block_nn_interpolator,
gradient_resampler_indices_block,
)
from pyresample.resampler import resample_blocks
from pyresample.utils.cf import load_cf_area
except ImportError as e:
raise ImportError(
'The use of pyramid_resample requires the packages pyresample and dask'
) from e
if da.encoding.get('_FillValue') is None and np.issubdtype(da.dtype, np.floating):
da.encoding['_FillValue'] = np.nan
if resampling == 'bilinear':
fun = block_bilinear_interpolator
elif resampling in ['nearest_neighbor' 'nearest_neighbour', 'nn', 'nearest']:
fun = block_nn_interpolator
else:
raise ValueError(f"Unrecognized interpolation method {resampling} for gradient resampling.")
target_area_def = create_area_def(
area_id=projection_model.name,
projection=projection_model._crs,
shape=(dim, dim),
area_extent=projection_model._area_extent,
)
try:
source_area_def = load_cf_area(da.to_dataset(name='var'), variable='var')[0]
except ValueError as e:
warnings.warn(
f"Automatic determination of source AreaDefinition from CF conventions failed with {e}."
' Falling back to AreaDefinition creation from coordinates.'
)
lx = da.x[0] - (da.x[1] - da.x[0]) / 2
rx = da.x[-1] + (da.x[-1] - da.x[-2]) / 2
uy = da.y[0] - (da.y[1] - da.y[0]) / 2
ly = da.y[-1] + (da.y[-1] - da.y[-2]) / 2
source_crs = CRS.from_string(da.rio.crs.to_string())
source_area_def = create_area_def(
area_id=2,
projection=source_crs,
shape=(da.sizes['y'], da.sizes['x']),
area_extent=(lx.values, ly.values, rx.values, uy.values),
)
indices_xy = resample_blocks(
gradient_resampler_indices_block,
source_area_def,
[],
target_area_def,
chunk_size=(other_chunk, pixels_per_tile, pixels_per_tile),
dtype=float,
)
resampled = resample_blocks(
fun,
source_area_def,
[da.data],
target_area_def,
dst_arrays=[indices_xy],
chunk_size=(other_chunk, pixels_per_tile, pixels_per_tile),
dtype=da.dtype,
)
resampled_da = xr.DataArray(resampled, dims=('time', 'y', 'x'))
resampled_da = update_resampled_coords(da, resampled_da, target_area_def)
resampled_da = add_crs_xy_coords(resampled_da, target_area_def)
resampled_da = resampled_da.drop_vars('crs')
resampled_da.attrs = {}
return resampled_da


def level_resample(
ds: xr.Dataset,
*,
x,
y,
projection: ProjectionOptions = 'web-mercator',
level: int,
pixels_per_tile: int = 128,
other_chunks: dict = None,
resampling: str | dict = 'bilinear',
clear_attrs: bool = False,
) -> xr.Dataset:
"""Create a level of a multiscale pyramid of a dataset via resampling.
Parameters
----------
ds : xarray.Dataset
The dataset to create a multiscale pyramid of.
y : string
name of the variable to use as 'y' axis of the CF area definition
x : string
name of the variable to use as 'x' axis of the CF area definition
projection : str, optional
The projection to use. Default is 'web-mercator'.
level : int
The level of the pyramid to create.
pixels_per_tile : int, optional
Number of pixels per tile
other_chunks : dict
Chunks for non-spatial dims.
resampling : str or dict, optional
Pyresample resampling method to use. Default is 'bilinear'.
If a dict, keys are variable names and values are resampling methods.
clear_attrs : bool, False
Clear the attributes of the DataArrays within the multiscale level. Default is False.
Returns
-------
xr.Dataset
The multiscale pyramid level.
Warning
-------
Pyramid generation by level is experimental and subject to change.
"""

dim = 2**level * pixels_per_tile
projection_model = Projection(name=projection)
save_kwargs = {'pixels_per_tile': pixels_per_tile}
attrs = {
'multiscales': multiscales_template(
datasets=[{'path': '.', 'level': level, 'crs': projection_model._crs}],
type='reduce',
method='pyramid_resample',
version=get_version(),
kwargs=save_kwargs,
)
}

# Convert resampling from string to dictionary if necessary
if isinstance(resampling, str):
resampling_dict: dict = defaultdict(lambda: resampling)
else:
resampling_dict = resampling
# update coord naming to x & y and ensure order of dims is time, y, x
ds = ds.rename({x: 'x', y: 'y'})
# create the data array for each level
ds_level = xr.Dataset(attrs=ds.attrs)
for k, da in ds.items():
if clear_attrs:
da.attrs.clear()
if len(da.shape) > 3:
# if extra_dim is not specified, raise an error
raise NotImplementedError(
'4+ dimensional datasets are not currently supported for pyramid_resample.'
)
else:
# if the data array is not 4D, just resample it
if other_chunks is None:
other_chunk = list(da.sizes.values())[0]
else:
other_chunk = list(other_chunks.values())[0]
ds_level[k] = _da_resample(
da,
dim=dim,
projection_model=projection_model,
pixels_per_tile=pixels_per_tile,
other_chunk=other_chunk,
resampling=resampling_dict[k],
)
ds_level.attrs['multiscales'] = attrs['multiscales']
ds_level = ds_level.rio.write_crs(projection_model._crs)
return ds_level


def pyramid_resample(
ds: xr.Dataset,
*,
x: str,
y: str,
projection: ProjectionOptions = 'web-mercator',
levels: int = None,
pixels_per_tile: int = 128,
other_chunks: dict = None,
resampling: str | dict = 'bilinear',
clear_attrs: bool = False,
) -> dt.DataTree:
"""Create a multiscale pyramid of a dataset via resampling.
Parameters
----------
ds : xarray.Dataset
The dataset to create a multiscale pyramid of.
y : string
name of the variable to use as ``y`` axis of the CF area definition
x : string
name of the variable to use as ``x`` axis of the CF area definition
projection : str, optional
The projection to use. Default is ``web-mercator``.
levels : int, optional
The number of levels to create. If None, the number of levels is
determined by the number of tiles in the dataset.
pixels_per_tile : int, optional
Number of pixels per tile, by default 128
other_chunks : dict
Chunks for non-spatial dims to pass to :py:meth:`~xr.Dataset.chunk`. Default is None
resampling : str or dict, optional
Pyresample resampling method to use (``bilinear`` or ``nearest``). Default is ``bilinear``.
If a dict, keys are variable names and values are resampling methods.
clear_attrs : bool, False
Clear the attributes of the DataArrays within the multiscale pyramid. Default is False.
Returns
-------
dt.DataTree
The multiscale pyramid.
Warnings
--------
- Pyresample expects longitude ranges between -180 - 180 degrees and latitude ranges between -90 and 90 degrees.
- 3-D datasets are expected to have a dimension order of ``(time, y, x)``.
``Ndpyramid`` and ``pyresample`` do not check the validity of these assumptions to improve performance.
"""
if not levels:
levels = get_levels(ds)
save_kwargs = {'levels': levels, 'pixels_per_tile': pixels_per_tile}
attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(levels)],
type='reduce',
method='pyramid_resample',
version=get_version(),
kwargs=save_kwargs,
)
}

# set up pyramid
plevels = {}

# pyramid data
for level in range(levels):
plevels[str(level)] = level_resample(
ds,
x=x,
y=y,
projection=projection,
level=level,
pixels_per_tile=pixels_per_tile,
other_chunks=other_chunks,
resampling=resampling,
clear_attrs=clear_attrs,
)

# create the final multiscale pyramid
plevels['/'] = xr.Dataset(attrs=attrs)
pyramid = dt.DataTree.from_dict(plevels)

projection_model = Projection(name=projection)

pyramid = add_metadata_and_zarr_encoding(
pyramid,
levels=levels,
pixels_per_tile=pixels_per_tile,
other_chunks=other_chunks,
projection=projection_model,
)
return pyramid
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,17 @@ dependencies = [
"rasterio"
]


[project.optional-dependencies]
dask = [
"dask",
"pyresample",
]
jupyter = [
'notebook',
'ipytree>=0.2.2',
'ipywidgets>=8.0.0',
'matplotlib'
]
xesmf = ["xesmf"]

test = [
Expand Down
Loading

0 comments on commit 5bfeaec

Please sign in to comment.