From e6f6ebd684b1f487f20646761ad0b591ddd1d744 Mon Sep 17 00:00:00 2001 From: masawdah Date: Thu, 27 Jun 2024 12:34:59 +0200 Subject: [PATCH] pytest --- xvec/tests/test_zonal_stats.py | 123 +++++++++++++++++++++------------ xvec/zonal.py | 42 +++++------ 2 files changed, 93 insertions(+), 72 deletions(-) diff --git a/xvec/tests/test_zonal_stats.py b/xvec/tests/test_zonal_stats.py index 42e415e..11b8cb8 100644 --- a/xvec/tests/test_zonal_stats.py +++ b/xvec/tests/test_zonal_stats.py @@ -9,7 +9,7 @@ import xvec # noqa: F401 -@pytest.mark.parametrize("method", ["rasterize", "iterate"]) +@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"]) def test_structure(method): da = xr.DataArray( np.ones((10, 10, 5)), @@ -24,14 +24,22 @@ def test_structure(method): polygon2 = shapely.geometry.Polygon([(6, 22), (9, 22), (9, 29), (6, 26)]) polygons = gpd.GeoSeries([polygon1, polygon2], crs="EPSG:4326") - expected = xr.DataArray( - np.array([[12.0] * 5, [18.0] * 5]), - coords={ - "geometry": polygons, - "time": pd.date_range("2023-01-01", periods=5), - }, - ).xvec.set_geom_indexes("geometry", crs="EPSG:4326") - + if method == "exactextract": + expected = xr.DataArray( + np.array([[12.0] * 5, [16.5] * 5]), + coords={ + "geometry": polygons, + "time": pd.date_range("2023-01-01", periods=5), + }, + ).xvec.set_geom_indexes("geometry", crs="EPSG:4326") + else: + expected = xr.DataArray( + np.array([[12.0] * 5, [18.0] * 5]), + coords={ + "geometry": polygons, + "time": pd.date_range("2023-01-01", periods=5), + }, + ).xvec.set_geom_indexes("geometry", crs="EPSG:4326") actual = da.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method) xr.testing.assert_identical(actual, expected) @@ -43,35 +51,36 @@ def test_structure(method): ) # dataset - ds = da.to_dataset(name="test") - - expected_ds = expected.to_dataset(name="test").set_coords("geometry") - actual_ds = ds.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method) - xr.testing.assert_identical(actual_ds, expected_ds) - - actual_ix_ds = ds.xvec.zonal_stats( - polygons, "x", "y", stats="sum", method=method, index=True - ) - xr.testing.assert_identical( - actual_ix_ds, expected_ds.assign_coords({"index": ("geometry", polygons.index)}) - ) + if method == "rasterize" or method == "iterate": + ds = da.to_dataset(name="test") + expected_ds = expected.to_dataset(name="test").set_coords("geometry") + actual_ds = ds.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method) + xr.testing.assert_identical(actual_ds, expected_ds) + + actual_ix_ds = ds.xvec.zonal_stats( + polygons, "x", "y", stats="sum", method=method, index=True + ) + xr.testing.assert_identical( + actual_ix_ds, + expected_ds.assign_coords({"index": ("geometry", polygons.index)}), + ) - # named index - polygons.index.name = "my_index" - actual_ix_named = da.xvec.zonal_stats( - polygons, "x", "y", stats="sum", method=method - ) - xr.testing.assert_identical( - actual_ix_named, - expected.assign_coords({"my_index": ("geometry", polygons.index)}), - ) - actual_ix_names_ds = ds.xvec.zonal_stats( - polygons, "x", "y", stats="sum", method=method - ) - xr.testing.assert_identical( - actual_ix_names_ds, - expected_ds.assign_coords({"my_index": ("geometry", polygons.index)}), - ) + # named index + polygons.index.name = "my_index" + actual_ix_named = da.xvec.zonal_stats( + polygons, "x", "y", stats="sum", method=method + ) + xr.testing.assert_identical( + actual_ix_named, + expected.assign_coords({"my_index": ("geometry", polygons.index)}), + ) + actual_ix_names_ds = ds.xvec.zonal_stats( + polygons, "x", "y", stats="sum", method=method + ) + xr.testing.assert_identical( + actual_ix_names_ds, + expected_ds.assign_coords({"my_index": ("geometry", polygons.index)}), + ) def test_match(): @@ -105,7 +114,7 @@ def test_dataset(method): ) -@pytest.mark.parametrize("method", ["rasterize", "iterate"]) +@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"]) def test_dataarray(method): ds = xr.tutorial.open_dataset("eraint_uvz") world = gpd.read_file(geodatasets.get_path("naturalearth land")) @@ -115,10 +124,13 @@ def test_dataarray(method): assert result.shape == (127, 2, 3) assert result.dims == ("geometry", "month", "level") - assert result.mean() == pytest.approx(61367.76185577) + if method == "exactextract": + assert result.mean() == pytest.approx(61625.53438858) + else: + assert result.mean() == pytest.approx(61367.76185577) -@pytest.mark.parametrize("method", ["rasterize", "iterate"]) +@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"]) def test_stat(method): ds = xr.tutorial.open_dataset("eraint_uvz") world = gpd.read_file(geodatasets.get_path("naturalearth land")) @@ -129,13 +141,32 @@ def test_stat(method): median_ = ds.z.xvec.zonal_stats( world.geometry, "longitude", "latitude", method=method, stats="median" ) - quantile_ = ds.z.xvec.zonal_stats( - world.geometry, "longitude", "latitude", method=method, stats="quantile", q=0.2 - ) + if method == "exactextract": + quantile_ = ds.z.xvec.zonal_stats( + world.geometry, + "longitude", + "latitude", + method=method, + stats="quantile(q=0.33)", + ) + else: + quantile_ = ds.z.xvec.zonal_stats( + world.geometry, + "longitude", + "latitude", + method=method, + stats="quantile", + q=0.2, + ) - assert mean_.mean() == pytest.approx(61367.76185577) - assert median_.mean() == pytest.approx(61370.18563539) - assert quantile_.mean() == pytest.approx(61279.93619836) + if method == "exactextract": + assert mean_.mean() == pytest.approx(61625.53438858) + assert median_.mean() == pytest.approx(61628.67168691) + assert quantile_.mean() == pytest.approx(61576.0883029) + else: + assert mean_.mean() == pytest.approx(61367.76185577) + assert median_.mean() == pytest.approx(61370.18563539) + assert quantile_.mean() == pytest.approx(61279.93619836) @pytest.mark.parametrize("method", ["rasterize", "iterate"]) diff --git a/xvec/zonal.py b/xvec/zonal.py index ee4eabe..c34930b 100644 --- a/xvec/zonal.py +++ b/xvec/zonal.py @@ -338,10 +338,8 @@ def _zonal_stats_exactextract( data = data.transpose("location", y_coords, x_coords) # Aggregation result - stats = _prep_stats(stats) - results = exactextract.exact_extract( - rast=data, vec=gpd.GeoDataFrame(geometry), ops=stats, output="pandas" - ) + gdf = gpd.GeoDataFrame(geometry=geometry, crs=crs) + results = exactextract.exact_extract(rast=data, vec=gdf, ops=stats, output="pandas") # Unstack the results if pd.api.types.is_list_like(stats): @@ -356,27 +354,19 @@ def _zonal_stats_exactextract( ).xvec.set_geom_indexes(name, crs=crs) agg[stat] = result i += locs - - vec_cube = xr.concat( - agg.values(), - dim=xr.DataArray( - list(agg.keys()), name="zonal_statistics", dims="zonal_statistics" - ), - ) + vec_cube = xr.concat( + agg.values(), + dim=xr.DataArray( + list(agg.keys()), name="zonal_statistics", dims="zonal_statistics" + ), + ) + elif isinstance(stats, str): + # Unstack the result + arr = results.values.reshape(original_shape) + vec_cube = xr.DataArray( + arr, coords=coords_info, dims=coords_info.keys() + ).xvec.set_geom_indexes(name, crs=crs) + else: + raise ValueError(f"{stats} is not a valid aggregation for exactextract method.") return vec_cube - - -def _prep_stats(stats): - if isinstance(stats, str): - stats = [stats] - - prepared_stats = [] - for stat in stats: - if isinstance(stat, str): - prepared_stats.append(stat) - else: - raise ValueError( - f'{stat} is not supported. It supports strings (e.g., "mean", "quantile(q=0.25)")' - ) - return prepared_stats