Skip to content

Commit

Permalink
ENH: multiple aggregations at once in zonal_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
martinfleis committed Dec 15, 2023
1 parent 07a8e92 commit ed7d545
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 17 deletions.
8 changes: 2 additions & 6 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def zonal_stats(
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: Hashable = "geometry",
index: bool = None,
method: str = "rasterize",
Expand Down Expand Up @@ -990,8 +990,6 @@ def zonal_stats(
the :class:`GeometryIndex` of ``geometry``.
"""
# TODO: allow multiple stats at the same time (concat along a new axis),
# TODO: possibly as a list of tuples to include names?
if method == "rasterize":
result = _zonal_stats_rasterize(
self,
Expand Down Expand Up @@ -1033,9 +1031,7 @@ def zonal_stats(
result = result.assign_coords({index_name: (name, geometry.index)})

# standardize the shape - each method comes with a different one
return result.transpose(
name, *tuple(d for d in self._obj.dims if d not in [x_coords, y_coords])
)
return result.transpose(name, ...)

def extract_points(
self,
Expand Down
61 changes: 61 additions & 0 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,64 @@ def test_callable(method):
world.geometry, "longitude", "latitude", method=method, stats="std"
)
xr.testing.assert_identical(da_agg, da_std)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_multiple(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
result = ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=[
"mean",
"sum",
("quantile", "quantile", {"q": [0.1, 0.2, 0.3]}),
("numpymean", np.nanmean),
np.nanmean,
],
method=method,
n_jobs=1,
)
assert sorted(result.dims) == sorted(
[
"level",
"zonal_statistics",
"geometry",
"month",
"quantile",
]
)

assert (
result.zonal_statistics == ["mean", "sum", "quantile", "numpymean", "nanmean"]
).all()


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_invalid(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
with pytest.raises(ValueError, match=r"\['gorilla'\] is not a valid aggregation."):
ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=[
"mean",
["gorilla"],
],
method=method,
n_jobs=1,
)

with pytest.raises(ValueError, match="3 is not a valid aggregation."):
ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=3,
method=method,
n_jobs=1,
)
81 changes: 70 additions & 11 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,34 @@
from typing import Callable

import numpy as np
import pandas as pd
import shapely
import xarray as xr


def _agg_rasterize(groups, stats, **kwargs):
if isinstance(stats, str):
return getattr(groups, stats)(**kwargs)
return groups.reduce(stats, keep_attrs=True, **kwargs)


def _agg_iterate(masked, stats, x_coords, y_coords, **kwargs):
if isinstance(stats, str):
return getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
return masked.reduce(stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs)


def _zonal_stats_rasterize(
acc,
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str
| Callable
| Sequence[tuple[str | Callable]]
| Sequence[str | Callable] = "mean",
name: str = "geometry",
all_touched: bool = False,
**kwargs,
Expand Down Expand Up @@ -47,10 +65,31 @@ def _zonal_stats_rasterize(
all_touched=all_touched,
)
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
if isinstance(stats, str):
agg = getattr(groups, stats)(**kwargs)

if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats:
if isinstance(stat, str):
agg[stat] = _agg_rasterize(groups, stat, **kwargs)
elif callable(stat):
agg[stat.__name__] = _agg_rasterize(groups, stat, **kwargs)
elif isinstance(stat, tuple):
kws = stat[2] if len(stat) == 3 else {}
agg[stat[0]] = _agg_rasterize(groups, stat[1], **kws)
else:
raise ValueError(f"{stat} is not a valid aggregation.")

agg = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str) or callable(stats):
agg = _agg_rasterize(groups, stats, **kwargs)
else:
agg = groups.reduce(stats, keep_attrs=True, **kwargs)
raise ValueError(f"{stats} is not a valid aggregation.")

vec_cube = (
agg.reindex(group=range(len(geometry)))
.assign_coords(group=geometry)
Expand All @@ -68,7 +107,10 @@ def _zonal_stats_iterative(
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str
| Callable
| Sequence[tuple[str | Callable]]
| Sequence[str | Callable] = "mean",
name: str = "geometry",
all_touched: bool = False,
n_jobs: int = -1,
Expand Down Expand Up @@ -216,14 +258,31 @@ def _agg_geom(
all_touched=all_touched,
)
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
if isinstance(stats, str):
result = getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats:
if isinstance(stat, str):
agg[stat] = _agg_iterate(masked, stat, x_coords, y_coords, **kwargs)
elif callable(stat):
agg[stat.__name__] = _agg_iterate(
masked, stat, x_coords, y_coords, **kwargs
)
elif isinstance(stat, tuple):
kws = stat[2] if len(stat) == 3 else {}
agg[stat[0]] = _agg_iterate(masked, stat[1], x_coords, y_coords, **kws)
else:
raise ValueError(f"{stat} is not a valid aggregation.")

result = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str) or callable(stats):
result = _agg_iterate(masked, stats, x_coords, y_coords, **kwargs)
else:
result = masked.reduce(
stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
raise ValueError(f"{stats} is not a valid aggregation.")

del mask
gc.collect()
Expand Down

0 comments on commit ed7d545

Please sign in to comment.