From 96bc649ad59750ec37b959d9d4890891b54c62c7 Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Tue, 8 Oct 2024 11:26:52 -0700 Subject: [PATCH] Update `get_bounds()` to support mappable non-CF axes using `"bounds"` attr (#708) --- Makefile | 2 +- docs/api.rst | 15 ++++++------ tests/test_bounds.py | 55 +++++++++++++++++++++++++++++++++++++++++ xcdat/__init__.py | 2 +- xcdat/bounds.py | 58 +++++++++++++++++++++++++++++++++++--------- 5 files changed, 111 insertions(+), 21 deletions(-) diff --git a/Makefile b/Makefile index 9135ed7b..68e1e7d4 100644 --- a/Makefile +++ b/Makefile @@ -83,4 +83,4 @@ docs: ## generate Sphinx HTML documentation, including API docs # Build # ---------------------- install: clean ## install the package to the active Python's site-packages - python setup.py install + python -m pip install . diff --git a/docs/api.rst b/docs/api.rst index 1ced17a3..1e107649 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -27,6 +27,7 @@ Below is a list of top-level API functions that are available in ``xcdat``. compare_datasets get_dim_coords get_dim_keys + create_bounds create_axis create_gaussian_grid create_global_mean_grid @@ -83,13 +84,13 @@ Classes .. autosummary:: :toctree: generated/ - xcdat.bounds.BoundsAccessor - xcdat.spatial.SpatialAccessor - xcdat.temporal.TemporalAccessor - xcdat.regridder.accessor.RegridderAccessor - xcdat.regridder.regrid2.Regrid2Regridder - xcdat.regridder.xesmf.XESMFRegridder - xcdat.regridder.xgcm.XGCMRegridder + bounds.BoundsAccessor + spatial.SpatialAccessor + temporal.TemporalAccessor + regridder.accessor.RegridderAccessor + regridder.regrid2.Regrid2Regridder + regridder.xesmf.XESMFRegridder + regridder.xgcm.XGCMRegridder .. currentmodule:: xarray diff --git a/tests/test_bounds.py b/tests/test_bounds.py index 2c519fce..00838e32 100644 --- a/tests/test_bounds.py +++ b/tests/test_bounds.py @@ -287,6 +287,27 @@ def test_returns_single_dataset_axis_bounds_as_a_dataarray_object(self): assert result.identical(expected) + def test_returns_single_dataset_axis_bounds_as_a_dataarray_object_for_non_cf_axis( + self, + ): + ds = xr.Dataset( + coords={ + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={"bounds": "lat_bnds"}, + ) + }, + data_vars={ + "lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]) + }, + ) + + result = ds.bounds.get_bounds("Y") + expected = ds.lat_bnds + + assert result.identical(expected) + def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self): ds = xr.Dataset( coords={ @@ -321,6 +342,40 @@ def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self): assert result.identical(expected) + def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object_for_non_cf_axis( + self, + ): + ds = xr.Dataset( + coords={ + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={ + "bounds": "lat_bnds", + }, + ), + "latitude": xr.DataArray( + data=np.ones(3), + dims="latitude", + attrs={ + "bounds": "latitude_bnds", + }, + ), + }, + data_vars={ + "var": xr.DataArray(data=np.ones(3), dims=["lat"]), + "lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]), + "latitude_bnds": xr.DataArray( + data=np.ones((3, 3)), dims=["latitude", "bnds"] + ), + }, + ) + + result = ds.bounds.get_bounds("Y") + expected = ds.drop_vars("var") + + assert result.identical(expected) + class TestAddBounds: @pytest.fixture(autouse=True) diff --git a/xcdat/__init__.py b/xcdat/__init__.py index cae99097..3881359e 100644 --- a/xcdat/__init__.py +++ b/xcdat/__init__.py @@ -6,7 +6,7 @@ get_dim_keys, swap_lon_axis, ) -from xcdat.bounds import BoundsAccessor # noqa: F401 +from xcdat.bounds import BoundsAccessor, create_bounds # noqa: F401 from xcdat.dataset import decode_time, open_dataset, open_mfdataset # noqa: F401 from xcdat.regridder.accessor import RegridderAccessor # noqa: F401 from xcdat.regridder.grid import ( # noqa: F401 diff --git a/xcdat/bounds.py b/xcdat/bounds.py index dfbc0aff..4ee799c3 100644 --- a/xcdat/bounds.py +++ b/xcdat/bounds.py @@ -236,18 +236,7 @@ def get_bounds( else: # Get the obj in the Dataset using the key. obj = _get_data_var(self._dataset, key=var_key) - - # Check if the object is a data variable or a coordinate variable. - # If it is a data variable, derive the axis coordinate variable. - if obj.name in list(self._dataset.data_vars): - coord = get_dim_coords(obj, axis) - elif obj.name in list(self._dataset.coords): - coord = obj - - try: - bounds_keys = [coord.attrs["bounds"]] - except KeyError: - bounds_keys = [] + bounds_keys = self._get_bounds_from_attr(obj, axis) if len(bounds_keys) == 0: raise KeyError( @@ -505,8 +494,53 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]: except KeyError: pass + keys_from_attr = self._get_bounds_from_attr(self._dataset, axis) + keys = keys + keys_from_attr + return list(set(keys)) + def _get_bounds_from_attr( + self, obj: xr.DataArray | xr.Dataset, axis: CFAxisKey + ) -> List[str]: + """Retrieve bounds attribute keys from the given xarray object. + + This method extracts the "bounds" attribute keys from the coordinates + of the specified axis in the provided xarray DataArray or Dataset. + + Parameters: + ----------- + obj : xr.DataArray | xr.Dataset + The xarray object from which to retrieve the bounds attribute keys. + axis : CFAxisKey + The CF axis key ("X", "Y", "T", or "Z"). + + Returns: + -------- + List[str] + A list of bounds attribute keys found in the coordinates of the + specified axis. Otherwise, an empty list is returned. + """ + coords_obj = get_dim_coords(obj, axis) + bounds_keys: List[str] = [] + + if isinstance(coords_obj, xr.DataArray): + bounds_keys = self._extract_bounds_key(coords_obj, bounds_keys) + elif isinstance(coords_obj, xr.Dataset): + for coord in coords_obj.coords.values(): + bounds_keys = self._extract_bounds_key(coord, bounds_keys) + + return bounds_keys + + def _extract_bounds_key( + self, coords_obj: xr.DataArray, bounds_keys: List[str] + ) -> List[str]: + bnds_key = coords_obj.attrs.get("bounds") + + if bnds_key is not None: + bounds_keys.append(bnds_key) + + return bounds_keys + def _create_time_bounds( # noqa: C901 self, time: xr.DataArray,