Skip to content

Commit

Permalink
pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
masawdah committed Jun 27, 2024
1 parent 720a07e commit e6f6ebd
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 72 deletions.
123 changes: 77 additions & 46 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xvec # noqa: F401


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
def test_structure(method):
da = xr.DataArray(
np.ones((10, 10, 5)),
Expand All @@ -24,14 +24,22 @@ def test_structure(method):
polygon2 = shapely.geometry.Polygon([(6, 22), (9, 22), (9, 29), (6, 26)])
polygons = gpd.GeoSeries([polygon1, polygon2], crs="EPSG:4326")

expected = xr.DataArray(
np.array([[12.0] * 5, [18.0] * 5]),
coords={
"geometry": polygons,
"time": pd.date_range("2023-01-01", periods=5),
},
).xvec.set_geom_indexes("geometry", crs="EPSG:4326")

if method == "exactextract":
expected = xr.DataArray(
np.array([[12.0] * 5, [16.5] * 5]),
coords={
"geometry": polygons,
"time": pd.date_range("2023-01-01", periods=5),
},
).xvec.set_geom_indexes("geometry", crs="EPSG:4326")
else:
expected = xr.DataArray(
np.array([[12.0] * 5, [18.0] * 5]),
coords={
"geometry": polygons,
"time": pd.date_range("2023-01-01", periods=5),
},
).xvec.set_geom_indexes("geometry", crs="EPSG:4326")
actual = da.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual, expected)

Expand All @@ -43,35 +51,36 @@ def test_structure(method):
)

# dataset
ds = da.to_dataset(name="test")

expected_ds = expected.to_dataset(name="test").set_coords("geometry")
actual_ds = ds.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual_ds, expected_ds)

actual_ix_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method, index=True
)
xr.testing.assert_identical(
actual_ix_ds, expected_ds.assign_coords({"index": ("geometry", polygons.index)})
)
if method == "rasterize" or method == "iterate":
ds = da.to_dataset(name="test")
expected_ds = expected.to_dataset(name="test").set_coords("geometry")
actual_ds = ds.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual_ds, expected_ds)

actual_ix_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method, index=True
)
xr.testing.assert_identical(
actual_ix_ds,
expected_ds.assign_coords({"index": ("geometry", polygons.index)}),
)

# named index
polygons.index.name = "my_index"
actual_ix_named = da.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_named,
expected.assign_coords({"my_index": ("geometry", polygons.index)}),
)
actual_ix_names_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_names_ds,
expected_ds.assign_coords({"my_index": ("geometry", polygons.index)}),
)
# named index
polygons.index.name = "my_index"
actual_ix_named = da.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_named,
expected.assign_coords({"my_index": ("geometry", polygons.index)}),
)
actual_ix_names_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_names_ds,
expected_ds.assign_coords({"my_index": ("geometry", polygons.index)}),
)


def test_match():
Expand Down Expand Up @@ -105,7 +114,7 @@ def test_dataset(method):
)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
def test_dataarray(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
Expand All @@ -115,10 +124,13 @@ def test_dataarray(method):

assert result.shape == (127, 2, 3)
assert result.dims == ("geometry", "month", "level")
assert result.mean() == pytest.approx(61367.76185577)
if method == "exactextract":
assert result.mean() == pytest.approx(61625.53438858)
else:
assert result.mean() == pytest.approx(61367.76185577)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
def test_stat(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
Expand All @@ -129,13 +141,32 @@ def test_stat(method):
median_ = ds.z.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="median"
)
quantile_ = ds.z.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="quantile", q=0.2
)
if method == "exactextract":
quantile_ = ds.z.xvec.zonal_stats(
world.geometry,
"longitude",
"latitude",
method=method,
stats="quantile(q=0.33)",
)
else:
quantile_ = ds.z.xvec.zonal_stats(
world.geometry,
"longitude",
"latitude",
method=method,
stats="quantile",
q=0.2,
)

assert mean_.mean() == pytest.approx(61367.76185577)
assert median_.mean() == pytest.approx(61370.18563539)
assert quantile_.mean() == pytest.approx(61279.93619836)
if method == "exactextract":
assert mean_.mean() == pytest.approx(61625.53438858)
assert median_.mean() == pytest.approx(61628.67168691)
assert quantile_.mean() == pytest.approx(61576.0883029)
else:
assert mean_.mean() == pytest.approx(61367.76185577)
assert median_.mean() == pytest.approx(61370.18563539)
assert quantile_.mean() == pytest.approx(61279.93619836)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
Expand Down
42 changes: 16 additions & 26 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,8 @@ def _zonal_stats_exactextract(
data = data.transpose("location", y_coords, x_coords)

# Aggregation result
stats = _prep_stats(stats)
results = exactextract.exact_extract(
rast=data, vec=gpd.GeoDataFrame(geometry), ops=stats, output="pandas"
)
gdf = gpd.GeoDataFrame(geometry=geometry, crs=crs)
results = exactextract.exact_extract(rast=data, vec=gdf, ops=stats, output="pandas")

# Unstack the results
if pd.api.types.is_list_like(stats):
Expand All @@ -356,27 +354,19 @@ def _zonal_stats_exactextract(
).xvec.set_geom_indexes(name, crs=crs)
agg[stat] = result
i += locs

vec_cube = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
vec_cube = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str):
# Unstack the result
arr = results.values.reshape(original_shape)
vec_cube = xr.DataArray(
arr, coords=coords_info, dims=coords_info.keys()
).xvec.set_geom_indexes(name, crs=crs)
else:
raise ValueError(f"{stats} is not a valid aggregation for exactextract method.")

return vec_cube


def _prep_stats(stats):
if isinstance(stats, str):
stats = [stats]

prepared_stats = []
for stat in stats:
if isinstance(stat, str):
prepared_stats.append(stat)
else:
raise ValueError(
f'{stat} is not supported. It supports strings (e.g., "mean", "quantile(q=0.25)")'
)
return prepared_stats

0 comments on commit e6f6ebd

Please sign in to comment.