Skip to content

Commit

Permalink
Update get_bounds() to support mappable non-CF axes using `"bounds"…
Browse files Browse the repository at this point in the history
…` attr (#708)
  • Loading branch information
tomvothecoder authored Oct 8, 2024
1 parent e3dda64 commit 96bc649
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ docs: ## generate Sphinx HTML documentation, including API docs
# Build
# ----------------------
install: clean ## install the package to the active Python's site-packages
python setup.py install
python -m pip install .
15 changes: 8 additions & 7 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Below is a list of top-level API functions that are available in ``xcdat``.
compare_datasets
get_dim_coords
get_dim_keys
create_bounds
create_axis
create_gaussian_grid
create_global_mean_grid
Expand Down Expand Up @@ -83,13 +84,13 @@ Classes
.. autosummary::
:toctree: generated/

xcdat.bounds.BoundsAccessor
xcdat.spatial.SpatialAccessor
xcdat.temporal.TemporalAccessor
xcdat.regridder.accessor.RegridderAccessor
xcdat.regridder.regrid2.Regrid2Regridder
xcdat.regridder.xesmf.XESMFRegridder
xcdat.regridder.xgcm.XGCMRegridder
bounds.BoundsAccessor
spatial.SpatialAccessor
temporal.TemporalAccessor
regridder.accessor.RegridderAccessor
regridder.regrid2.Regrid2Regridder
regridder.xesmf.XESMFRegridder
regridder.xgcm.XGCMRegridder

.. currentmodule:: xarray

Expand Down
55 changes: 55 additions & 0 deletions tests/test_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,27 @@ def test_returns_single_dataset_axis_bounds_as_a_dataarray_object(self):

assert result.identical(expected)

def test_returns_single_dataset_axis_bounds_as_a_dataarray_object_for_non_cf_axis(
self,
):
ds = xr.Dataset(
coords={
"lat": xr.DataArray(
data=np.ones(3),
dims="lat",
attrs={"bounds": "lat_bnds"},
)
},
data_vars={
"lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"])
},
)

result = ds.bounds.get_bounds("Y")
expected = ds.lat_bnds

assert result.identical(expected)

def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self):
ds = xr.Dataset(
coords={
Expand Down Expand Up @@ -321,6 +342,40 @@ def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self):

assert result.identical(expected)

def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object_for_non_cf_axis(
self,
):
ds = xr.Dataset(
coords={
"lat": xr.DataArray(
data=np.ones(3),
dims="lat",
attrs={
"bounds": "lat_bnds",
},
),
"latitude": xr.DataArray(
data=np.ones(3),
dims="latitude",
attrs={
"bounds": "latitude_bnds",
},
),
},
data_vars={
"var": xr.DataArray(data=np.ones(3), dims=["lat"]),
"lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]),
"latitude_bnds": xr.DataArray(
data=np.ones((3, 3)), dims=["latitude", "bnds"]
),
},
)

result = ds.bounds.get_bounds("Y")
expected = ds.drop_vars("var")

assert result.identical(expected)


class TestAddBounds:
@pytest.fixture(autouse=True)
Expand Down
2 changes: 1 addition & 1 deletion xcdat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
get_dim_keys,
swap_lon_axis,
)
from xcdat.bounds import BoundsAccessor # noqa: F401
from xcdat.bounds import BoundsAccessor, create_bounds # noqa: F401
from xcdat.dataset import decode_time, open_dataset, open_mfdataset # noqa: F401
from xcdat.regridder.accessor import RegridderAccessor # noqa: F401
from xcdat.regridder.grid import ( # noqa: F401
Expand Down
58 changes: 46 additions & 12 deletions xcdat/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,7 @@ def get_bounds(
else:
# Get the obj in the Dataset using the key.
obj = _get_data_var(self._dataset, key=var_key)

# Check if the object is a data variable or a coordinate variable.
# If it is a data variable, derive the axis coordinate variable.
if obj.name in list(self._dataset.data_vars):
coord = get_dim_coords(obj, axis)
elif obj.name in list(self._dataset.coords):
coord = obj

try:
bounds_keys = [coord.attrs["bounds"]]
except KeyError:
bounds_keys = []
bounds_keys = self._get_bounds_from_attr(obj, axis)

if len(bounds_keys) == 0:
raise KeyError(
Expand Down Expand Up @@ -505,8 +494,53 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]:
except KeyError:
pass

keys_from_attr = self._get_bounds_from_attr(self._dataset, axis)
keys = keys + keys_from_attr

return list(set(keys))

def _get_bounds_from_attr(
self, obj: xr.DataArray | xr.Dataset, axis: CFAxisKey
) -> List[str]:
"""Retrieve bounds attribute keys from the given xarray object.
This method extracts the "bounds" attribute keys from the coordinates
of the specified axis in the provided xarray DataArray or Dataset.
Parameters:
-----------
obj : xr.DataArray | xr.Dataset
The xarray object from which to retrieve the bounds attribute keys.
axis : CFAxisKey
The CF axis key ("X", "Y", "T", or "Z").
Returns:
--------
List[str]
A list of bounds attribute keys found in the coordinates of the
specified axis. Otherwise, an empty list is returned.
"""
coords_obj = get_dim_coords(obj, axis)
bounds_keys: List[str] = []

if isinstance(coords_obj, xr.DataArray):
bounds_keys = self._extract_bounds_key(coords_obj, bounds_keys)
elif isinstance(coords_obj, xr.Dataset):
for coord in coords_obj.coords.values():
bounds_keys = self._extract_bounds_key(coord, bounds_keys)

return bounds_keys

def _extract_bounds_key(
self, coords_obj: xr.DataArray, bounds_keys: List[str]
) -> List[str]:
bnds_key = coords_obj.attrs.get("bounds")

if bnds_key is not None:
bounds_keys.append(bnds_key)

return bounds_keys

def _create_time_bounds( # noqa: C901
self,
time: xr.DataArray,
Expand Down

0 comments on commit 96bc649

Please sign in to comment.