From c6587b88c3d6d834d4a821890ad2af80e77d000d Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Mon, 7 Nov 2022 16:29:00 -0800 Subject: [PATCH] Support for N axis dimensions mapped to N coordinates (#343) --- docs/api.rst | 8 +- docs/examples/general-utilities.ipynb | 46 +- tests/fixtures.py | 235 ++-- tests/test_axis.py | 700 +++++++-- tests/test_bounds.py | 253 +++- tests/test_dataset.py | 1871 ++++++++++++++++++------- tests/test_regrid.py | 173 ++- tests/test_spatial.py | 86 +- tests/test_temporal.py | 94 +- xcdat/__init__.py | 6 +- xcdat/axis.py | 522 ++++--- xcdat/bounds.py | 256 ++-- xcdat/dataset.py | 659 ++++----- xcdat/regridder/accessor.py | 23 +- xcdat/regridder/grid.py | 63 +- xcdat/spatial.py | 42 +- xcdat/temporal.py | 127 +- xcdat/utils.py | 27 +- 18 files changed, 3536 insertions(+), 1655 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 0b7bab0b..16151f71 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -22,11 +22,11 @@ Below is a list of top-level API functions available in ``xcdat``. open_dataset open_mfdataset center_times - decode_non_cf_time + decode_time swap_lon_axis compare_datasets - get_axis_coord - get_axis_dim + get_dim_coords + get_dim_keys create_gaussian_grid create_global_mean_grid create_grid @@ -165,7 +165,7 @@ It is especially useful for those who are transitioning over from CDAT to xarray - ``Dataset.spatial.average("VAR_KEY", axis=["X", "Y"])`` specifying ``lat_bounds`` and ``lon_bounds`` - ``cdutil.averager(TransientVariable, axis="xy")``, optionally subset ``TransientVariable`` with ``cdutil.region.domain()`` * - Decode time coordinates with CF/Non-CF units? - - ``xr.decode_cf()`` specifying ``decode_times=True``, or ``xcdat.decode_non_cf_time()`` + - ``xr.decode_cf()`` specifying ``decode_times=True``, or ``xcdat.decode_time()`` - ``cdms2.axis.Axis.asComponentTime()`` * - Temporally averaging with a single time-averaged snapshot and time coordinates removed? - ``Dataset.temporal.average("VAR_KEY")`` diff --git a/docs/examples/general-utilities.ipynb b/docs/examples/general-utilities.ipynb index 241016ae..bf4d5554 100644 --- a/docs/examples/general-utilities.ipynb +++ b/docs/examples/general-utilities.ipynb @@ -4533,26 +4533,42 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "2ab105f4", "metadata": {}, "source": [ - "## Get an axis coordinate variable\n", + "## Get the dimension coordinates for an axis.\n", "\n", - "In `xarray`, you can get a coordinate variable by directly referencing its name (e.g., `ds.lat`).\n", + "In `xarray`, you can get a dimension coordinates by directly referencing its name (e.g., `ds.lat`). `xcdat` provides an alternative way to get dimension coordinates agnostically by simply passing the CF axis key to applicable APIs.\n", "\n", - "`xcdat` provides an alternative way to get coordinate variable agnostically by simply passing the name of the axis to the related API.\n", + "* Related API: [xcdat.get_dim_coords()](../generated/xcdat.get_dim_coords.rst)\n", "\n", - "* Related API: [xcdat.get_axis_coord()](../generated/xcdat.get_axis_coord.rst)\n", + "Helpful knowledge:\n", "\n", - "This API requires that either the `axis` attr or `standard_name` attr is set, or the name of the dimension follows the valid short-hand convention (e.g., 'lat' for latitude)." + "* This API uses ``cf_xarray`` to interpret CF axis names and coordinate names in the xarray object attributes. Refer to [Metadata Interpretation](../faqs.rst) for more information.\n", + "\n", + "Xarray documentation on coordinates ([source](https://docs.xarray.dev/en/stable/user-guide/data-structures.html#coordinates)):\n", + "\n", + "* There are two types of coordinates in xarray:\n", + "\n", + " * **dimension coordinates** are one dimensional coordinates with a name equal to their sole dimension (marked by * when printing a dataset or data array). They are used for label based indexing and alignment, like the index found on a pandas DataFrame or Series. Indeed, these “dimension” coordinates use a pandas.Index internally to store their values.\n", + "\n", + " * **non-dimension coordinates** are variables that contain coordinate data, but are not a dimension coordinate. They can be multidimensional (see Working with Multidimensional Coordinates), and there is no relationship between the name of a non-dimension coordinate and the name(s) of its dimension(s). Non-dimension coordinates can be useful for indexing or plotting; otherwise, xarray does not make any direct use of the values associated with them. They are not used for alignment or automatic indexing, nor are they required to match when doing arithmetic (see Coordinates).\n", + "\n", + "* Xarray’s terminology differs from the [CF terminology](https://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#terminology), where the “dimension coordinates” are called “coordinate variables”, and the “non-dimension coordinates” are called “auxiliary coordinate variables” (see [GH1295](https://github.com/pydata/xarray/issues/1295) for more details).\n", + "\n", + "\n", + "\n" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "15b5441d", "metadata": {}, "source": [ + "\n", "### 1. `axis` attr" ] }, @@ -4606,16 +4622,6 @@ "ds.lat.attrs[\"standard_name\"]" ] }, - { - "cell_type": "markdown", - "id": "3740761b", - "metadata": {}, - "source": [ - "### 3. Name of the dimension\n", - "\n", - "Must be the short name (e.g., \"lat\" for latitude and \"lon\" for longitude)" - ] - }, { "cell_type": "code", "execution_count": 22, @@ -4637,14 +4643,6 @@ "\"lat\" in ds.dims" ] }, - { - "cell_type": "markdown", - "id": "b0591ecd", - "metadata": {}, - "source": [ - "Requires at least at one of the three keys above to be set." - ] - }, { "cell_type": "code", "execution_count": 24, @@ -5107,7 +5105,7 @@ } ], "source": [ - "xcdat.get_coord_var(ds, axis=\"Y\")" + "xcdat.get_axis_coord(ds, axis=\"Y\")" ] } ], diff --git a/tests/fixtures.py b/tests/fixtures.py index 6ac3c513..c74bb869 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,4 +1,5 @@ """This module stores reusable test fixtures.""" +import cftime import numpy as np import xarray as xr @@ -13,26 +14,26 @@ # TIME # ==== -time_cf = xr.DataArray( +time_decoded = xr.DataArray( data=np.array( [ - "2000-01-16T12:00:00.000000000", - "2000-02-15T12:00:00.000000000", - "2000-03-16T12:00:00.000000000", - "2000-04-16T00:00:00.000000000", - "2000-05-16T12:00:00.000000000", - "2000-06-16T00:00:00.000000000", - "2000-07-16T12:00:00.000000000", - "2000-08-16T12:00:00.000000000", - "2000-09-16T00:00:00.000000000", - "2000-10-16T12:00:00.000000000", - "2000-11-16T00:00:00.000000000", - "2000-12-16T12:00:00.000000000", - "2001-01-16T12:00:00.000000000", - "2001-02-15T00:00:00.000000000", - "2001-12-16T12:00:00.000000000", + cftime.DatetimeGregorian(2000, 1, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 2, 15, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 3, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 4, 16, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 5, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 6, 16, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 7, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 8, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 9, 16, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 10, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 11, 16, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 12, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2001, 1, 16, 12, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2001, 2, 15, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2001, 12, 16, 12, 0, 0, 0, has_year_zero=False), ], - dtype="datetime64[ns]", + dtype=object, ), dims=["time"], attrs={ @@ -41,64 +42,89 @@ "standard_name": "time", }, ) -# NOTE: With `decode_times=True`, the "calendar" and "units" attributes are -# stored in `.encoding`. -time_cf.encoding["calendar"] = "standard" -time_cf.encoding["units"] = "days since 2000-01-01" - - -# NOTE: With `decode_times=False`, the "calendar" and "units" attributes are -# stored in `.attrs`. -time_non_cf = xr.DataArray( +time_encoded = xr.DataArray( data=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], dims=["time"], attrs={ "axis": "T", "long_name": "time", "standard_name": "time", - "calendar": "standard", - "units": "months since 2000-01-01", }, ) -time_non_cf_unsupported = xr.DataArray( - data=np.arange(1850 + 1 / 24.0, 1851 + 3 / 12.0, 1 / 12.0), - dims=["time"], - attrs={ - "units": "year A.D.", - "long_name": "time", - "standard_name": "time", - }, -) -time_bnds = xr.DataArray( +time_bnds_decoded = xr.DataArray( name="time_bnds", data=np.array( [ - ["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"], - ["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], - ["2000-03-01T00:00:00.000000000", "2000-04-01T00:00:00.000000000"], - ["2000-04-01T00:00:00.000000000", "2000-05-01T00:00:00.000000000"], - ["2000-05-01T00:00:00.000000000", "2000-06-01T00:00:00.000000000"], - ["2000-06-01T00:00:00.000000000", "2000-07-01T00:00:00.000000000"], - ["2000-07-01T00:00:00.000000000", "2000-08-01T00:00:00.000000000"], - ["2000-08-01T00:00:00.000000000", "2000-09-01T00:00:00.000000000"], - ["2000-09-01T00:00:00.000000000", "2000-10-01T00:00:00.000000000"], - ["2000-10-01T00:00:00.000000000", "2000-11-01T00:00:00.000000000"], - ["2000-11-01T00:00:00.000000000", "2000-12-01T00:00:00.000000000"], - ["2000-12-01T00:00:00.000000000", "2001-01-01T00:00:00.000000000"], - ["2001-01-01T00:00:00.000000000", "2001-02-01T00:00:00.000000000"], - ["2001-02-01T00:00:00.000000000", "2001-03-01T00:00:00.000000000"], - ["2001-12-01T00:00:00.000000000", "2002-01-01T00:00:00.000000000"], + [ + cftime.DatetimeGregorian(2000, 1, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 2, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 2, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 3, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 3, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 4, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 4, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 5, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 5, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 6, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 6, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 7, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 7, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 8, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 8, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 9, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 9, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 10, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 10, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 11, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 11, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2000, 12, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2000, 12, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2001, 1, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2001, 1, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2001, 2, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2001, 2, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2001, 3, 1, 0, 0, 0, 0, has_year_zero=False), + ], + [ + cftime.DatetimeGregorian(2001, 12, 1, 0, 0, 0, 0, has_year_zero=False), + cftime.DatetimeGregorian(2002, 1, 1, 0, 0, 0, 0, has_year_zero=False), + ], ], - dtype="datetime64[ns]", + dtype=object, ), - coords={"time": time_cf}, dims=["time", "bnds"], attrs={ "xcdat_bounds": "True", }, ) -time_bnds_non_cf = xr.DataArray( +time_bnds_encoded = xr.DataArray( name="time_bnds", data=[ [-1, 0], @@ -117,20 +143,9 @@ [12, 13], [13, 14], ], - coords={"time": time_non_cf}, dims=["time", "bnds"], attrs={"xcdat_bounds": "True"}, ) -tb = [] -for t in time_non_cf_unsupported: - tb.append([t - 1 / 24.0, t + 1 / 24.0]) -time_bnds_non_cf_unsupported = xr.DataArray( - name="time_bnds", - data=tb, - coords={"time": time_non_cf_unsupported}, - dims=["time", "bnds"], - attrs={"is_generated": "True"}, -) # LATITUDE # ======== @@ -171,69 +186,84 @@ # VARIABLES # ========= -ts_cf = xr.DataArray( +ts_decoded = xr.DataArray( name="ts", data=np.ones((15, 4, 4)), - coords={"time": time_cf, "lat": lat, "lon": lon}, + coords={"time": time_decoded, "lat": lat, "lon": lon}, dims=["time", "lat", "lon"], ) -ts_non_cf = xr.DataArray( +ts_encoded = xr.DataArray( name="ts", data=np.ones((15, 4, 4)), - coords={"time": time_non_cf, "lat": lat, "lon": lon}, + coords={"time": time_encoded, "lat": lat, "lon": lon}, dims=["time", "lat", "lon"], ) def generate_dataset( - cf_compliant: bool, has_bounds: bool, unsupported: bool = False + decode_times: bool, + cf_compliant: bool, + has_bounds: bool, ) -> xr.Dataset: """Generates a dataset using coordinate and data variable fixtures. Parameters ---------- - cf_compliant : bool, optional - CF compliant time units. - has_bounds : bool, optional + decode_times : bool + If True, represent time coordinates `cftime` objects. If False, + represent time coordinates as numbers. + cf_compliant : bool + If True, use CF compliant time units ("days since ..."). If False, + use non-CF compliant time units ("months since ..."). + has_bounds : bool Include bounds for coordinates. This also adds the "bounds" attribute to existing coordinates to link them to their respective bounds. - unsupported : bool, optional - Create time units that are unsupported and cannot be decoded. - Note that cf_compliant must be set to False. Returns ------- xr.Dataset Test dataset. """ - - if unsupported & cf_compliant: - raise ValueError( - "Cannot set cf_compliant=True and unsupported=True. \n" - "Set cf_compliant=False." + # First, create a dataset with either encoded or decoded coordinates. + if decode_times: + ds = xr.Dataset( + data_vars={ + "ts": ts_decoded.copy(), + }, + coords={"lat": lat.copy(), "lon": lon.copy(), "time": time_decoded.copy()}, ) - if has_bounds: + # Add the calendar and units attr to the encoding dict. + ds["time"].encoding["calendar"] = "standard" + if cf_compliant: + ds["time"].encoding["units"] = "days since 2000-01-01" + else: + ds["time"].encoding["units"] = "months since 2000-01-01" + + else: ds = xr.Dataset( data_vars={ - "ts": ts_cf.copy(), - "lat_bnds": lat_bnds.copy(), - "lon_bnds": lon_bnds.copy(), + "ts": ts_encoded.copy(), }, - coords={"lat": lat.copy(), "lon": lon.copy()}, + coords={"lat": lat.copy(), "lon": lon.copy(), "time": time_encoded.copy()}, ) + # Add the calendar and units attr to the attrs dict. + ds["time"].attrs["calendar"] = "standard" if cf_compliant: - ds.coords["time"] = time_cf.copy() - ds["time_bnds"] = time_bnds.copy() - elif not cf_compliant: - if unsupported: - ds.coords["time"] = time_non_cf_unsupported.copy() - ds["time_bnds"] = time_bnds_non_cf_unsupported.copy() - else: - ds.coords["time"] = time_non_cf.copy() - ds["time_bnds"] = time_bnds_non_cf.copy() + ds["time"].attrs["units"] = "days since 2000-01-01" + else: + ds["time"].attrs["units"] = "months since 2000-01-01" + + if has_bounds: + ds["lat_bnds"] = lat_bnds.copy() + ds["lon_bnds"] = lon_bnds.copy() + + if decode_times: + ds["time_bnds"] = time_bnds_decoded.copy() + else: + ds["time_bnds"] = time_bnds_encoded.copy() # If the "bounds" attribute is included in an existing DataArray and # added to a new Dataset, it will get dropped. Therefore, it needs to be @@ -242,15 +272,4 @@ def generate_dataset( ds["lon"].attrs["bounds"] = "lon_bnds" ds["time"].attrs["bounds"] = "time_bnds" - elif not has_bounds: - ds = xr.Dataset( - data_vars={"ts": ts_cf.copy()}, - coords={"lat": lat.copy(), "lon": lon.copy()}, - ) - - if cf_compliant: - ds.coords["time"] = time_cf.copy() - elif not cf_compliant: - ds.coords["time"] = time_non_cf.copy() - return ds diff --git a/tests/test_axis.py b/tests/test_axis.py index a44369e5..1a985737 100644 --- a/tests/test_axis.py +++ b/tests/test_axis.py @@ -3,19 +3,17 @@ import xarray as xr from tests.fixtures import generate_dataset -from xcdat.axis import center_times, get_axis_coord, get_axis_dim, swap_lon_axis +from xcdat.axis import ( + CFAxisKey, + center_times, + get_dim_coords, + get_dim_keys, + swap_lon_axis, +) -class TestGetAxisCoord: - def test_raises_error_if_coord_var_does_not_exist(self): - ds = xr.Dataset() - - with pytest.raises(KeyError): - get_axis_coord(ds, "Y") - - def test_raises_error_if_axis_or_standard_name_is_not_set_or_dim_name_is_not_valid( - self, - ): +class TestGetDimKeys: + def test_raises_error_if_dim_name_is_not_valid(self): ds = xr.Dataset( coords={ "invalid_lat_shortname": xr.DataArray( @@ -25,53 +23,122 @@ def test_raises_error_if_axis_or_standard_name_is_not_set_or_dim_name_is_not_val ) with pytest.raises(KeyError): - get_axis_coord(ds, "Y") + get_dim_keys(ds, "Y") - def test_returns_coord_var_if_axis_attr_is_set(self): - # Set the dimension name to something other than "lat" to make sure - # axis attr is being used for the match. + def test_returns_dim_name(self): ds = xr.Dataset( coords={ - "lat_not_short_name": xr.DataArray( - data=np.ones(3), dims="lat_not_short_name", attrs={"axis": "Y"} + "lat": xr.DataArray( + data=np.ones(3), dims="lat", attrs={"standard_name": "latitude"} ) } ) - result = get_axis_coord(ds, "Y") - expected = ds.lat_not_short_name + dim = get_dim_keys(ds, "Y") - assert result.identical(expected) + assert dim == "lat" - def test_returns_coord_var_if_standard_name_attr_is_set(self): - # Set the dimension name to something other than "lat" to make sure - # standard_name attr is being used for the match. + def test_returns_dim_names(self): ds = xr.Dataset( coords={ - "lat_not_short_name": xr.DataArray( - data=np.ones(3), - dims="lat_not_short_name", - attrs={"standard_name": "latitude"}, - ) + "ilev": xr.DataArray(data=np.ones(3), dims="ilev", attrs={"axis": "Z"}), + "lev": xr.DataArray(data=np.ones(3), dims="lev", attrs={"axis": "Z"}), } ) - result = get_axis_coord(ds, "Y") - expected = ds.lat_not_short_name + dim = get_dim_keys(ds, "Z") - assert result.identical(expected) + assert dim == ["ilev", "lev"] - def test_returns_coord_var_if_dim_name_is_valid(self): - ds = xr.Dataset(coords={"lat": xr.DataArray(data=np.ones(3), dims="lat")}) - result = get_axis_coord(ds, "Y") - expected = ds.lat +class TestGetDimCoords: + @pytest.fixture(autouse=True) + def setup(self): + # A dataset with "axis" attr set on all dim coord vars. + self.ds_axis = xr.Dataset( + data_vars={ + "hyai": xr.DataArray( + name="hyai", + data=np.ones(3), + dims="ilev", + attrs={"long_name": "hybrid A coefficient at layer interfaces"}, + ), + "hyam": xr.DataArray( + name="hyam", + data=np.ones(3), + dims="lev", + attrs={"long_name": "hybrid B coefficient at layer interfaces"}, + ), + }, + coords={ + "ilev": xr.DataArray( + data=np.ones(3), + dims="ilev", + attrs={"axis": "Z"}, + ), + "lev": xr.DataArray( + data=np.ones(3), + dims="lev", + attrs={"axis": "Z"}, + ), + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={"axis": "Y"}, + ), + }, + ) - assert result.identical(expected) + # A dataset with "standard_name" attr set on all dim coord vars. + self.ds_sn = xr.Dataset( + data_vars={ + "hyai": xr.DataArray( + name="hyai", + data=np.ones(3), + dims="ilev", + attrs={"long_name": "hybrid A coefficient at layer interfaces"}, + ), + "hyam": xr.DataArray( + name="hyam", + data=np.ones(3), + dims="lev", + attrs={"long_name": "hybrid B coefficient at layer interfaces"}, + ), + }, + coords={ + "ilev": xr.DataArray( + data=np.ones(3), + dims="ilev", + attrs={ + "standard_name": "atmosphere_hybrid_sigma_pressure_coordinate" + }, + ), + "lev": xr.DataArray( + data=np.ones(3), + dims="lev", + attrs={ + "standard_name": "atmosphere_hybrid_sigma_pressure_coordinate" + }, + ), + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={"standard_name": "latitude"}, + ), + }, + ) + def test_raises_error_if_dim_does_not_exist(self): + ds = xr.Dataset() + dims: CFAxisKey = ["X", "Y", "T", "Z"] # type: ignore -class TestGetAxisDim: - def test_raises_error_if_dim_name_is_not_valid(self): + for dim in dims: + with pytest.raises(KeyError): + get_dim_coords(ds, dim) # type: ignore + + def test_raises_error_if_axis_or_standard_name_is_not_set_or_dim_name_is_not_valid( + self, + ): ds = xr.Dataset( coords={ "invalid_lat_shortname": xr.DataArray( @@ -81,26 +148,113 @@ def test_raises_error_if_dim_name_is_not_valid(self): ) with pytest.raises(KeyError): - get_axis_dim(ds, "Y") + get_dim_coords(ds, "Y") - def test_returns_dim_name(self): + def test_raises_error_if_a_dataarray_has_multiple_dims_for_the_same_axis(self): + da = xr.DataArray( + coords={ + "ilev": xr.DataArray( + data=np.ones(3), + dims="ilev", + attrs={"axis": "Z"}, + ), + "lev": xr.DataArray( + data=np.ones(3), + dims="lev", + attrs={"axis": "Z"}, + ), + }, + dims=["ilev", "lev"], + ) + + with pytest.raises(ValueError): + get_dim_coords(da, "Z") + + def test_raises_error_if_multidimensional_coords_are_only_present_for_an_axis(self): + lat = xr.DataArray( + data=np.array([[0, 1, 2], [3, 4, 5]]), + dims=["placeholder_1", "placeholder_2"], + attrs={"units": "degrees_north", "axis": "Y"}, + ) + ds = xr.Dataset(coords={"lat": lat}) + + with pytest.raises(KeyError): + get_dim_coords(ds, "Y") + + def test_returns_dataset_dimension_coordinate_vars_using_common_var_names( + self, + ): ds = xr.Dataset( coords={ - "lat": xr.DataArray( - data=np.ones(3), dims="lat", attrs={"standard_name": "latitude"} - ) + "lat": xr.DataArray(data=np.ones(3), dims="lat"), + "lon": xr.DataArray(data=np.ones(3), dims="lon"), + "time": xr.DataArray(data=np.ones(3), dims="time"), + "atmosphere_sigma_coordinate": xr.DataArray( + data=np.ones(3), dims="atmosphere_sigma_coordinate" + ), } ) - dim = get_axis_dim(ds, "Y") + result = get_dim_coords(ds, "X") + assert result.identical(ds["lon"]) # type: ignore - assert dim == "lat" + result = get_dim_coords(ds, "Y") + assert result.identical(ds["lat"]) # type: ignore + + result = get_dim_coords(ds, "T") + assert result.identical(ds["time"]) # type: ignore + + result = get_dim_coords(ds, "Z") + assert result.identical(ds["atmosphere_sigma_coordinate"]) # type: ignore + + def test_returns_dataset_dimension_coordinate_vars_using_axis_attr(self): + # For example, E3SM datasets might have "ilev" and "lev" dimensions + # with the dim coord var attr "axis" both mapped to "Z". + result = get_dim_coords(self.ds_axis, "Z") + expected = xr.Dataset( + coords={"ilev": self.ds_axis.ilev, "lev": self.ds_axis.lev} + ) + + assert result.identical(expected) # type: ignore + + def test_returns_dataset_dimension_coordinate_vars_using_standard_name_attr(self): + # For example, E3SM datasets might have "ilev" and "lev" dimensions + # with the dim coord var attr "standard_name" both mapped to + # "atmosphere_hybrid_sigma_pressure_coordinate". + result = get_dim_coords(self.ds_sn, "Z") + expected = xr.Dataset(coords={"ilev": self.ds_sn.ilev, "lev": self.ds_sn.lev}) + + assert result.identical(expected) # type: ignore + + def test_returns_dataarray_dimension_coordinate_var_using_axis_attr(self): + result = get_dim_coords(self.ds_axis.hyai, "Z") + expected = self.ds_axis.ilev + + assert result.identical(expected) + + result = get_dim_coords(self.ds_axis.hyam, "Z") + expected = self.ds_axis.lev + + assert result.identical(expected) + + def test_returns_dataarray_dimension_coordinate_var_using_standard_name_attr(self): + result = get_dim_coords(self.ds_sn.hyai, "Z") + expected = self.ds_sn.ilev + + assert result.identical(expected) + + result = get_dim_coords(self.ds_sn.hyam, "Z") + expected = self.ds_sn.lev + + assert result.identical(expected) class TestCenterTimes: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_raises_error_if_time_coord_var_does_not_exist_in_dataset(self): ds = self.ds.copy() @@ -109,67 +263,145 @@ def test_raises_error_if_time_coord_var_does_not_exist_in_dataset(self): with pytest.raises(KeyError): center_times(ds) - def test_raises_error_if_time_bounds_does_not_exist_in_the_dataset(self): - ds = self.ds.copy() - ds = ds.drop_vars("time_bnds") + def test_skips_centering_time_coords_for_a_dimension_if_bounds_do_not_exist(self): + ds = xr.Dataset( + coords={ + "time": xr.DataArray( + name="time", + data=np.array( + [ + "2000-01-31T12:00:00.000000000", + "2000-02-29T12:00:00.000000000", + "2000-03-31T12:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims="time", + attrs={ + "long_name": "time", + "standard_name": "time", + "axis": "T", + "bounds": "time_bnds", + }, + ), + }, + ) - with pytest.raises(KeyError): - center_times(ds) + # Compare result of the method against the expected. + result = center_times(ds) + expected = ds.copy() - def test_gets_time_as_the_midpoint_between_time_bounds(self): - ds = self.ds.copy() + assert result.identical(expected) - # Make the time coordinates uncentered. - uncentered_time = np.array( - [ - "2000-01-31T12:00:00.000000000", - "2000-02-29T12:00:00.000000000", - "2000-03-31T12:00:00.000000000", - "2000-04-30T00:00:00.000000000", - "2000-05-31T12:00:00.000000000", - "2000-06-30T00:00:00.000000000", - "2000-07-31T12:00:00.000000000", - "2000-08-31T12:00:00.000000000", - "2000-09-30T00:00:00.000000000", - "2000-10-16T12:00:00.000000000", - "2000-11-30T00:00:00.000000000", - "2000-12-31T12:00:00.000000000", - "2001-01-31T12:00:00.000000000", - "2001-02-28T00:00:00.000000000", - "2001-12-31T12:00:00.000000000", - ], - dtype="datetime64[ns]", + def test_returns_time_coords_as_the_midpoint_between_time_bounds(self): + ds = xr.Dataset( + data_vars={ + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + [ + "2000-01-01T00:00:00.000000000", + "2000-02-01T00:00:00.000000000", + ], + [ + "2000-02-01T00:00:00.000000000", + "2000-03-01T00:00:00.000000000", + ], + [ + "2000-03-01T00:00:00.000000000", + "2000-04-01T00:00:00.000000000", + ], + ], + dtype="datetime64[ns]", + ), + dims=["time", "bnds"], + attrs={ + "xcdat_bounds": "True", + }, + ), + }, + coords={ + "time": xr.DataArray( + name="time", + data=np.array( + [ + "2000-01-31T12:00:00.000000000", + "2000-02-29T12:00:00.000000000", + "2000-03-31T12:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims="time", + attrs={ + "long_name": "time", + "standard_name": "time", + "axis": "T", + "bounds": "time_bnds", + }, + ), + "time2": xr.DataArray( + name="time2", + data=np.array( + [ + "2000-01-31T12:00:00.000000000", + "2000-02-29T12:00:00.000000000", + "2000-03-31T12:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims="time", + attrs={ + "long_name": "time", + "standard_name": "time", + "axis": "T", + "bounds": "time_bnds", + }, + ), + }, ) - ds.time.data[:] = uncentered_time + + result = center_times(ds) # Compare result of the method against the expected. - expected = ds.copy() - expected_time_data = np.array( - [ - "2000-01-16T12:00:00.000000000", - "2000-02-15T12:00:00.000000000", - "2000-03-16T12:00:00.000000000", - "2000-04-16T00:00:00.000000000", - "2000-05-16T12:00:00.000000000", - "2000-06-16T00:00:00.000000000", - "2000-07-16T12:00:00.000000000", - "2000-08-16T12:00:00.000000000", - "2000-09-16T00:00:00.000000000", - "2000-10-16T12:00:00.000000000", - "2000-11-16T00:00:00.000000000", - "2000-12-16T12:00:00.000000000", - "2001-01-16T12:00:00.000000000", - "2001-02-15T00:00:00.000000000", - "2001-12-16T12:00:00.000000000", - ], - dtype="datetime64[ns]", - ) - expected = expected.assign_coords( - { + expected = xr.Dataset( + data_vars={ + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + [ + "2000-01-01T00:00:00.000000000", + "2000-02-01T00:00:00.000000000", + ], + [ + "2000-02-01T00:00:00.000000000", + "2000-03-01T00:00:00.000000000", + ], + [ + "2000-03-01T00:00:00.000000000", + "2000-04-01T00:00:00.000000000", + ], + ], + dtype="datetime64[ns]", + ), + dims=["time", "bnds"], + attrs={ + "xcdat_bounds": "True", + }, + ), + }, + coords={ "time": xr.DataArray( name="time", - data=expected_time_data, - coords={"time": expected_time_data}, + data=np.array( + [ + "2000-01-16T12:00:00.000000000", + "2000-02-15T12:00:00.000000000", + "2000-03-16T12:00:00.000000000", + ], + dtype="datetime64[ns]", + ), dims="time", attrs={ "long_name": "time", @@ -177,21 +409,41 @@ def test_gets_time_as_the_midpoint_between_time_bounds(self): "axis": "T", "bounds": "time_bnds", }, - ) - } + ), + "time2": xr.DataArray( + name="time2", + data=np.array( + [ + "2000-01-16T12:00:00.000000000", + "2000-02-15T12:00:00.000000000", + "2000-03-16T12:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims="time", + attrs={ + "long_name": "time", + "standard_name": "time", + "axis": "T", + "bounds": "time_bnds", + }, + ), + }, ) - # Update time bounds with centered time coordinates. - time_bounds = ds.time_bnds.copy() - time_bounds["time"] = expected.time - expected["time_bnds"] = time_bounds - result = center_times(ds) assert result.identical(expected) class TestSwapLonAxis: + def test_raises_error_if_no_longitude_axis_exists(self): + ds = generate_dataset(decode_times=True, cf_compliant=False, has_bounds=True) + ds = ds.drop_dims("lon") + + with pytest.raises(KeyError): + swap_lon_axis(ds, to=(-180, 180)) + def test_raises_error_with_incorrect_lon_orientation_for_swapping(self): - ds = generate_dataset(cf_compliant=True, has_bounds=True) + ds = generate_dataset(decode_times=True, cf_compliant=False, has_bounds=True) with pytest.raises(ValueError): swap_lon_axis(ds, to=9000) # type: ignore @@ -256,11 +508,11 @@ def test_does_not_swap_if_desired_orientation_is_the_same_as_the_existing_orient }, ) - result = swap_lon_axis(ds_360, to=(0, 360)) + result = swap_lon_axis(ds_360, to=(0, 360), sort_ascending=False) assert result.identical(ds_360) - def test_swap_from_360_to_180_and_sorts(self): + def test_does_not_swap_bounds_if_bounds_do_not_exist(self): ds_360 = xr.Dataset( coords={ "lon": xr.DataArray( @@ -268,18 +520,70 @@ def test_swap_from_360_to_180_and_sorts(self): data=np.array([60, 150, 271]), dims=["lon"], attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, - ) + ), + }, + data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array([0, 1, 2]), + dims=["lon"], + attrs={"test_attr": "test"}, + ), + }, + ) + + result = swap_lon_axis(ds_360, to=(-180, 180)) + expected = xr.Dataset( + coords={ + "lon": xr.DataArray( + data=np.array([-89, 60, 150]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), + }, + data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array([2, 0, 1]), + dims=["lon"], + attrs={"test_attr": "test"}, + ), + }, + ) + + assert result.identical(expected) + + def test_swaps_single_dim_from_360_to_180_and_sorts(self): + ds_360 = xr.Dataset( + coords={ + "lon": xr.DataArray( + name="lon", + data=np.array([60, 150, 271]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), + "lon2": xr.DataArray( + name="lon2", + data=np.array([60, 150, 271]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), }, data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array([0, 1, 2]), + dims=["lon"], + attrs={"test_attr": "test"}, + ), "lon_bnds": xr.DataArray( name="lon_bnds", data=np.array([[0, 120], [120, 181], [181, 360]]), dims=["lon", "bnds"], attrs={"xcdat_bounds": "True"}, - ) + ), }, ) - result = swap_lon_axis(ds_360, to=(-180, 180)) expected = xr.Dataset( coords={ @@ -287,21 +591,35 @@ def test_swap_from_360_to_180_and_sorts(self): data=np.array([-89, 60, 150]), dims=["lon"], attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, - ) + ), + "lon2": xr.DataArray( + name="lon2", + data=np.array([-89, 60, 150]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), }, data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array([2, 0, 1]), + dims=["lon"], + attrs={"test_attr": "test"}, + ), "lon_bnds": xr.DataArray( name="lon_bnds", data=np.array([[-179, 0], [0, 120], [120, -179]]), dims=["lon", "bnds"], attrs={"xcdat_bounds": "True"}, - ) + ), }, ) assert result.identical(expected) - def test_swap_from_180_to_360_and_sorts_with_prime_meridian_cell_in_lon_bnds(self): + def test_swaps_single_dim_from_180_to_360_and_sorts_with_prime_meridian_cell_in_lon_bnds( + self, + ): ds_180 = xr.Dataset( coords={ "lon": xr.DataArray( @@ -309,9 +627,21 @@ def test_swap_from_180_to_360_and_sorts_with_prime_meridian_cell_in_lon_bnds(sel data=np.array([-180, -1, 0, 1, 179]), dims=["lon"], attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, - ) + ), + "lon2": xr.DataArray( + name="lon2", + data=np.array([-180, -1, 0, 1, 179]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), }, data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array([0, 1, 2, 3, 4]), + dims=["lon"], + attrs={"test_attr": "test"}, + ), "lon_bnds": xr.DataArray( name="lon_bnds", data=np.array( @@ -326,15 +656,112 @@ def test_swap_from_180_to_360_and_sorts_with_prime_meridian_cell_in_lon_bnds(sel dims=["lon", "bnds"], attrs={"xcdat_bounds": "True"}, ), + }, + ) + result = swap_lon_axis(ds_180, to=(0, 360)) + expected = xr.Dataset( + coords={ + "lon": xr.DataArray( + name="lon", + data=np.array([0, 1, 179, 180, 359, 360]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), + "lon2": xr.DataArray( + name="lon2", + data=np.array([0, 1, 179, 180, 359, 360]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), + }, + data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array([2, 3, 4, 0, 1, 2]), + dims=["lon"], + attrs={"test_attr": "test"}, + ), + "lon_bnds": xr.DataArray( + name="lon_bnds", + data=np.array( + [ + [0, 0.5], + [0.5, 1.5], + [1.5, 179.5], + [179.5, 358.5], + [358.5, 359.5], + [359.5, 360], + ] + ), + dims=["lon", "bnds"], + attrs={"xcdat_bounds": "True"}, + ), + }, + ) + + assert result.identical(expected) + + def test_swaps_all_dims_from_180_to_360_and_sorts_with_prime_meridian_cell_in_lon_bnds( + self, + ): + ds_180 = xr.Dataset( + coords={ + "lon": xr.DataArray( + name="lon", + data=np.array([-180, -1, 0, 1, 179]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), + "zlon": xr.DataArray( + name="zlon", + data=np.array([-180, -1, 0, 1, 179]), + dims=["zlon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "zlon_bnds"}, + ), + }, + data_vars={ "ts": xr.DataArray( name="ts", data=np.array([0, 1, 2, 3, 4]), dims=["lon"], attrs={"test_attr": "test"}, ), + "ts2": xr.DataArray( + name="ts2", + data=np.array([0, 1, 2, 3, 4]), + dims=["zlon"], + attrs={"test_attr": "test"}, + ), + "lon_bnds": xr.DataArray( + name="lon_bnds", + data=np.array( + [ + [-180.5, -1.5], + [-1.5, -0.5], + [-0.5, 0.5], + [0.5, 1.5], + [1.5, 179.5], + ] + ), + dims=["lon", "bnds"], + attrs={"xcdat_bounds": "True"}, + ), + "zlon_bnds": xr.DataArray( + name="zlon_bnds", + data=np.array( + [ + [-180.5, -1.5], + [-1.5, -0.5], + [-0.5, 0.5], + [0.5, 1.5], + [1.5, 179.5], + ] + ), + dims=["zlon", "bnds"], + attrs={"xcdat_bounds": "True"}, + ), }, ) - result = swap_lon_axis(ds_180, to=(0, 360)) expected = xr.Dataset( coords={ @@ -343,9 +770,27 @@ def test_swap_from_180_to_360_and_sorts_with_prime_meridian_cell_in_lon_bnds(sel data=np.array([0, 1, 179, 180, 359, 360]), dims=["lon"], attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, - ) + ), + "zlon": xr.DataArray( + name="zlon", + data=np.array([0, 1, 179, 180, 359, 360]), + dims=["zlon"], + attrs={"units": "degrees_east", "axis": "X", "bounds": "zlon_bnds"}, + ), }, data_vars={ + "ts": xr.DataArray( + name="ts", + data=np.array([2, 3, 4, 0, 1, 2]), + dims=["lon"], + attrs={"test_attr": "test"}, + ), + "ts2": xr.DataArray( + name="ts2", + data=np.array([2, 3, 4, 0, 1, 2]), + dims=["zlon"], + attrs={"test_attr": "test"}, + ), "lon_bnds": xr.DataArray( name="lon_bnds", data=np.array( @@ -361,11 +806,20 @@ def test_swap_from_180_to_360_and_sorts_with_prime_meridian_cell_in_lon_bnds(sel dims=["lon", "bnds"], attrs={"xcdat_bounds": "True"}, ), - "ts": xr.DataArray( - name="ts", - data=np.array([2, 3, 4, 0, 1, 2]), - dims=["lon"], - attrs={"test_attr": "test"}, + "zlon_bnds": xr.DataArray( + name="zlon_bnds", + data=np.array( + [ + [0, 0.5], + [0.5, 1.5], + [1.5, 179.5], + [179.5, 358.5], + [358.5, 359.5], + [359.5, 360], + ] + ), + dims=["zlon", "bnds"], + attrs={"xcdat_bounds": "True"}, ), }, ) diff --git a/tests/test_bounds.py b/tests/test_bounds.py index 0e101cf9..f4c48ee8 100644 --- a/tests/test_bounds.py +++ b/tests/test_bounds.py @@ -14,8 +14,12 @@ class TestBoundsAccessor: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=False) - self.ds_with_bnds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=False + ) + self.ds_with_bnds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test__init__(self): obj = BoundsAccessor(self.ds) @@ -54,8 +58,12 @@ def test_keys_property_returns_a_list_of_sorted_bounds_keys(self): class TestAddMissingBounds: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=False) - self.ds_with_bnds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=False + ) + self.ds_with_bnds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_adds_bounds_to_the_dataset(self): ds = self.ds_with_bnds.copy() @@ -65,37 +73,28 @@ def test_adds_bounds_to_the_dataset(self): result = ds.bounds.add_missing_bounds() assert result.identical(self.ds_with_bnds) - def test_adds_bounds_to_the_dataset_skips_nondimensional_axes(self): - # generate dataset with height coordinate - ds = generate_dataset(cf_compliant=True, has_bounds=True) - ds = ds.assign_coords({"height": 2}) - - # drop bounds - dsm = ds.drop_vars(["lat_bnds", "lon_bnds"]).copy() + def test_skips_adding_bounds_for_coords_that_are_1_dim_singleton(self): + # Length <=1 + lon = xr.DataArray( + data=np.array([0]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X"}, + ) + ds = xr.Dataset(coords={"lon": lon}) - # test bounds re-generation - result = dsm.bounds.add_missing_bounds() + result = ds.bounds.add_missing_bounds() - # dataset with missing bounds added should match dataset with bounds - # and added height coordinate assert result.identical(ds) - def test_skips_adding_bounds_for_coords_that_are_multidimensional_or_len_of_1(self): - # Multidimensional - lat = xr.DataArray( - data=np.array([[0, 1, 2], [3, 4, 5]]), - dims=["placeholder_1", "placeholder_2"], - attrs={"units": "degrees_north", "axis": "Y"}, - ) - # Length <=1 + def test_skips_adding_bounds_for_coords_that_are_0_dim_singleton(self): + # 0-dimensional array lon = xr.DataArray( - data=np.array([0]), - dims=["lon"], + data=float(0), attrs={"units": "degrees_east", "axis": "X"}, ) - ds = xr.Dataset(coords={"lat": lat, "lon": lon}) + ds = xr.Dataset(coords={"lon": lon}) - result = ds.bounds.add_missing_bounds("Y") + result = ds.bounds.add_missing_bounds() assert result.identical(ds) @@ -103,26 +102,58 @@ def test_skips_adding_bounds_for_coords_that_are_multidimensional_or_len_of_1(se class TestGetBounds: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=False) - self.ds_with_bnds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=False + ) + self.ds_with_bnds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_raises_error_with_invalid_axis_key(self): with pytest.raises(ValueError): self.ds.bounds.get_bounds("incorrect_axis_argument") - def test_raises_error_if_bounds_attr_is_not_set_on_coord_var(self): + def test_raises_error_if_no_bounds_are_found_because_none_exist(self): ds = xr.Dataset( + data_vars={"ts": xr.DataArray(data=np.ones(3), dims="lat")}, coords={ - "lat": xr.DataArray(data=np.ones(3), dims="lat", attrs={"axis": "Y"}) + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={"axis": "Y", "bounds": "lat_bnds"}, + ) + }, + ) + + # No "Y" axis bounds are found in the entire dataset. + with pytest.raises(KeyError): + ds.bounds.get_bounds("Y") + + # No "Y" axis bounds are found for the specified var_key. + with pytest.raises(KeyError): + ds.bounds.get_bounds("Y", var_key="ts") + + def test_raises_error_if_no_bounds_are_found_because_bounds_attr_not_set(self): + ds = xr.Dataset( + coords={ + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={"axis": "Y"}, + ) }, data_vars={ - "lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]) + "var": xr.DataArray(data=np.ones((3)), dims=["lat"]), + "lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]), }, ) + with pytest.raises(KeyError): ds.bounds.get_bounds("Y") - def test_raises_error_if_bounds_attr_is_set_but_no_bounds_data_var_exists(self): + def test_raises_error_if_no_bounds_are_found_with_bounds_attr_set_because_none_exist( + self, + ): ds = xr.Dataset( coords={ "lat": xr.DataArray( @@ -130,13 +161,55 @@ def test_raises_error_if_bounds_attr_is_set_but_no_bounds_data_var_exists(self): dims="lat", attrs={"axis": "Y", "bounds": "lat_bnds"}, ) - } + }, + data_vars={ + "var": xr.DataArray(data=np.ones((3)), dims=["lat"]), + }, ) with pytest.raises(KeyError): - ds.bounds.get_bounds("Y") + ds.bounds.get_bounds("Y", var_key="var") - def test_returns_bounds(self): + def test_returns_single_coord_var_axis_bounds_as_datarray_object(self): + ds = xr.Dataset( + coords={ + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={"axis": "Y", "bounds": "lat_bnds"}, + ) + }, + data_vars={ + "lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]), + }, + ) + + result = ds.bounds.get_bounds("Y", var_key="lat") + expected = ds.lat_bnds + + assert result.identical(expected) + + def test_returns_single_data_var_axis_bounds_as_datarray_object(self): + ds = xr.Dataset( + coords={ + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={"axis": "Y", "bounds": "lat_bnds"}, + ) + }, + data_vars={ + "var": xr.DataArray(data=np.ones((3)), dims=["lat"]), + "lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]), + }, + ) + + result = ds.bounds.get_bounds("Y", var_key="var") + expected = ds.lat_bnds + + assert result.identical(expected) + + def test_returns_single_dataset_axis_bounds_as_a_dataarray_object(self): ds = xr.Dataset( coords={ "lat": xr.DataArray( @@ -150,41 +223,67 @@ def test_returns_bounds(self): }, ) - lat_bnds = ds.bounds.get_bounds("Y") + result = ds.bounds.get_bounds("Y") + expected = ds.lat_bnds - assert lat_bnds.identical(ds.lat_bnds) + assert result.identical(expected) + def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self): + ds = xr.Dataset( + coords={ + "lat": xr.DataArray( + data=np.ones(3), + dims="lat", + attrs={ + "axis": "Y", + "standard_name": "latitude", + "bounds": "lat_bnds", + }, + ), + "lat2": xr.DataArray( + data=np.ones(3), + dims="lat2", + attrs={ + "axis": "Y", + "standard_name": "latitude", + "bounds": "lat2_bnds", + }, + ), + }, + data_vars={ + "var": xr.DataArray(data=np.ones(3), dims=["lat"]), + "lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]), + "lat2_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat2", "bnds"]), + }, + ) -class TestAddBounds: - @pytest.fixture(autouse=True) - def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=False) - self.ds_with_bnds = generate_dataset(cf_compliant=True, has_bounds=True) + result = ds.bounds.get_bounds("Y") + expected = ds.drop_vars("var") - def test_raises_error_if_bounds_already_exist(self): - ds = self.ds_with_bnds.copy() + assert result.identical(expected) - with pytest.raises(ValueError): - ds.bounds.add_bounds("Y") - def test_raises_errors_for_data_dim_and_length(self): - # Multidimensional - lat = xr.DataArray( - data=np.array([[0, 1, 2], [3, 4, 5]]), - dims=["placeholder_1", "placeholder_2"], - attrs={"units": "degrees_north", "axis": "Y"}, +class TestAddBounds: + @pytest.fixture(autouse=True) + def setup(self): + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=False + ) + self.ds_with_bnds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True ) + + def test_raises_error_for_singleton_coords(self): # Length <=1 lon = xr.DataArray( data=np.array([0]), dims=["lon"], attrs={"units": "degrees_east", "axis": "X"}, ) - ds = xr.Dataset(coords={"lat": lat, "lon": lon}) + ds = xr.Dataset(coords={"lon": lon}) - # If coords dimensions does not equal 1. with pytest.raises(ValueError): - ds.bounds.add_bounds("Y") + ds.bounds.add_bounds("X") def test_raises_error_if_lat_coord_var_units_is_not_in_degrees(self): lat = xr.DataArray( @@ -213,8 +312,32 @@ def test_adds_bounds_and_sets_units_to_degrees_north_if_lat_coord_var_is_missing assert result.lat.attrs["units"] == "degrees_north" assert result.lat.attrs["bounds"] == "lat_bnds" + def test_skips_adding_bounds_for_coord_vars_with_bounds(self): + ds = self.ds_with_bnds.copy() + result = ds.bounds.add_bounds("Y") + + assert ds.identical(result) + def test_add_bounds_for_dataset_with_time_coords_as_datetime_objects(self): ds = self.ds.copy() + ds = ds.drop_dims("time") + ds["time"] = xr.DataArray( + name="time", + data=np.array( + [ + "2000-01-01T12:00:00.000000000", + "2000-02-01T12:00:00.000000000", + "2000-03-01T12:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + }, + ) result = ds.bounds.add_bounds("Y") assert result.lat_bnds.equals(lat_bnds) @@ -233,21 +356,9 @@ def test_add_bounds_for_dataset_with_time_coords_as_datetime_objects(self): name="time_bnds", data=np.array( [ - ["2000-01-01T12:00:00.000000000", "2000-01-31T12:00:00.000000000"], - ["2000-01-31T12:00:00.000000000", "2000-03-01T12:00:00.000000000"], - ["2000-03-01T12:00:00.000000000", "2000-03-31T18:00:00.000000000"], - ["2000-03-31T18:00:00.000000000", "2000-05-01T06:00:00.000000000"], - ["2000-05-01T06:00:00.000000000", "2000-05-31T18:00:00.000000000"], - ["2000-05-31T18:00:00.000000000", "2000-07-01T06:00:00.000000000"], - ["2000-07-01T06:00:00.000000000", "2000-08-01T00:00:00.000000000"], - ["2000-08-01T00:00:00.000000000", "2000-08-31T18:00:00.000000000"], - ["2000-08-31T18:00:00.000000000", "2000-10-01T06:00:00.000000000"], - ["2000-10-01T06:00:00.000000000", "2000-10-31T18:00:00.000000000"], - ["2000-10-31T18:00:00.000000000", "2000-12-01T06:00:00.000000000"], - ["2000-12-01T06:00:00.000000000", "2001-01-01T00:00:00.000000000"], - ["2001-01-01T00:00:00.000000000", "2001-01-31T06:00:00.000000000"], - ["2001-01-31T06:00:00.000000000", "2001-07-17T06:00:00.000000000"], - ["2001-07-17T06:00:00.000000000", "2002-05-17T18:00:00.000000000"], + ["1999-12-17T00:00:00.000000000", "2000-01-17T00:00:00.000000000"], + ["2000-01-17T00:00:00.000000000", "2000-02-16T00:00:00.000000000"], + ["2000-02-16T00:00:00.000000000", "2000-03-16T00:00:00.000000000"], ], dtype="datetime64[ns]", ), diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 82d55f85..a10177e1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,5 +1,4 @@ import logging -import pathlib import warnings import cftime @@ -9,12 +8,9 @@ from tests.fixtures import generate_dataset from xcdat.dataset import ( - _has_cf_compliant_time, _keep_single_var, _postprocess_dataset, - _preprocess_non_cf_dataset, - _split_time_units_attr, - decode_non_cf_time, + decode_time, open_dataset, open_mfdataset, ) @@ -31,35 +27,41 @@ def setup(self, tmp_path): dir.mkdir() self.file_path = f"{dir}/file.nc" - def test_non_cf_compliant_time_is_not_decoded(self): - ds = generate_dataset(cf_compliant=False, has_bounds=True) + def test_skip_decoding_time_explicitly(self): + ds = generate_dataset(decode_times=False, cf_compliant=True, has_bounds=True) ds.to_netcdf(self.file_path) result = open_dataset(self.file_path, decode_times=False) - expected = generate_dataset(cf_compliant=False, has_bounds=True) + expected = generate_dataset( + decode_times=False, + cf_compliant=True, + has_bounds=True, + ) + assert result.identical(expected) - def test_non_cf_compliant_and_unsupported_time_is_not_decoded(self, caplog): + def test_skips_decoding_non_cf_compliant_time_with_unsupported_units(self, caplog): # Update logger level to silence the logger warning during test runs. caplog.set_level(logging.ERROR) - ds = generate_dataset(cf_compliant=False, has_bounds=True, unsupported=True) + ds = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds["time"].attrs["units"] = "year A.D." ds.to_netcdf(self.file_path) # even though decode_times=True, it should fail to decode unsupported time axis - result = open_dataset(self.file_path, decode_times=True) + result = open_dataset(self.file_path, decode_times=False) expected = ds assert result.identical(expected) - def test_non_cf_compliant_time_is_decoded(self): - ds = generate_dataset(cf_compliant=False, has_bounds=False) + def test_decode_time_in_days(self): + ds = generate_dataset(decode_times=False, cf_compliant=True, has_bounds=True) ds.to_netcdf(self.file_path) - result = open_dataset(self.file_path, data_var="ts") + result = open_dataset(self.file_path, data_var="ts", decode_times=True) - # Generate an expected dataset with decoded non-CF compliant time units. - expected = generate_dataset(cf_compliant=True, has_bounds=True) + # Generate an expected dataset with decoded CF compliant time units. + expected = ds.copy() expected["time"] = xr.DataArray( name="time", data=np.array( @@ -68,46 +70,46 @@ def test_non_cf_compliant_time_is_decoded(self): 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 2, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 3, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 4, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 5, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 5, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 6, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 6, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 7, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 7, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 8, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 8, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 9, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 9, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 10, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 10, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 11, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 11, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 12, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 12, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 1, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 13, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 2, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 14, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 3, 1, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 15, 0, 0, 0, 0, has_year_zero=False ), ], dtype="object", @@ -120,122 +122,122 @@ def test_non_cf_compliant_time_is_decoded(self): [ [ cftime.DatetimeGregorian( - 1999, 12, 16, 12, 0, 0, 0, has_year_zero=False + 1999, 12, 31, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 2, 15, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 2, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 2, 15, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 2, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 3, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 3, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 3, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 3, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 4, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 4, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 4, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 4, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 5, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 5, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 5, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 5, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 6, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 6, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 6, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 6, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 7, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 7, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 7, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 7, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 8, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 8, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 8, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 8, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 9, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 9, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 9, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 9, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 10, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 10, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 10, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 10, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 11, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 11, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 11, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 11, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 12, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 12, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 12, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 12, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 13, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2001, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 13, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 2, 15, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 14, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2001, 2, 15, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 14, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 3, 15, 0, 0, 0, 0, has_year_zero=False + 2000, 1, 15, 0, 0, 0, 0, has_year_zero=False ), ], ], @@ -244,123 +246,63 @@ def test_non_cf_compliant_time_is_decoded(self): dims=["time", "bnds"], attrs={"xcdat_bounds": "True"}, ) - expected.time.attrs = { "axis": "T", "long_name": "time", "standard_name": "time", "bounds": "time_bnds", } + + assert result.identical(expected) + + # Check encoding is preserved. expected.time.encoding = { + "zlib": False, + "szip": False, + "zstd": False, + "bzip2": False, + "blosc": False, + "shuffle": False, + "complevel": 0, + "fletcher32": False, + "contiguous": True, + "chunksizes": None, # Set source as result source because it changes every test run. "source": result.time.encoding["source"], - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, - "units": "months since 2000-01-01", + "original_shape": (15,), + "dtype": np.dtype("int64"), + "units": "days since 2000-01-01", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "zlib": False, + "szip": False, + "zstd": False, + "bzip2": False, + "blosc": False, + "shuffle": False, + "complevel": 0, + "fletcher32": False, + "contiguous": True, + "chunksizes": None, + "source": result.time.encoding["source"], + "original_shape": (15, 2), + "dtype": np.dtype("int64"), + "units": "days since 2000-01-01", "calendar": "standard", } - assert result.identical(expected) assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_preserves_lat_and_lon_bounds_if_they_exist(self): - ds = generate_dataset(cf_compliant=True, has_bounds=True) - - # Suppress UserWarning regarding missing time.encoding "units" because - # it is not relevant to this test. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - ds.to_netcdf(self.file_path) - - result = open_dataset(self.file_path, data_var="ts") - expected = ds.copy() - - assert result.identical(expected) - - def test_keeps_specified_var(self): - ds = generate_dataset(cf_compliant=True, has_bounds=True) - - # Create a modified version of the Dataset with a new var - ds_mod = ds.copy() - ds_mod["tas"] = ds_mod.ts.copy() - - # Suppress UserWarning regarding missing time.encoding "units" because - # it is not relevant to this test. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - ds_mod.to_netcdf(self.file_path) + def test_decode_time_in_months(self): + ds = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds.to_netcdf(self.file_path) result = open_dataset(self.file_path, data_var="ts") - expected = ds.copy() - - assert result.identical(expected) - - -class TestOpenMfDataset: - @pytest.fixture(autouse=True) - def setUp(self, tmp_path): - # Create temporary directory to save files. - dir = tmp_path / "input_data" - dir.mkdir() - self.file_path1 = f"{dir}/file1.nc" - self.file_path2 = f"{dir}/file2.nc" - - def test_mfdataset_keeps_time_encoding_dict(self): - ds1 = generate_dataset(cf_compliant=True, has_bounds=True) - ds1.to_netcdf(self.file_path1) - - # Create another dataset that extends the time coordinates by 1 value, - # to mimic a multifile dataset. - ds2 = generate_dataset(cf_compliant=True, has_bounds=True) - ds2 = ds2.isel(dict(time=slice(0, 1))) - ds2["time"].values[:] = np.array( - ["2002-01-16T12:00:00.000000000"], - dtype="datetime64[ns]", - ) - ds2.to_netcdf(self.file_path2) - - result = open_mfdataset([self.file_path1, self.file_path2], decode_times=True) - expected = ds1.merge(ds2) - - assert result.identical(expected) - - # We mainly care for the "source" and "original_shape" attrs (updated - # internally by xCDAT), and the "calendar" and "units" attrs. We don't - # perform equality assertion on the entire time `.encoding` dict because - # there might be different encoding attributes added or removed between - # xarray versions (e.g., "bzip2", "ztsd", "blosc", and "szip" are added - # in v2022.06.0), which makes that assertion fragile. - paths = result.time.encoding["source"] - assert self.file_path1 in paths[0] - assert self.file_path2 in paths[1] - assert result.time.encoding["original_shape"] == (16,) - assert result.time.encoding["calendar"] == "standard" - assert result.time.encoding["units"] == "days since 2000-01-01" - - def test_non_cf_compliant_time_is_not_decoded(self): - ds1 = generate_dataset(cf_compliant=False, has_bounds=True) - ds1.to_netcdf(self.file_path1) - ds2 = generate_dataset(cf_compliant=False, has_bounds=True) - ds2 = ds2.rename_vars({"ts": "tas"}) - ds2.to_netcdf(self.file_path2) - - result = open_mfdataset([self.file_path1, self.file_path2], decode_times=False) - - expected = ds1.merge(ds2) - assert result.identical(expected) - - def test_non_cf_compliant_time_is_decoded(self): - ds1 = generate_dataset(cf_compliant=False, has_bounds=False) - ds2 = generate_dataset(cf_compliant=False, has_bounds=False) - ds2 = ds2.rename_vars({"ts": "tas"}) - - ds1.to_netcdf(self.file_path1) - ds2.to_netcdf(self.file_path2) - - result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") # Generate an expected dataset with decoded non-CF compliant time units. - expected = generate_dataset(cf_compliant=True, has_bounds=True) + expected = ds.copy() expected["time"] = xr.DataArray( name="time", data=np.array( @@ -415,128 +357,129 @@ def test_non_cf_compliant_time_is_decoded(self): ), dims="time", ) + expected["time_bnds"] = xr.DataArray( name="time_bnds", data=np.array( [ [ cftime.DatetimeGregorian( - 1999, 12, 16, 12, 0, 0, 0, has_year_zero=False + 1999, 12, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 2, 15, 12, 0, 0, 0, has_year_zero=False + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 2, 15, 12, 0, 0, 0, has_year_zero=False + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 3, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 3, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 4, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 4, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 5, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 5, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 5, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 5, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 6, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 6, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 6, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 6, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 7, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 7, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 7, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 7, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 8, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 8, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 8, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 8, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 9, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 9, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 9, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 9, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 10, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 10, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 10, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 10, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 11, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 11, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 11, 16, 0, 0, 0, 0, has_year_zero=False + 2000, 11, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2000, 12, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 12, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2000, 12, 16, 12, 0, 0, 0, has_year_zero=False + 2000, 12, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2001, 1, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2001, 1, 16, 12, 0, 0, 0, has_year_zero=False + 2001, 1, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 2, 15, 0, 0, 0, 0, has_year_zero=False + 2001, 2, 1, 0, 0, 0, 0, has_year_zero=False ), ], [ cftime.DatetimeGregorian( - 2001, 2, 15, 0, 0, 0, 0, has_year_zero=False + 2001, 2, 1, 0, 0, 0, 0, has_year_zero=False ), cftime.DatetimeGregorian( - 2001, 3, 15, 0, 0, 0, 0, has_year_zero=False + 2001, 3, 1, 0, 0, 0, 0, has_year_zero=False ), ], ], @@ -552,209 +495,1022 @@ def test_non_cf_compliant_time_is_decoded(self): "standard_name": "time", "bounds": "time_bnds", } + + assert result.identical(expected) + + # Check encoding is preserved. expected.time.encoding = { + "zlib": False, + "szip": False, + "zstd": False, + "bzip2": False, + "blosc": False, + "shuffle": False, + "complevel": 0, + "fletcher32": False, + "contiguous": True, + "chunksizes": None, # Set source as result source because it changes every test run. "source": result.time.encoding["source"], - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, + "original_shape": (15,), + "dtype": np.dtype("int64"), + "units": "months since 2000-01-01", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "zlib": False, + "szip": False, + "zstd": False, + "bzip2": False, + "blosc": False, + "shuffle": False, + "complevel": 0, + "fletcher32": False, + "contiguous": True, + "chunksizes": None, + # Set source as result source because it changes every test run. + "source": result.time.encoding["source"], + "original_shape": (15, 2), + "dtype": np.dtype("int64"), "units": "months since 2000-01-01", "calendar": "standard", } - assert result.identical(expected) assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_keeps_specified_var(self): - ds1 = generate_dataset(cf_compliant=True, has_bounds=True) - ds2 = generate_dataset(cf_compliant=True, has_bounds=True) - ds2 = ds2.rename_vars({"ts": "tas"}) + def test_keeps_specified_var_and_preserves_bounds(self): + ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=True) + + # Create a modified version of the Dataset with a new var + ds_mod = ds.copy() + ds_mod["tas"] = ds_mod.ts.copy() # Suppress UserWarning regarding missing time.encoding "units" because # it is not relevant to this test. with warnings.catch_warnings(): warnings.simplefilter("ignore") - ds1.to_netcdf(self.file_path1) - ds2.to_netcdf(self.file_path2) + ds_mod.to_netcdf(self.file_path) - result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") + result = open_dataset(self.file_path, data_var="ts") + expected = ds.copy() - # Generate an expected dataset with decoded non-CF compliant time units. - expected = generate_dataset(cf_compliant=True, has_bounds=True) assert result.identical(expected) -class Test_HasCFCompliantTime: +class TestOpenMfDataset: @pytest.fixture(autouse=True) def setUp(self, tmp_path): # Create temporary directory to save files. - self.dir = tmp_path / "input_data" - self.dir.mkdir() - - # Paths to the dummy datasets. - self.file_path = f"{self.dir}/file.nc" - - def test_non_cf_compliant_time(self): - # Generate dummy dataset with non-CF compliant time units - ds = generate_dataset(cf_compliant=False, has_bounds=False) - ds.to_netcdf(self.file_path) - - result = _has_cf_compliant_time(self.file_path) - - # Check that False is returned when the dataset has non-cf_compliant time - assert result is False - - def test_no_time_axis(self): - # Generate dummy dataset with CF compliant time - ds = generate_dataset(cf_compliant=True, has_bounds=False) - # remove time axis - ds = ds.isel(time=0) - ds = ds.squeeze(drop=True) - ds = ds.reset_coords() - ds = ds.drop_vars("time") - ds.to_netcdf(self.file_path) - - result = _has_cf_compliant_time(self.file_path) - - # Check that None is returned when there is no time axis - assert result is None - - def test_glob_cf_compliant_time(self): - # Generate dummy datasets with CF compliant time - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) - - result = _has_cf_compliant_time(f"{self.dir}/*.nc") - - # Check that the wildcard path input is correctly evaluated - assert result is True - - def test_list_cf_compliant_time(self): - # Generate dummy datasets with CF compliant time units - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) - - flist = [self.file_path, self.file_path, self.file_path] - result = _has_cf_compliant_time(flist) - - # Check that the list input is correctly evaluated - assert result is True - - def test_cf_compliant_time_with_string_path(self): - # Generate dummy dataset with CF compliant time units - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) - - result = _has_cf_compliant_time(self.file_path) + dir = tmp_path / "input_data" + dir.mkdir() + self.file_path1 = f"{dir}/file1.nc" + self.file_path2 = f"{dir}/file2.nc" - # Check that True is returned when the dataset has cf_compliant time - assert result is True + def test_skip_decoding_times_explicitly(self): + ds1 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds1.to_netcdf(self.file_path1) + ds2 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds2 = ds2.rename_vars({"ts": "tas"}) + ds2.to_netcdf(self.file_path2) - def test_cf_compliant_time_with_pathlib_path(self): - # Generate dummy dataset with CF compliant time units - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) + result = open_mfdataset([self.file_path1, self.file_path2], decode_times=False) - result = _has_cf_compliant_time(pathlib.Path(self.file_path)) + expected = ds1.merge(ds2) + assert result.identical(expected) - # Check that True is returned when the dataset has cf_compliant time - assert result is True + def test_user_specified_callable_results_in_subsetting_dataset_on_time_slice(self): + def callable(ds): + return ds.isel(time=slice(0, 1)) - def test_cf_compliant_time_with_list_of_list_of_strings(self): - # Generate dummy dataset with CF compliant time units - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) + ds = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds.to_netcdf(self.file_path1) - result = _has_cf_compliant_time([self.file_path]) + result = open_mfdataset(self.file_path1, decode_times=True, preprocess=callable) + expected = ds.copy().isel(time=slice(0, 1)) + expected["time"] = xr.DataArray( + name="time", + data=np.array([cftime.datetime(2000, 1, 1)]), + dims=["time"], + ) + expected["time_bnds"] = xr.DataArray( + name="time_bnds", + data=np.array( + [[cftime.datetime(1999, 12, 1), cftime.datetime(2000, 1, 1)]], + ), + dims=["time", "bnds"], + ) + expected.time.attrs = { + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + } - # Check that True is returned when the dataset has cf_compliant time - assert result is True + expected.time_bnds.attrs = {"xcdat_bounds": "True"} - def test_cf_compliant_time_with_list_of_list_of_pathlib_paths(self): - # Generate dummy dataset with CF compliant time units - ds = generate_dataset(cf_compliant=True, has_bounds=False) - ds.to_netcdf(self.file_path) + assert result.identical(expected) - result = _has_cf_compliant_time([[pathlib.Path(self.file_path)]]) + def test_decode_time_in_months(self): + # Generate two dataset files with different variables. + ds1 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds1.to_netcdf(self.file_path1) - # Check that True is returned when the dataset has cf_compliant time - assert result is True + ds2 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds2 = ds2.rename_vars({"ts": "tas"}) + ds2.to_netcdf(self.file_path2) + # Open both dataset files as a single Dataset object. + result = open_mfdataset([self.file_path1, self.file_path2], data_var="ts") -class TestDecodeNonCFTimeUnits: - @pytest.fixture(autouse=True) - def setup(self): - time = xr.DataArray( + # Create an expected Dataset object. + expected = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) + expected["time"] = xr.DataArray( name="time", - data=[1, 2, 3], - dims=["time"], + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 5, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 6, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 7, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 8, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 9, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 10, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 11, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 12, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims="time", attrs={ - "bounds": "time_bnds", "axis": "T", "long_name": "time", "standard_name": "time", - # calendar attr and units is specified by test. + "bounds": "time_bnds", }, ) - time_bnds = xr.DataArray( + + expected["time_bnds"] = xr.DataArray( name="time_bnds", - data=[[0, 1], [1, 2], [2, 3]], + data=np.array( + [ + [ + cftime.DatetimeGregorian( + 1999, 12, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 5, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 5, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 6, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 6, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 7, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 7, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 8, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 8, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 9, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 9, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 10, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 10, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 11, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 11, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 12, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 12, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2001, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2001, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + ], + dtype="object", + ), dims=["time", "bnds"], + attrs={"xcdat_bounds": "True"}, ) - time_bnds.encoding = { + + # Make sure the expected is chunked. + expected = expected.chunk(chunks={"time": 15, "bnds": 2}) + + # Check encoding is preserved. The extra metadata like "zlib" are from + # the netCDF4 files. + expected.time.encoding = { "zlib": False, + "szip": False, + "zstd": False, + "bzip2": False, + "blosc": False, "shuffle": False, "complevel": 0, "fletcher32": False, - "contiguous": False, - "chunksizes": (1, 2), - "source": "None", - "original_shape": (1980, 2), - "dtype": np.dtype("float64"), + "contiguous": True, + "chunksizes": None, + # Set source as result source because it changes every test run. + "source": result.time.encoding["source"], + "original_shape": (15,), + "dtype": np.dtype("int64"), + "units": "months since 2000-01-01", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + + assert result.identical(expected) + assert result.time.encoding == expected.time.encoding + + # FIXME: For some reason the encoding attributes get dropped only in + # the test and not real-world datasets. + assert result.time_bnds.encoding != expected.time_bnds.encoding + + def test_keeps_specified_var_and_preserves_bounds(self): + # Generate two dataset files with different variables. + ds1 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds1.to_netcdf(self.file_path1) + + ds2 = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + ds2 = ds2.rename_vars({"ts": "tas"}) + ds2.to_netcdf(self.file_path2) + + # Open both dataset files as a single Dataset object. + result = open_mfdataset( + [self.file_path1, self.file_path2], data_var="ts", decode_times=False + ) + + # Create an expected Dataset object and check identical with result. + expected = generate_dataset( + decode_times=False, cf_compliant=False, has_bounds=True + ) + expected = expected.chunk(chunks={"time": 15, "bnds": 2}) + + assert result.identical(expected) + + +class TestDecodeTime: + @pytest.fixture(autouse=True) + def setup(self): + time = xr.DataArray( + name="time", + data=[1, 2, 3], + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + }, + ) + time_bnds = xr.DataArray( + name="time_bnds", + data=[[0, 1], [1, 2], [2, 3]], + dims=["time", "bnds"], + ) + self.ds = xr.Dataset({"time": time, "time_bnds": time_bnds}) + + def test_raises_error_if_no_time_coordinates_could_be_mapped_to(self, caplog): + ds = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + + # Remove time attributes and rename the coordinate variable before + # attempting to decode. + ds.time.attrs = {} + ds = ds.rename({"time": "invalid_time"}) + + with pytest.raises(KeyError): + decode_time(ds) + + def test_skips_decoding_time_coords_if_units_is_not_set(self, caplog): + # Update logger level to silence the logger warning during test runs. + caplog.set_level(logging.ERROR) + + ds = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + + del ds.time.attrs["units"] + + result = decode_time(ds) + assert ds.identical(result) + + def test_skips_decoding_time_coords_if_units_is_not_supported(self, caplog): + # Update logger level to silence the logger warning during test runs. + caplog.set_level(logging.ERROR) + + ds = generate_dataset(decode_times=False, cf_compliant=False, has_bounds=True) + + ds.time.attrs["units"] = "year AD" + + result = decode_time(ds) + assert ds.identical(result) + + def test_skips_decoding_time_bounds_if_bounds_dont_exist(self): + ds = xr.Dataset( + coords={ + "time": xr.DataArray( + name="time", + data=[1, 2, 3], + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + "calendar": "standard", + "units": "months since 2000-01-01", + }, + ), + "time2": xr.DataArray( + name="time2", + data=[1, 2, 3], + dims="time", + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + "calendar": "standard", + "units": "months since 2000-01-01", + }, + ), + }, + ) + + result = decode_time(ds) + expected = xr.Dataset( + coords={ + "time": xr.DataArray( + name="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + }, + ), + "time2": xr.DataArray( + name="time2", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + }, + ), + }, + ) + + assert result.identical(expected) + + # Check encoding is preserved. + expected.time.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + expected.time2.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + + assert result.time.encoding == expected.time.encoding + assert result.time2.encoding == expected.time.encoding + + def test_decodes_all_time_coordinates_and_time_bounds(self): + ds = xr.Dataset( + coords={ + "time": xr.DataArray( + name="time", + data=[1, 2, 3], + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + "calendar": "standard", + "units": "months since 2000-01-01", + }, + ), + "time2": xr.DataArray( + name="time2", + data=[1, 2, 3], + dims="time", + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + "calendar": "standard", + "units": "months since 2000-01-01", + }, + ), + }, + data_vars={ + "time_bnds": xr.DataArray( + name="time_bnds", + data=[[0, 1], [1, 2], [2, 3]], + dims=["time", "bnds"], + ) + }, + ) + + result = decode_time(ds) + expected = xr.Dataset( + coords={ + "time": xr.DataArray( + name="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + }, + ), + "time2": xr.DataArray( + name="time2", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + }, + ), + }, + data_vars={ + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + ], + dtype="object", + ), + dims=["time", "bnds"], + attrs=ds.time_bnds.attrs, + ), + }, + ) + + assert result.identical(expected) + + # Check the encoding is preserved. + expected.time.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + expected.time2.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + + assert result.time.encoding == expected.time.encoding + assert result.time2.encoding == expected.time2.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding + + def test_decodes_time_coords_and_bounds_without_calendar_attr_set(self, caplog): + # Update logger level to silence the logger warning during test runs. + caplog.set_level(logging.ERROR) + + ds = xr.Dataset( + coords={ + "time": xr.DataArray( + name="time", + data=[1, 2, 3], + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + "units": "months since 2000-01-01", + }, + ), + }, + data_vars={ + "time_bnds": xr.DataArray( + name="time_bnds", + data=[[0, 1], [1, 2], [2, 3]], + dims=["time", "bnds"], + ) + }, + ) + + result = decode_time(ds) + expected = xr.Dataset( + coords={ + "time": xr.DataArray( + name="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims=["time"], + attrs={ + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + }, + ), + }, + data_vars={ + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + ], + dtype="object", + ), + dims=["time", "bnds"], + ), + }, + ) + + assert result.identical(expected) + + # Check the encoding is preserved. + expected.time.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", + } + + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding + + def test_decode_time_in_days(self): + ds = generate_dataset(decode_times=False, cf_compliant=True, has_bounds=True) + + result = decode_time(ds) + + # Generate an expected dataset with decoded CF compliant time units. + expected = ds.copy() + expected["time"] = xr.DataArray( + name="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 2, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 3, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 4, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 5, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 6, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 7, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 8, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 9, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 10, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 11, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 12, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 13, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 14, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 15, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims="time", + ) + expected["time_bnds"] = xr.DataArray( + name="time_bnds", + data=np.array( + [ + [ + cftime.DatetimeGregorian( + 1999, 12, 31, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 2, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 2, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 3, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 3, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 4, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 4, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 5, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 5, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 6, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 6, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 7, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 7, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 8, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 8, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 9, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 9, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 10, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 10, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 11, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 11, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 12, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 12, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 13, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 13, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 14, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 1, 14, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 1, 15, 0, 0, 0, 0, has_year_zero=False + ), + ], + ], + dtype="object", + ), + dims=["time", "bnds"], + attrs={"xcdat_bounds": "True"}, + ) + expected.time.attrs = { + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", } - self.ds = xr.Dataset({"time": time, "time_bnds": time_bnds}) - - def test_returns_original_dataset_if_calendar_attr_is_not_set(self, caplog): - # Update logger level to silence the logger warning during test runs. - caplog.set_level(logging.ERROR) - ds = generate_dataset(cf_compliant=False, has_bounds=True) - - del ds.time.attrs["calendar"] - - result = decode_non_cf_time(ds) - assert ds.identical(result) - - def test_returns_original_dataset_if_units_attr_is_not_set(self, caplog): - # Update logger level to silence the logger warning during test runs. - caplog.set_level(logging.ERROR) - - ds = generate_dataset(cf_compliant=False, has_bounds=True) + assert result.identical(expected) - del ds.time.attrs["units"] + # Check encoding is preserved. + expected.time.encoding = { + "units": "days since 2000-01-01", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "days since 2000-01-01", + "calendar": "standard", + } - result = decode_non_cf_time(ds) - assert ds.identical(result) + assert result.time.encoding == expected.time.encoding + assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_returns_original_dataset_if_units_attr_is_in_an_unsupported_format( - self, caplog + def test_decodes_time_coords_and_bounds_in_months_with_a_reference_date_at_the_start_of_the_month( + self, ): - # Update logger level to silence the logger warning during test runs. - caplog.set_level(logging.ERROR) - - ds = generate_dataset(cf_compliant=False, has_bounds=True) - - ds.time.attrs["units"] = "year AD" - - result = decode_non_cf_time(ds) - assert ds.identical(result) - - def test_decodes_months_with_a_reference_date_at_the_start_of_the_month(self): ds = self.ds.copy() calendar = "standard" ds.time.attrs["calendar"] = calendar ds.time.attrs["units"] = "months since 2000-01-01" - result = decode_non_cf_time(ds) + result = decode_time(ds) expected = xr.Dataset( { "time": xr.DataArray( @@ -813,30 +1569,33 @@ def test_decodes_months_with_a_reference_date_at_the_start_of_the_month(self): dtype="object", ), dims=["time", "bnds"], - attrs=ds.time_bnds.attrs, ), } ) assert result.identical(expected) + # Check the encoding is preserved. expected.time.encoding = { - "source": "None", - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, - "units": ds.time.attrs["units"], - "calendar": calendar, + "units": "months since 2000-01-01", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "months since 2000-01-01", + "calendar": "standard", } - expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_decodes_months_with_a_reference_date_at_the_middle_of_the_month(self): + def test_decodes_time_coords_and_bounds_in_months_with_a_reference_date_at_the_middle_of_the_month( + self, + ): ds = self.ds.copy() calendar = "standard" ds.time.attrs["calendar"] = calendar ds.time.attrs["units"] = "months since 2000-01-15" - result = decode_non_cf_time(ds) + result = decode_time(ds) expected = xr.Dataset( { "time": xr.DataArray( @@ -881,24 +1640,28 @@ def test_decodes_months_with_a_reference_date_at_the_middle_of_the_month(self): ) assert result.identical(expected) + # Check the encoding is preserved. expected.time.encoding = { - "source": "None", - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, - "units": ds.time.attrs["units"], - "calendar": ds.time.attrs["calendar"], + "units": "months since 2000-01-15", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "months since 2000-01-15", + "calendar": "standard", } - expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_decodes_months_with_a_reference_date_at_the_end_of_the_month(self): + def test_decodes_time_coords_and_bounds_in_months_with_a_reference_date_at_the_end_of_the_month( + self, + ): ds = self.ds.copy() calendar = "standard" ds.time.attrs["calendar"] = calendar ds.time.attrs["units"] = "months since 1999-12-31" - result = decode_non_cf_time(ds) + result = decode_time(ds) expected = xr.Dataset( { "time": xr.DataArray( @@ -943,24 +1706,28 @@ def test_decodes_months_with_a_reference_date_at_the_end_of_the_month(self): ) assert result.identical(expected) + # Check the encoding is preserved. expected.time.encoding = { - "source": "None", - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, - "units": ds.time.attrs["units"], - "calendar": ds.time.attrs["calendar"], + "units": "months since 1999-12-31", + "calendar": "standard", } - expected.time_bnds.encoding = ds.time_bnds.encoding + expected.time_bnds.encoding = { + "units": "months since 1999-12-31", + "calendar": "standard", + } + assert result.time.encoding == expected.time.encoding assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_decodes_months_with_a_reference_date_on_a_leap_year(self): + def test_decodes_time_coords_and_bounds_in_months_with_a_reference_date_on_a_leap_year( + self, + ): ds = self.ds.copy() calendar = "standard" ds.time.attrs["calendar"] = calendar ds.time.attrs["units"] = "months since 2000-02-29" - result = decode_non_cf_time(ds) + result = decode_time(ds) expected = xr.Dataset( { @@ -1006,25 +1773,29 @@ def test_decodes_months_with_a_reference_date_on_a_leap_year(self): ) assert result.identical(expected) + # Check the encoding is preserved. expected.time.encoding = { - "source": "None", - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, - "units": ds.time.attrs["units"], - "calendar": ds.time.attrs["calendar"], + "units": "months since 2000-02-29", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "months since 2000-02-29", + "calendar": "standard", } - expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_decodes_years_with_a_reference_date_at_the_middle_of_the_year(self): + def test_decodes_time_coords_and_bounds_in_years_with_a_reference_date_in_the_mid_year( + self, + ): ds = self.ds.copy() calendar = "standard" ds.time.attrs["calendar"] = calendar ds.time.attrs["units"] = "years since 2000-06-01" - result = decode_non_cf_time(ds) + result = decode_time(ds) expected = xr.Dataset( { @@ -1070,25 +1841,29 @@ def test_decodes_years_with_a_reference_date_at_the_middle_of_the_year(self): ) assert result.identical(expected) + # Check the encoding is preserved. expected.time.encoding = { - "source": "None", - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, - "units": ds.time.attrs["units"], - "calendar": ds.time.attrs["calendar"], + "units": "years since 2000-06-01", + "calendar": "standard", } - expected.time_bnds.encoding = ds.time_bnds.encoding + expected.time_bnds.encoding = { + "units": "years since 2000-06-01", + "calendar": "standard", + } + assert result.time.encoding == expected.time.encoding assert result.time_bnds.encoding == expected.time_bnds.encoding - def test_decodes_years_with_a_reference_date_on_a_leap_year(self): + def test_decodes_time_coords_and_bounds_in_years_with_a_reference_date_on_a_leap_year( + self, + ): ds = self.ds.copy() calendar = "standard" ds.time.attrs["calendar"] = calendar ds.time.attrs["units"] = "years since 2000-02-29" - result = decode_non_cf_time(ds) + result = decode_time(ds) expected = xr.Dataset( { @@ -1134,14 +1909,16 @@ def test_decodes_years_with_a_reference_date_on_a_leap_year(self): ) assert result.identical(expected) + # Check the encoding is preserved. expected.time.encoding = { - "source": "None", - "dtype": np.dtype(np.int64), - "original_shape": expected.time.data.shape, - "units": ds.time.attrs["units"], - "calendar": ds.time.attrs["calendar"], + "units": "years since 2000-02-29", + "calendar": "standard", + } + expected.time_bnds.encoding = { + "units": "years since 2000-02-29", + "calendar": "standard", } - expected.time_bnds.encoding = ds.time_bnds.encoding + assert result.time.encoding == expected.time.encoding assert result.time_bnds.encoding == expected.time_bnds.encoding @@ -1149,47 +1926,156 @@ def test_decodes_years_with_a_reference_date_on_a_leap_year(self): class Test_PostProcessDataset: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) - def test_keeps_specified_var(self): - ds = generate_dataset(cf_compliant=True, has_bounds=True) + def test_centers_time_coords_and_converts_datetime_dtype_to_cftime_object_type( + self, + ): + ds = generate_dataset(decode_times=True, cf_compliant=False, has_bounds=True) - # Create a modified version of the Dataset with a new var - ds_mod = ds.copy() - ds_mod["tas"] = ds_mod.ts.copy() + # Create a dataset with uncentered time coordinates that are decoded as + # dtype="datetime[ns]" + ds_uncentered = ds.copy() + ds_uncentered["time"] = xr.DataArray( + data=np.array( + [ + "2000-01-31T12:00:00.000000000", + "2000-02-29T12:00:00.000000000", + "2000-03-31T12:00:00.000000000", + "2000-04-30T00:00:00.000000000", + "2000-05-31T12:00:00.000000000", + "2000-06-30T00:00:00.000000000", + "2000-07-31T12:00:00.000000000", + "2000-08-31T12:00:00.000000000", + "2000-09-30T00:00:00.000000000", + "2000-10-16T12:00:00.000000000", + "2000-11-30T00:00:00.000000000", + "2000-12-31T12:00:00.000000000", + "2001-01-31T12:00:00.000000000", + "2001-02-28T00:00:00.000000000", + "2001-12-31T12:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims=ds.time.dims, + attrs=ds.time.attrs, + ) + ds_uncentered.time.encoding = { + "source": None, + "original_shape": ds.time.data.shape, + "dtype": np.dtype("float64"), + "units": "days since 2000-01-01", + "calendar": "standard", + "_FillValue": False, + } - result = _postprocess_dataset(ds, data_var="ts") + # Compare result of the method against the expected. + result = _postprocess_dataset(ds_uncentered, center_times=True) expected = ds.copy() - assert result.identical(expected) + expected["time"] = xr.DataArray( + name="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 15, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 5, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 6, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 7, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 8, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 9, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 10, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 11, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 12, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 1, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 2, 15, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 12, 16, 12, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims="time", + attrs={ + "long_name": "time", + "standard_name": "time", + "axis": "T", + "bounds": "time_bnds", + }, + ) + + expected.time.encoding = { + "source": None, + "original_shape": (15,), + "dtype": np.dtype("float64"), + "units": "days since 2000-01-01", + "calendar": "standard", + "_FillValue": False, + } - def test_centers_time(self): - ds = generate_dataset(cf_compliant=True, has_bounds=True) + # Compare result of the function against the expected. + assert result.identical(expected) + assert result.time.encoding == expected.time.encoding + def test_centers_time_coordinates_and_maintains_cftime_object_type(self): + # Create a dataset with uncentered time coordinates + ds = generate_dataset(decode_times=True, cf_compliant=False, has_bounds=True) uncentered_time = np.array( [ - "2000-01-31T12:00:00.000000000", - "2000-02-29T12:00:00.000000000", - "2000-03-31T12:00:00.000000000", - "2000-04-30T00:00:00.000000000", - "2000-05-31T12:00:00.000000000", - "2000-06-30T00:00:00.000000000", - "2000-07-31T12:00:00.000000000", - "2000-08-31T12:00:00.000000000", - "2000-09-30T00:00:00.000000000", - "2000-10-16T12:00:00.000000000", - "2000-11-30T00:00:00.000000000", - "2000-12-31T12:00:00.000000000", - "2001-01-31T12:00:00.000000000", - "2001-02-28T00:00:00.000000000", - "2001-12-31T12:00:00.000000000", + cftime.DatetimeGregorian(2000, 1, 31, 12, 0, 0, 0), + cftime.DatetimeGregorian(2000, 2, 29, 12, 0, 0, 0), + cftime.DatetimeGregorian(2000, 3, 31, 12, 0, 0, 0), + cftime.DatetimeGregorian(2000, 4, 30, 0, 0, 0, 0), + cftime.DatetimeGregorian(2000, 5, 31, 12, 0, 0, 0), + cftime.DatetimeGregorian(2000, 6, 30, 0, 0, 0, 0), + cftime.DatetimeGregorian(2000, 7, 31, 12, 0, 0, 0), + cftime.DatetimeGregorian(2000, 8, 31, 12, 0, 0, 0), + cftime.DatetimeGregorian(2000, 9, 30, 0, 0, 0, 0), + cftime.DatetimeGregorian(2000, 10, 16, 12, 0, 0, 0), + cftime.DatetimeGregorian(2000, 11, 30, 0, 0, 0, 0), + cftime.DatetimeGregorian(2000, 12, 31, 12, 0, 0, 0), + cftime.DatetimeGregorian(2001, 1, 31, 12, 0, 0, 0), + cftime.DatetimeGregorian(2001, 2, 28, 0, 0, 0, 0), + cftime.DatetimeGregorian(2001, 12, 31, 12, 0, 0, 0), ], - dtype="datetime64[ns]", + dtype="object", ) ds.time.data[:] = uncentered_time ds.time.encoding = { "source": None, - "dtype": np.dtype(np.int64), "original_shape": ds.time.data.shape, + "dtype": np.dtype("float64"), "units": "days since 2000-01-01", "calendar": "standard", "_FillValue": False, @@ -1198,70 +2084,94 @@ def test_centers_time(self): # Compare result of the method against the expected. result = _postprocess_dataset(ds, center_times=True) expected = ds.copy() - expected_time_data = np.array( - [ - "2000-01-16T12:00:00.000000000", - "2000-02-15T12:00:00.000000000", - "2000-03-16T12:00:00.000000000", - "2000-04-16T00:00:00.000000000", - "2000-05-16T12:00:00.000000000", - "2000-06-16T00:00:00.000000000", - "2000-07-16T12:00:00.000000000", - "2000-08-16T12:00:00.000000000", - "2000-09-16T00:00:00.000000000", - "2000-10-16T12:00:00.000000000", - "2000-11-16T00:00:00.000000000", - "2000-12-16T12:00:00.000000000", - "2001-01-16T12:00:00.000000000", - "2001-02-15T00:00:00.000000000", - "2001-12-16T12:00:00.000000000", - ], - dtype="datetime64[ns]", - ) - expected = expected.assign_coords( - { - "time": xr.DataArray( - name="time", - data=expected_time_data, - coords={"time": expected_time_data}, - dims="time", - attrs={ - "long_name": "time", - "standard_name": "time", - "axis": "T", - "bounds": "time_bnds", - }, - ) - } + expected["time"] = xr.DataArray( + name="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 15, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 5, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 6, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 7, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 8, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 9, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 10, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 11, 16, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 12, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 1, 16, 12, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 2, 15, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2001, 12, 16, 12, 0, 0, 0, has_year_zero=False + ), + ], + dtype="object", + ), + dims="time", + attrs={ + "long_name": "time", + "standard_name": "time", + "axis": "T", + "bounds": "time_bnds", + }, ) + expected.time.encoding = { "source": None, - "dtype": np.dtype("int64"), "original_shape": (15,), + "dtype": np.dtype("float64"), "units": "days since 2000-01-01", "calendar": "standard", "_FillValue": False, } # Update time bounds with centered time coordinates. - time_bounds = ds.time_bnds.copy() - time_bounds["time"] = expected.time - expected["time_bnds"] = time_bounds + expected["time_bnds"] = ds.time_bnds.copy() + expected["time_bnds"]["time"] = expected.time # Compare result of the function against the expected. assert result.identical(expected) assert result.time.encoding == expected.time.encoding def test_raises_error_if_dataset_has_no_time_coords_but_center_times_is_true(self): - ds = generate_dataset(cf_compliant=True, has_bounds=False) + ds = generate_dataset(decode_times=True, cf_compliant=False, has_bounds=False) ds = ds.drop_dims("time") - with pytest.raises(ValueError): + with pytest.raises(KeyError): _postprocess_dataset(ds, center_times=True) def test_adds_missing_lat_and_lon_bounds(self): # Create expected dataset without bounds. - ds = generate_dataset(cf_compliant=True, has_bounds=False) + ds = generate_dataset(decode_times=True, cf_compliant=False, has_bounds=False) data_vars = list(ds.data_vars.keys()) assert "lat_bnds" not in data_vars @@ -1303,9 +2213,7 @@ def test_orients_longitude_bounds_from_180_to_360_and_sorts_with_prime_meridian_ }, ).chunk({"lon": 2}) - result = _postprocess_dataset( - ds, data_var=None, center_times=False, add_bounds=True, lon_orient=(0, 360) - ) + result = _postprocess_dataset(ds, lon_orient=(0, 360)) expected = xr.Dataset( coords={ "lon": xr.DataArray( @@ -1338,18 +2246,19 @@ def test_orients_longitude_bounds_from_180_to_360_and_sorts_with_prime_meridian_ def test_raises_error_if_dataset_has_no_longitude_coords_but_lon_orient_is_specified( self, ): - ds = generate_dataset(cf_compliant=True, has_bounds=False) - + ds = generate_dataset(decode_times=True, cf_compliant=False, has_bounds=False) ds = ds.drop_dims("lon") - with pytest.raises(ValueError): + with pytest.raises(KeyError): _postprocess_dataset(ds, lon_orient=(0, 360)) class Test_KeepSingleVar: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) self.ds_mod = self.ds.copy() self.ds_mod["tas"] = self.ds_mod.ts.copy() @@ -1385,53 +2294,3 @@ def test_bounds_always_persist(self): assert ds.get("lat_bnds") is not None assert ds.get("lon_bnds") is not None assert ds.get("time_bnds") is not None - - -class Test_PreProcessNonCFDataset: - @pytest.fixture(autouse=True) - def setup(self): - self.ds = generate_dataset(cf_compliant=False, has_bounds=True) - - def test_user_specified_callable_results_in_subsetting_dataset_on_time_slice(self): - def callable(ds): - return ds.isel(time=slice(0, 1)) - - ds = self.ds.copy() - - result = _preprocess_non_cf_dataset(ds, callable) - expected = ds.copy().isel(time=slice(0, 1)) - expected["time"] = xr.DataArray( - name="time", - data=np.array([cftime.datetime(2000, 1, 1)]), - dims=["time"], - ) - expected["time_bnds"] = xr.DataArray( - name="time_bnds", - data=np.array( - [[cftime.datetime(1999, 12, 1), cftime.datetime(2000, 1, 1)]], - ), - dims=["time", "bnds"], - ) - expected.time.attrs = { - "axis": "T", - "long_name": "time", - "standard_name": "time", - "bounds": "time_bnds", - } - - expected.time_bnds.attrs = {"xcdat_bounds": "True"} - - assert result.identical(expected) - - -class Test_SplitTimeUnitsAttr: - def test_splits_units_attr_to_unit_and_reference_date(self): - assert _split_time_units_attr("months since 1800") == ("months", "1800") - assert _split_time_units_attr("months since 1800-01-01") == ( - "months", - "1800-01-01", - ) - assert _split_time_units_attr("months since 1800-01-01 00:00:00") == ( - "months", - "1800-01-01 00:00:00", - ) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index b2f65c35..c5219b21 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -141,7 +141,9 @@ def setup(self): @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") def test_output_bounds(self): - ds = fixtures.generate_dataset(cf_compliant=True, has_bounds=True) + ds = fixtures.generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) output_grid = grid.create_gaussian_grid(32) @@ -225,6 +227,7 @@ def test_unknown_variable(self): with pytest.raises(KeyError): regridder.horizontal("unknown", self.coarse_2d_ds) + @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") def test_regrid_input_mask(self): regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds) @@ -416,7 +419,9 @@ def test_reversed_extract_bounds(self): class TestXESMFRegridder: @pytest.fixture(autouse=True) def setup(self): - self.ds = fixtures.generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = fixtures.generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) self.new_grid = grid.create_uniform_grid(-90, 90, 4.0, -180, 180, 5.0) @pytest.mark.xfail @@ -481,7 +486,9 @@ def test_invalid_extra_method(self): ) def test_preserve_bounds(self): - ds = fixtures.generate_dataset(cf_compliant=True, has_bounds=True) + ds = fixtures.generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) ds = ds.drop_vars(["lat_bnds", "lon_bnds"]) @@ -573,6 +580,75 @@ def test_global_mean_grid(self): assert np.all(mean_grid.lon == np.array([180.0])) assert np.all(mean_grid.lon_bnds == np.array([[-22.5, 405]])) + def test_raises_error_for_global_mean_grid_if_an_axis_has_multiple_dimensions(self): + source_grid = xr.Dataset( + coords={ + "lat": xr.DataArray( + name="lat", + data=np.array([-80, -40, 0, 40, 80]), + dims="lat", + attrs={"units": "degrees_north", "axis": "Y", "bounds": "lat_bnds"}, + ), + "lon": xr.DataArray( + name="lon", + data=np.array([0, 45, 90, 180, 270, 360]), + dims="lon", + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), + }, + data_vars={ + "lat_bnds": xr.DataArray( + name="lat_bnds", + data=np.array( + [ + [-90.0, -60.0], + [-60.0, -20.0], + [-20.0, 20.0], + [20.0, 60.0], + [60.0, 90.0], + ] + ), + dims=["lat", "bnds"], + attrs={"units": "degrees_north", "axis": "Y", "bounds": "lat_bnds"}, + ), + "lon_bnds": xr.DataArray( + name="lon_bnds", + data=np.array( + [ + [-22.5, 22.5], + [22.5, 67.5], + [67.5, 135.0], + [135.0, 225.0], + [225.0, 315.0], + [315.0, 405.0], + ] + ), + dims=["lon", "bnds"], + attrs={"xcdat_bounds": True}, + ), + }, + ) + + source_grid_with_2_lats = source_grid.copy() + source_grid_with_2_lats["lat2"] = xr.DataArray( + name="lat2", + data=np.array([-80, -40, 0, 40, 80]), + dims="lat2", + attrs={"units": "degrees_north", "axis": "Y", "bounds": "lat_bnds"}, + ) + with pytest.raises(ValueError): + grid.create_global_mean_grid(source_grid_with_2_lats) + + source_grid_with_2_lons = source_grid.copy() + source_grid_with_2_lons["lon2"] = xr.DataArray( + name="lon2", + data=np.array([0, 45, 90, 180, 270, 360]), + dims="lon2", + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ) + with pytest.raises(ValueError): + grid.create_global_mean_grid(source_grid_with_2_lons) + def test_zonal_grid(self): source_grid = grid.create_grid( np.array([-80, -40, 0, 40, 80]), np.array([-160, -80, 80, 160]) @@ -588,6 +664,62 @@ def test_zonal_grid(self): assert np.all(zonal_grid.lon == np.array([0.0])) assert np.all(zonal_grid.lon_bnds == np.array([-200, 200])) + def test_raises_error_for_zonal_grid_if_an_axis_has_multiple_dimensions(self): + source_grid = xr.Dataset( + coords={ + "lat": xr.DataArray( + name="lat", + data=np.array([-80, -40, 0, 40, 80]), + dims="lat", + attrs={"units": "degrees_north", "axis": "Y", "bounds": "lat_bnds"}, + ), + "lon": xr.DataArray( + name="lon", + data=np.array([-160, -80, 80, 160]), + dims="lon", + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ), + }, + data_vars={ + "lat_bnds": xr.DataArray( + name="lat_bnds", + data=np.array( + [[-90, -60], [-60, -20], [-20, 20], [20, 60], [60, 90]] + ), + dims=["lat", "bnds"], + attrs={"units": "degrees_north", "axis": "Y", "bounds": "lat_bnds"}, + ), + "lon_bnds": xr.DataArray( + name="lon_bnds", + data=np.array( + [[-200.0, -120.0], [-120.0, 0.0], [0.0, 120.0], [120.0, 200.0]] + ), + dims=["lon", "bnds"], + attrs={"xcdat_bounds": True}, + ), + }, + ) + + source_grid_with_2_lats = source_grid.copy() + source_grid_with_2_lats["lat2"] = xr.DataArray( + name="lat2", + data=np.array([-80, -40, 0, 40, 80]), + dims="lat2", + attrs={"units": "degrees_north", "axis": "Y", "bounds": "lat_bnds"}, + ) + with pytest.raises(ValueError): + grid.create_zonal_grid(source_grid_with_2_lats) + + source_grid_with_2_lons = source_grid.copy() + source_grid_with_2_lons["lon2"] = xr.DataArray( + name="lon2", + data=np.array([0, 45, 90, 180, 270, 360]), + dims="lon2", + attrs={"units": "degrees_east", "axis": "X", "bounds": "lon_bnds"}, + ) + with pytest.raises(ValueError): + grid.create_zonal_grid(source_grid_with_2_lons) + class TestAccessor: @pytest.fixture(autouse=True) @@ -596,7 +728,9 @@ def setup(self): self.ac = accessor.RegridderAccessor(self.data) def test_grid_missing_axis(self): - ds = fixtures.generate_dataset(True, True) + ds = fixtures.generate_dataset( + decode_times=True, cf_compliant=True, has_bounds=True + ) ds_no_lat = ds.drop_dims(["lat"]) @@ -609,7 +743,9 @@ def test_grid_missing_axis(self): ds_no_lon.regridder.grid def test_grid(self): - ds_bounds = fixtures.generate_dataset(True, True) + ds_bounds = fixtures.generate_dataset( + decode_times=True, cf_compliant=True, has_bounds=True + ) grid = ds_bounds.regridder.grid @@ -618,7 +754,9 @@ def test_grid(self): assert "lat_bnds" in grid assert "lon_bnds" in grid - ds_no_bounds = fixtures.generate_dataset(True, False) + ds_no_bounds = fixtures.generate_dataset( + decode_times=True, cf_compliant=True, has_bounds=False + ) grid = ds_no_bounds.regridder.grid @@ -627,6 +765,17 @@ def test_grid(self): assert "lat_bnds" in grid assert "lon_bnds" in grid + def test_grid_raises_error_when_dataset_has_multiple_dims_for_an_axis(self): + ds_bounds = fixtures.generate_dataset( + decode_times=True, cf_compliant=True, has_bounds=True + ) + ds_bounds.coords["lat2"] = xr.DataArray( + data=[], dims="lat2", attrs={"axis": "Y"} + ) + + with pytest.raises(ValueError): + ds_bounds.regridder.grid + def test_valid_tool(self): mock_regridder = mock.MagicMock() mock_regridder.return_value.horizontal.return_value = "output data" @@ -649,7 +798,9 @@ def test_invalid_tool(self): @requires_xesmf @pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning") def test_convenience_methods(self): - ds = fixtures.generate_dataset(cf_compliant=True, has_bounds=True) + ds = fixtures.generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) out_grid = grid.create_gaussian_grid(32) @@ -665,7 +816,9 @@ def test_convenience_methods(self): def test_raises_error_if_xesmf_is_not_installed(self): # TODO Find a way to mock the value of `_has_xesmf` to False or # to remove the `xesmf` module entirely - ds = fixtures.generate_dataset(cf_compliant=True, has_bounds=True) + ds = fixtures.generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) out_grid = grid.create_gaussian_grid(32) with pytest.raises(ModuleNotFoundError): @@ -674,7 +827,9 @@ def test_raises_error_if_xesmf_is_not_installed(self): class TestBase: def test_preserve_bounds(self): - ds_with_bounds = fixtures.generate_dataset(cf_compliant=True, has_bounds=True) + ds_with_bounds = fixtures.generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) ds_without_bounds = ds_with_bounds.drop_vars(["lat_bnds", "lon_bnds"]) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index 6b5df530..1a486c5b 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -10,7 +10,9 @@ class TestSpatialAccessor: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test__init__(self): ds = self.ds.copy() @@ -28,7 +30,9 @@ def test_decorator_call(self): class TestAverage: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) # Limit to just 3 data points to simplify testing. self.ds = self.ds.isel(time=slice(None, 3)) @@ -48,9 +52,13 @@ def test_raises_error_if_axis_list_contains_unsupported_axis(self): def test_raises_error_if_lat_axis_coords_cant_be_found(self): ds = self.ds.copy() - ds = ds.rename_dims({"lat": "invalid_lat"}) + # Update CF metadata to invalid values so cf_xarray can't interpret them. del ds.lat.attrs["axis"] del ds.lat.attrs["standard_name"] + del ds.lat.attrs["units"] + # Update coordinate name. + ds = ds.rename({"lat": "invalid_lat"}) + ds = ds.set_index(invalid_lat="invalid_lat") with pytest.raises(KeyError): ds.spatial.average("ts", axis=["X", "Y"]) @@ -58,9 +66,14 @@ def test_raises_error_if_lat_axis_coords_cant_be_found(self): def test_raises_error_if_lon_axis_coords_cant_be_found(self): ds = self.ds.copy() - ds = ds.rename_dims({"lon": "invalid_lon"}) + # Update CF metadata to invalid values so cf_xarray can't interpret them. del ds.lon.attrs["axis"] del ds.lon.attrs["standard_name"] + del ds.lon.attrs["units"] + # Update coordinate name. + ds = ds.rename({"lon": "invalid_lon"}) + ds = ds.set_index(invalid_lon="invalid_lon") + with pytest.raises(KeyError): ds.spatial.average("ts", axis=["X", "Y"]) @@ -252,7 +265,9 @@ def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self): class TestGetWeights: @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_bounds_reordered_when_upper_indexed_first(self): domain_bounds = xr.DataArray( @@ -275,7 +290,42 @@ def test_bounds_reordered_when_upper_indexed_first(self): ) assert result.identical(expected_domain_bounds) - def test_weights_for_region_in_lat_and_lon_domains(self): + def test_raises_error_if_dataset_has_multiple_bounds_variables_for_an_axis(self): + ds = self.ds.copy() + + # Create a second "Y" axis dimension and associated bounds + ds["lat2"] = ds.lat.copy() + ds["lat2"].name = "lat2" + ds["lat2"].attrs["bounds"] = "lat_bnds2" + ds["lat_bnds2"] = ds.lat_bnds.copy() + ds["lat_bnds2"].name = "lat_bnds2" + + # Check raises error when there are > 1 bounds for the dataset. + with pytest.raises(TypeError): + ds.spatial.get_weights(axis=["Y", "X"]) + + def test_data_var_weights_for_region_in_lat_and_lon_domains(self): + ds = self.ds.copy() + + result = ds.spatial.get_weights( + axis=["Y", "X"], lat_bounds=(-5, 5), lon_bounds=(-170, -120), data_var="ts" + ) + expected = xr.DataArray( + data=np.array( + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 4.35778714, 0.0], + [0.0, 0.0, 4.35778714, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + ), + coords={"lat": self.ds.lat, "lon": self.ds.lon}, + dims=["lat", "lon"], + ) + + xr.testing.assert_allclose(result, expected) + + def test_dataset_weights_for_region_in_lat_and_lon_domains(self): result = self.ds.spatial.get_weights( axis=["Y", "X"], lat_bounds=(-5, 5), lon_bounds=(-170, -120) ) @@ -294,7 +344,7 @@ def test_weights_for_region_in_lat_and_lon_domains(self): xr.testing.assert_allclose(result, expected) - def test_area_weights_for_region_in_lat_domain(self): + def test_dataset_weights_for_region_in_lat_domain(self): result = self.ds.spatial.get_weights( axis=["Y", "X"], lat_bounds=(-5, 5), lon_bounds=None ) @@ -332,7 +382,9 @@ def test_weights_for_region_in_lon_domain(self): xr.testing.assert_allclose(result, expected) - def test_weights_for_region_in_lon_domain_with_region_spanning_p_meridian(self): + def test_dataset_weights_for_region_in_lon_domain_with_region_spanning_p_meridian( + self, + ): ds = self.ds.copy() result = ds.spatial._get_longitude_weights( @@ -348,7 +400,7 @@ def test_weights_for_region_in_lon_domain_with_region_spanning_p_meridian(self): xr.testing.assert_allclose(result, expected) - def test_weights_all_longitudes_for_equal_region_bounds(self): + def test_dataset_weights_all_longitudes_for_equal_region_bounds(self): expected = xr.DataArray( data=np.array( [1.875, 178.125, 178.125, 1.875], @@ -357,14 +409,14 @@ def test_weights_all_longitudes_for_equal_region_bounds(self): dims=["lon"], ) result = self.ds.spatial.get_weights( - axis=["X"], - lat_bounds=None, - lon_bounds=np.array([0.0, 360.0]), + axis=["X"], lat_bounds=None, lon_bounds=np.array([0.0, 360.0]) ) xr.testing.assert_allclose(result, expected) - def test_weights_for_equal_region_bounds_representing_entire_lon_domain(self): + def test_dataset_weights_for_equal_region_bounds_representing_entire_lon_domain( + self, + ): expected = xr.DataArray( data=np.array( [1.875, 178.125, 178.125, 1.875], @@ -384,7 +436,9 @@ class Test_SwapLonAxis: # converting it to a public method in the future. @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_raises_error_with_incorrect_orientation_to_swap_to(self): domain = xr.DataArray( @@ -532,7 +586,9 @@ class Test_ScaleDimToRegion: # that has edge cases with some complexities. @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) @requires_dask def test_scales_chunked_lat_bounds_when_not_wrapping_around_prime_meridian(self): diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 8d307107..fc25ef34 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -11,24 +11,32 @@ class TestTemporalAccessor: def test__init__(self): - ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) + ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) obj = TemporalAccessor(ds) assert obj._dataset.identical(ds) def test_decorator(self): - ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) + ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) obj = ds.temporal assert obj._dataset.identical(ds) - def test_raises_error_if_calendar_encoding_attr_is_not_set(self): - ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) - ds.time.encoding = {} - with pytest.raises(KeyError): - TemporalAccessor(ds) +class TestAverage: + def test_raises_error_if_calendar_encoding_attr_not_found_on_data_var_time_coords( + self, + ): + ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) + ds.ts.time.encoding = {} + with pytest.raises(KeyError): + ds.temporal.average("ts") -class TestAverage: def test_averages_for_yearly_time_series(self): ds = xr.Dataset( coords={ @@ -422,6 +430,17 @@ def setup(self): attrs={"test_attr": "test"}, ) + def test_raises_error_if_calendar_encoding_attr_not_found_on_data_var_time_coords( + self, + ): + ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) + ds.ts.time.encoding = {} + + with pytest.raises(KeyError): + ds.temporal.group_average("ts", freq="year") + def test_weighted_annual_averages(self): ds = self.ds.copy() @@ -971,7 +990,20 @@ class TestClimatology: # for better test reliability and accuracy. This may require subsetting. @pytest.fixture(autouse=True) def setup(self): - self.ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) + + def test_raises_error_if_calendar_encoding_attr_not_found_on_data_var_time_coords( + self, + ): + ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) + ds.ts.time.encoding = {} + + with pytest.raises(KeyError): + ds.temporal.climatology("ts", freq="season") def test_weighted_seasonal_climatology_with_DJF(self): ds = self.ds.copy() @@ -1542,10 +1574,23 @@ class TestDepartures: # better test reliability and accuracy. This may require subsetting. @pytest.fixture(autouse=True) def setup(self): - self.ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) self.seasons = ["JJA", "MAM", "SON", "DJF"] + def test_raises_error_if_calendar_encoding_attr_not_found_on_data_var_time_coords( + self, + ): + ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) + ds.ts.time.encoding = {} + + with pytest.raises(KeyError): + ds.temporal.departures("ts", freq="year") + def test_weighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -1823,12 +1868,15 @@ class Test_GetWeights: class TestWeightsForAverageMode: @pytest.fixture(autouse=True) def setup(self): - self.ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_weights_for_yearly_averages(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "average" ds.temporal._freq = "year" ds.temporal._weighted = "True" @@ -1892,6 +1940,7 @@ def test_weights_for_monthly_averages(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "average" ds.temporal._freq = "month" ds.temporal._weighted = "True" @@ -1953,12 +2002,15 @@ def test_weights_for_monthly_averages(self): class TestWeightsForGroupAverageMode: @pytest.fixture(autouse=True) def setup(self): - self.ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_weights_for_yearly_averages(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "group_average" ds.temporal._freq = "year" ds.temporal._weighted = "True" @@ -2022,6 +2074,7 @@ def test_weights_for_monthly_averages(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "group_average" ds.temporal._freq = "month" ds.temporal._weighted = "True" @@ -2166,6 +2219,7 @@ def test_weights_for_seasonal_averages_with_DJF_and_drop_incomplete_seasons( ) # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "group_average" ds.temporal._freq = "season" ds.temporal._weighted = "True" @@ -2228,6 +2282,7 @@ def test_weights_for_seasonal_averages_with_JFD(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "group_average" ds.temporal._freq = "season" ds.temporal._weighted = "True" @@ -2298,6 +2353,7 @@ def test_custom_season_time_series_weights(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "group_average" ds.temporal._freq = "season" ds.temporal._weighted = "True" @@ -2375,6 +2431,7 @@ def test_weights_for_daily_averages(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "group_average" ds.temporal._freq = "day" ds.temporal._weighted = "True" @@ -2420,6 +2477,7 @@ def test_weights_for_hourly_averages(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "group_average" ds.temporal._freq = "hour" ds.temporal._weighted = "True" @@ -2465,7 +2523,9 @@ def test_weights_for_hourly_averages(self): class TestWeightsForClimatologyMode: @pytest.fixture(autouse=True) def setup(self): - self.ds: xr.Dataset = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds: xr.Dataset = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_weights_for_seasonal_climatology_with_DJF(self): ds = self.ds.copy() @@ -2568,6 +2628,7 @@ def test_weights_for_seasonal_climatology_with_DJF(self): ) # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "climatology" ds.temporal._freq = "season" ds.temporal._weighted = "True" @@ -2625,6 +2686,7 @@ def test_weights_for_seasonal_climatology_with_JFD(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "climatology" ds.temporal._freq = "season" ds.temporal._weighted = "True" @@ -2690,6 +2752,7 @@ def test_weights_for_annual_climatology(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "climatology" ds.temporal._freq = "month" ds.temporal._weighted = "True" @@ -2753,6 +2816,7 @@ def test_weights_for_daily_climatology(self): ds = self.ds.copy() # Set object attrs required to test the method. + ds.temporal.dim = "time" ds.temporal._mode = "climatology" ds.temporal._freq = "day" ds.temporal._weighted = "True" @@ -2816,7 +2880,9 @@ class Test_Averager: # test these cases for the public methods that call this private method. @pytest.fixture(autouse=True) def setup(self): - self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ds = generate_dataset( + decode_times=True, cf_compliant=False, has_bounds=True + ) def test_raises_error_with_incorrect_mode_arg(self): with pytest.raises(ValueError): diff --git a/xcdat/__init__.py b/xcdat/__init__.py index 9f8e5c63..f526145d 100644 --- a/xcdat/__init__.py +++ b/xcdat/__init__.py @@ -1,12 +1,12 @@ """Top-level package for xcdat.""" from xcdat.axis import ( # noqa: F401 center_times, - get_axis_coord, - get_axis_dim, + get_dim_coords, + get_dim_keys, swap_lon_axis, ) from xcdat.bounds import BoundsAccessor # noqa: F401 -from xcdat.dataset import decode_non_cf_time, open_dataset, open_mfdataset # noqa: F401 +from xcdat.dataset import decode_time, open_dataset, open_mfdataset # noqa: F401 from xcdat.regridder.accessor import RegridderAccessor # noqa: F401 from xcdat.regridder.grid import ( # noqa: F401 create_gaussian_grid, diff --git a/xcdat/axis.py b/xcdat/axis.py index 326b7776..70aeef16 100644 --- a/xcdat/axis.py +++ b/xcdat/axis.py @@ -6,106 +6,145 @@ import numpy as np import xarray as xr -from dask.array.core import Array +from cf_xarray.criteria import coordinate_criteria + +from xcdat.utils import _if_multidim_dask_array_then_load # https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axis-names -CFAxisName = Literal["X", "Y", "T", "Z"] +CFAxisKey = Literal["X", "Y", "T", "Z"] # https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#coordinate-names -CFStandardName = Literal["latitude", "longitude", "time", "height", "pressure"] -ShortName = Literal["lat", "lon"] - -# The key is the accepted value for method and function arguments, and the -# values are the CF-compliant axis and standard names that are interpreted in -# the dataset. -CF_NAME_MAP: Dict[CFAxisName, List[Union[CFAxisName, CFStandardName, ShortName]]] = { - "X": ["X", "longitude", "lon"], - "Y": ["Y", "latitude", "lat"], - "T": ["T", "time"], - "Z": ["Z", "height", "pressure"], +CFStandardNameKey = Literal[ + "latitude", "longitude", "time", "vertical", "height", "pressure" +] + +# A dictionary that maps the xCDAT `axis` arguments to keys used for `cf_xarray` +# accessor class indexing. For example, if we pass `axis="X"` to a function, +# we can fetch specific `cf_xarray` mapping tables such as `ds.cf.axes["X"]` +# or `ds.cf.coordinates["longitude"]`. +# More information: https://cf-xarray.readthedocs.io/en/latest/coord_axes.html +CF_ATTR_MAP: Dict[CFAxisKey, Dict[str, Union[CFAxisKey, CFStandardNameKey]]] = { + "X": {"axis": "X", "coordinate": "longitude"}, + "Y": {"axis": "Y", "coordinate": "latitude"}, + "T": {"axis": "T", "coordinate": "time"}, + "Z": {"axis": "Z", "coordinate": "vertical"}, } +# A dictionary that maps common variable names to coordinate variables. This +# map is used as fall-back when coordinate variables don't have CF attributes +# set for ``cf_xarray`` to interpret using `CF_ATTR_MAP`. +VAR_NAME_MAP: Dict[CFAxisKey, List[str]] = { + "X": ["longitude", "lon"], + "Y": ["latitude", "lat"], + "T": ["time"], + "Z": coordinate_criteria["Z"]["standard_name"], +} -def get_axis_coord( - obj: Union[xr.Dataset, xr.DataArray], axis: CFAxisName -) -> xr.DataArray: - """Gets the coordinate variable for an axis. - This function uses ``cf_xarray`` to try to find the matching coordinate - variable by checking the following attributes in order: +def get_dim_keys( + obj: Union[xr.Dataset, xr.DataArray], axis: CFAxisKey +) -> Union[str, List[str]]: + """Gets the dimension key(s) for an axis. - - ``"axis"`` - - ``"standard_name"`` - - Dimension name + Each dimension should have a corresponding dimension coordinate variable, + which has a 1:1 map in keys and is denoted by the * symbol when printing out + the xarray object. - - Must follow the valid short-hand convention - - For example, ``"lat"`` for latitude and ``"lon"`` for longitude Parameters ---------- obj : Union[xr.Dataset, xr.DataArray] The Dataset or DataArray object. - axis : CFAxisName - The CF-compliant axis name ("X", "Y", "T", "Z"). + axis : CFAxisKey + The CF axis key ("X", "Y", "T", or "Z") Returns ------- - xr.DataArray - The coordinate variable. + Union[str, List[str]] + The dimension string or a list of dimensions strings for an axis. + """ + dims = sorted([str(dim) for dim in get_dim_coords(obj, axis).dims]) + + return dims[0] if len(dims) == 1 else dims + + +def get_dim_coords( + obj: Union[xr.Dataset, xr.DataArray], axis: CFAxisKey +) -> Union[xr.Dataset, xr.DataArray]: + """Gets the dimension coordinates for an axis. + + This function uses ``cf_xarray`` to attempt to map the axis to its + dimension coordinates by interpreting the CF axis and coordinate names + found in the coordinate attributes. Refer to [1]_ for a list of CF axis and + coordinate names that can be interpreted by ``cf_xarray``. + + If ``obj`` is an ``xr.Dataset,``, this function can return a single + dimension coordinate variable as an ``xr.DataArray`` or multiple dimension + coordinate variables in an ``xr Dataset``. If ``obj`` is an ``xr.DataArray``, + this function should return a single dimension coordinate variable as an + ``xr.DataArray``. + + Parameters + ---------- + obj : Union[xr.Dataset, xr.DataArray] + The Dataset or DataArray object. + axis : CFAxisKey + The CF axis key ("X", "Y", "T", "Z"). + + Returns + ------- + Union[xr.Dataset, xr.DataArray] + A Dataset of dimension coordinate variables or a DataArray for + the single dimension coordinate variable. Raises ------ + ValueError + If the ``obj`` is an ``xr.DataArray`` and more than one dimension is + mapped to the same axis. KeyError - If the coordinate variable was not found. + If no dimension coordinate variables were found for the ``axis``. Notes ----- - Refer to [1]_ for a list of CF-compliant ``"axis"`` and ``"standard_name"`` - attr names that can be interpreted by ``cf_xarray``. + Multidimensional coordinates are ignored. References ---------- - .. [1] https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axes-and-coordinates """ - keys = CF_NAME_MAP[axis] - coord_var = None - - for key in keys: - try: - coord_var = obj.cf[key] - break - except KeyError: - pass + # Get the object's index keys, with each being a dimension. + # NOTE: xarray does not include multidimensional coordinates as index keys. + # Example: ["lat", "lon", "time"] + index_keys = obj.indexes.keys() + + # Attempt to map the axis it all of its coordinate variable(s) using the + # axis and coordinate names in the object attributes (if they are set). + # Example: Returns ["time", "time_centered"] with `axis="T"` + coord_keys = _get_all_coord_keys(obj, axis) + # Filter the index keys to just the dimension coordinate keys. + # Example: Returns ["time"], since "time_centered" is not in `index_keys` + dim_coord_keys = list(set(index_keys) & set(coord_keys)) + + if isinstance(obj, xr.DataArray) and len(dim_coord_keys) > 1: + raise ValueError( + f"This DataArray has more than one dimension {dim_coord_keys} mapped to the " + f"'{axis}' axis, which is an unexpected behavior. Try dropping extraneous " + "dimensions from the DataArray first (might affect data shape)." + ) - if coord_var is None: + if len(dim_coord_keys) == 0: raise KeyError( - f"A coordinate variable for the {axis} axis was not found. Make sure " - "the coordinate variable exists and either the (1) 'axis' attr or (2) " - "'standard_name' attr is set, or (3) the dimension name follows the " - "short-hand convention (e.g.,'lat')." + f"No '{axis}' axis dimension coordinate variables were found in the " + f"xarray object. Make sure dimension coordinate variables exist, they are " + "one dimensional, and their CF 'axis' or 'standard_name' attrs are " + "correctly set." ) - return coord_var + dim_coords = obj[ + dim_coord_keys if len(dim_coord_keys) > 1 else dim_coord_keys[0] + ].copy() -def get_axis_dim(obj: Union[xr.Dataset, xr.DataArray], axis: CFAxisName) -> str: - """Gets the dimension for an axis. - - The coordinate name should be identical to the dimension name, so this - function simply returns the coordinate name. - - Parameters - ---------- - obj : Union[xr.Dataset, xr.DataArray] - The Dataset or DataArray object. - axis : CFAxisName - The CF-compliant axis name ("X", "Y", "T", "Z") - - Returns - ------- - str - The dimension for an axis. - """ - return str(get_axis_coord(obj, axis).name) + return dim_coords def center_times(dataset: xr.Dataset) -> xr.Dataset: @@ -116,6 +155,10 @@ def center_times(dataset: xr.Dataset) -> xr.Dataset: coordinates, ensures calculations using these values are performed reliably regardless of the recorded interval. + This method attempts to get bounds for each time variable using the + CF "bounds" attribute. Coordinate variables that cannot be mapped to + bounds will be skipped. + Parameters ---------- dataset : xr.Dataset @@ -127,25 +170,35 @@ def center_times(dataset: xr.Dataset) -> xr.Dataset: The Dataset with centered time coordinates. """ ds = dataset.copy() - time: xr.DataArray = get_axis_coord(ds, "T") - time_bounds = ds.bounds.get_bounds("T") - - lower_bounds, upper_bounds = (time_bounds[:, 0].data, time_bounds[:, 1].data) - bounds_diffs: np.timedelta64 = (upper_bounds - lower_bounds) / 2 - bounds_mids: np.ndarray = lower_bounds + bounds_diffs - - time_centered = xr.DataArray( - name=time.name, - data=bounds_mids, - coords={"time": bounds_mids}, - attrs=time.attrs, - ) - time_centered.encoding = time.encoding - ds = ds.assign_coords({"time": time_centered}) - - # Update time bounds with centered time coordinates. - time_bounds[time_centered.name] = time_centered - ds[time_bounds.name] = time_bounds + coords = get_dim_coords(ds, "T") + + for coord in coords.coords.values(): + try: + coord_bounds = ds.bounds.get_bounds("T", str(coord.name)) + except KeyError: + coord_bounds = None + + if coord_bounds is not None: + lower_bounds, upper_bounds = ( + coord_bounds[:, 0].data, + coord_bounds[:, 1].data, + ) + bounds_diffs: np.timedelta64 = (upper_bounds - lower_bounds) / 2 + bounds_mids: np.ndarray = lower_bounds + bounds_diffs + + coord_centered = xr.DataArray( + name=coord.name, + data=bounds_mids, + dims=coord.dims, + attrs=coord.attrs, + ) + coord_centered.encoding = coord.encoding + ds = ds.assign_coords({coord.name: coord_centered}) + + # Update time bounds with centered time coordinates. + coord_bounds[coord_centered.name] = coord_centered + ds[coord_bounds.name] = coord_bounds + return ds @@ -180,68 +233,187 @@ def swap_lon_axis( The Dataset with swapped lon axes orientation. """ ds = dataset.copy() - lon: xr.DataArray = get_axis_coord(ds, "X").copy() - lon_bounds: xr.DataArray = dataset.bounds.get_bounds("X").copy() + coord_keys = get_dim_coords(ds, "X").coords.keys() + # Attempt to swap the orientation for longitude coordinates. with xr.set_options(keep_attrs=True): - if to == (-180, 180): - new_lon = ((lon + 180) % 360) - 180 - new_lon_bounds = ((lon_bounds + 180) % 360) - 180 - ds = _reassign_lon(ds, new_lon, new_lon_bounds) - elif to == (0, 360): - new_lon = lon % 360 - new_lon_bounds = lon_bounds % 360 - ds = _reassign_lon(ds, new_lon, new_lon_bounds) - - # Handle cases where a prime meridian cell exists, which can occur - # after swapping to (0, 360). - p_meridian_index = _get_prime_meridian_index(new_lon_bounds) - if p_meridian_index is not None: - ds = _align_lon_to_360(ds, p_meridian_index) - else: - raise ValueError( - "Currently, only (-180, 180) and (0, 360) are supported longitude axis " - "orientations." - ) + for key in coord_keys: + new_coord = _swap_lon_axis(ds.coords[key], to) + + if ds.coords[key].identical(new_coord): + continue - # If the swapped axis orientation is the same as the existing axis - # orientation, return the original Dataset. - if new_lon.identical(lon): - return dataset + ds.coords[key] = new_coord + + try: + bounds = ds.bounds.get_bounds("X") + except KeyError: + bounds = None + + if isinstance(bounds, xr.DataArray): + ds = _swap_lon_bounds(ds, str(bounds.name), to) + elif isinstance(bounds, xr.Dataset): + for key in bounds.data_vars.keys(): + ds = _swap_lon_bounds(ds, str(key), to) if sort_ascending: - ds = ds.sortby(new_lon.name, ascending=True) + ds = ds.sortby(list(coord_keys), ascending=True) + + return ds + + +def _get_all_coord_keys( + obj: Union[xr.Dataset, xr.DataArray], axis: CFAxisKey +) -> List[str]: + """Gets all dimension and non-dimension coordinate keys for an axis. + + This function uses ``cf_xarray`` to interpret CF axis and coordinate name + metadata to map an ``axis`` to its coordinate keys. Refer to [2]_ for more + information on the ``cf_xarray`` mapping tables. + + It also loops over a list of statically defined coordinate variable names to + see if they exist in the object, and appends keys that do exist. + + Parameters + ---------- + obj : Union[xr.Dataset, xr.DataArray] + The Dataset or DataArray object. + axis : CFAxisKey + The CF axis key ("X", "Y", "T", or "Z"). + + Returns + ------- + List[str] + The axis coordinate variable keys. + + References + ---------- + .. [2] https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axes-and-coordinates + """ + cf_attrs = CF_ATTR_MAP[axis] + var_names = VAR_NAME_MAP[axis] + + keys: List[str] = [] + + try: + keys = keys + obj.cf.axes[cf_attrs["axis"]] + except KeyError: + pass + + try: + keys = keys + obj.cf.coordinates[cf_attrs["coordinate"]] + except KeyError: + pass + + for name in var_names: + if name in obj.coords.keys(): + keys.append(name) + + return list(set(keys)) + + +def _swap_lon_bounds(ds: xr.Dataset, key: str, to: Tuple[float, float]): + bounds = ds[key].copy() + new_bounds = _swap_lon_axis(bounds, to) + + if not ds[key].identical(new_bounds): + ds[key] = new_bounds + + # Handle cases where a prime meridian cell exists, which can occur + # after swapping longitude bounds to (0, 360). This involves extending + # the longitude and bounds by one cell to take into account the prime + # meridian. It also results in extending the data variables by one + # value. + if to == (0, 360): + p_meridian_index = _get_prime_meridian_index(ds[key]) + + if p_meridian_index is not None: + ds = _align_lon_to_360(ds, ds[key], p_meridian_index) return ds -def _reassign_lon(dataset: xr.Dataset, lon: xr.DataArray, lon_bounds: xr.DataArray): +def _swap_lon_axis(coords: xr.DataArray, to: Tuple[float, float]) -> xr.DataArray: + """Swaps the axis orientation for longitude coordinates. + + Parameters + ---------- + coords : xr.DataArray + Coordinates on a longitude axis. + to : Tuple[float, float] + The new longitude axis orientation. + + Returns + ------- + xr.DataArray + The longitude coordinates the opposite axis orientation If the + coordinates are already on the specified axis orientation, the same + coordinates are returned. """ - Reassign longitude coordinates and bounds to the Dataset after swapping the - orientation. + if to == (-180, 180): + new_coords = ((coords + 180) % 360) - 180 + elif to == (0, 360): + # Swap the coordinates. + # Example with 180 coords: [-180, -0, 179] -> [0, 180, 360] + # Example with 360 coords: [60, 150, 360] -> [60, 150, 0] + new_coords = coords % 360 + + # Check if the original coordinates contain an element with a value of + # 360. If this element exists, use its index to revert its swapped + # value of 0 (360 % 360 is 0) back to 360. This case usually happens + # if the coordinate are already on the (0, 360) axis orientation. + # Example with 360 coords: [60, 150, 0] -> [60, 150, 360] + index_with_360 = np.where(coords == 360) + + if len(index_with_360) > 0: + _if_multidim_dask_array_then_load(new_coords) + + new_coords[index_with_360] = 360 + else: + raise ValueError( + "Currently, only (-180, 180) and (0, 360) are supported longitude axis " + "orientations." + ) + + return new_coords + + +def _get_prime_meridian_index(lon_bounds: xr.DataArray) -> Optional[np.ndarray]: + """Gets the index of the prime meridian cell in the longitude bounds. + + A prime meridian cell can exist when converting the axis orientation + from [-180, 180) to [0, 360). Parameters ---------- - dataset : xr.Dataset - The Dataset. - lon : xr.DataArray - The swapped longitude coordinates. lon_bounds : xr.DataArray - The swapped longitude bounds. + The longitude bounds. Returns ------- - xr.Dataset - The Dataset with swapped longitude coordinates and bounds. + Optional[np.ndarray] + An array with a single element representing the index of the prime + meridian index if it exists. Otherwise, None if the cell does not exist. + + Raises + ------ + ValueError + If more than one grid cell spans the prime meridian. """ - lon[lon.name] = lon_bounds[lon.name] = lon + p_meridian_index = np.where(lon_bounds[:, 1] - lon_bounds[:, 0] < 0)[0] - dataset[lon.name] = lon - dataset[lon_bounds.name] = lon_bounds - return dataset + if p_meridian_index.size == 0: # pragma:no cover + return None + elif p_meridian_index.size > 1: + raise ValueError("More than one grid cell spans prime meridian.") + return p_meridian_index -def _align_lon_to_360(dataset: xr.Dataset, p_meridian_index: np.ndarray) -> xr.Dataset: +def _align_lon_to_360( + ds: xr.Dataset, + lon_bounds: xr.DataArray, + p_meridian_index: np.ndarray, +) -> xr.Dataset: """Handles a prime meridian cell to align longitude axis to (0, 360). This method ensures the domain bounds are within 0 to 360 by handling @@ -269,45 +441,39 @@ def _align_lon_to_360(dataset: xr.Dataset, p_meridian_index: np.ndarray) -> xr.D xr.Dataset The Dataset. """ - ds = dataset.copy() - lon: xr.DataArray = get_axis_coord(ds, "X") - lon_bounds: xr.DataArray = dataset.bounds.get_bounds("X") + dim = get_dim_keys(lon_bounds, "X") - # If chunking, must convert the xarray data structure from lazy - # Dask arrays into eager, in-memory NumPy arrays before performing - # manipulations on the data. Otherwise, it raises `NotImplementedError - # xarray can't set arrays with multiple array indices to dask yet`. - if isinstance(lon_bounds.data, Array): - lon_bounds.load() + # Create a dataset to store updated longitude variables. + ds_lon = xr.Dataset() - # Align the the longitude bounds using the prime meridian index. - lon_bounds = _align_lon_bounds_to_360(lon_bounds, p_meridian_index) + # Align the the longitude bounds to the 360 orientation using the prime + # meridian index. This function splits the grid cell into two parts (east + # and west), which appends an extra set of bounds for the 360 coordinate. + ds_lon[lon_bounds.name] = _align_lon_bounds_to_360(lon_bounds, p_meridian_index) - # Concatenate the longitude coordinates with 360 to handle the prime - # meridian cell and update the coordinates for the longitude bounds. - p_meridian_cell = xr.DataArray([360.0], coords={lon.name: [360.0]}, dims=[lon.name]) - lon = xr.concat((lon, p_meridian_cell), dim=lon.name) - lon_bounds[lon.name] = lon + # After appending the extra set of bounds, update the last coordinate from + # 0 to 360. + for key, coord in ds_lon.coords.items(): + coord.values[-1] = 360 + ds_lon[key] = coord # Get the data variables related to the longitude axis and concatenate each # with the value at the prime meridian. - lon_vars = {} - for key, value in ds.cf.data_vars.items(): - if key != lon_bounds.name and lon.name in value.dims: - lon_vars[key] = value - - for name, var in lon_vars.items(): - p_meridian_val = var.isel({lon.name: p_meridian_index}) - new_var = xr.concat((var, p_meridian_val), dim=lon.name) - new_var[lon.name] = lon - lon_vars[name] = new_var - - # Create a Dataset with longitude data vars and merge it to the Dataset - # without longitude data vars. - ds_lon = xr.Dataset(data_vars={**lon_vars, lon_bounds.name: lon_bounds}) - ds_no_lon = ds.get([v for v in ds.data_vars if lon.name not in ds[v].dims]) # type: ignore - ds = xr.merge((ds_no_lon, ds_lon)) - return ds + for key, var in ds.cf.data_vars.items(): + if key != lon_bounds.name and dim in var.dims: + # Concatenate the prime meridian cell to the variable + p_meridian_val = var.isel({dim: p_meridian_index}).copy() + new_var = xr.concat((var, p_meridian_val), dim=dim) + + # Update the longitude coordinates for the variable. + new_var[dim] = ds_lon[dim] + ds_lon[var.name] = new_var + + # Create a new dataset of non-longitude vars and updated longitude vars. + ds_no_lon = ds.get([v for v in ds.data_vars if dim not in ds[v].dims]) # type: ignore + ds_final = xr.merge((ds_no_lon, ds_lon)) + + return ds_final def _align_lon_bounds_to_360( @@ -345,6 +511,8 @@ def _align_lon_bounds_to_360( ValueError If longitude bounds are inclusively between 0 and 360. """ + _if_multidim_dask_array_then_load(bounds) + # Example array: [[359, 1], [1, 90], [90, 180], [180, 359]] # Reorient bound to span across zero (i.e., [359, 1] -> [-1, 1]). # Result: [[-1, 1], [1, 90], [90, 180], [180, 359]] @@ -353,7 +521,7 @@ def _align_lon_bounds_to_360( # Extend the array to nlon+1 by concatenating the grid cell that # spans the prime meridian to the end. # Result: [[-1, 1], [1, 90], [90, 180], [180, 359], [-1, 1]] - dim = get_axis_dim(bounds, "X") + dim = get_dim_keys(bounds, "X") bounds = xr.concat((bounds, bounds[p_meridian_index, :]), dim=dim) # Add an equivalent bound that spans 360 @@ -365,37 +533,5 @@ def _align_lon_bounds_to_360( # Update the lower-most min and upper-most max bounds to [0, 360]. # Result: [[0, 1], [1, 90], [90, 180], [180, 359], [359, 360]] bounds[p_meridian_index, 0], bounds[-1, 1] = (0.0, 360.0) - return bounds - - -def _get_prime_meridian_index(lon_bounds: xr.DataArray) -> Optional[np.ndarray]: - """Gets the index of the prime meridian cell in the longitude bounds. - - A prime meridian cell can exist when converting the axis orientation - from [-180, 180) to [0, 360). - - Parameters - ---------- - lon_bounds : xr.DataArray - The longitude bounds. - - Returns - ------- - Optional[np.ndarray] - An array with a single elementing representing the index of the prime - meridian index if it exists. Otherwise, None if the cell does not exist. - Raises - ------ - ValueError - If more than one grid cell spans the prime meridian. - """ - p_meridian_index = np.where(lon_bounds[:, 1] - lon_bounds[:, 0] < 0)[0] - - # FIXME: When does this conditional return true? It seems like swapping from - # (-180, to 180) to (0, 360) always produces a prime meridian cell? - if p_meridian_index.size == 0: # pragma:no cover - return None - elif p_meridian_index.size > 1: - raise ValueError("More than one grid cell spans prime meridian.") - return p_meridian_index + return bounds diff --git a/xcdat/bounds.py b/xcdat/bounds.py index 06f55082..7b5aff13 100644 --- a/xcdat/bounds.py +++ b/xcdat/bounds.py @@ -1,7 +1,7 @@ """Bounds module for functions related to coordinate bounds.""" import collections import warnings -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import cf_xarray as cfxr # noqa: F401 import cftime @@ -9,8 +9,9 @@ import pandas as pd import xarray as xr -from xcdat.axis import CF_NAME_MAP, CFAxisName, get_axis_coord +from xcdat.axis import CF_ATTR_MAP, CFAxisKey, get_dim_coords from xcdat.logger import setup_custom_logger +from xcdat.spatial import _get_data_var logger = setup_custom_logger(__name__) @@ -138,89 +139,103 @@ def add_missing_bounds(self, width: float = 0.5) -> xr.Dataset: ------- xr.Dataset """ - axes = CF_NAME_MAP.keys() + ds = self._dataset.copy() + axes = CF_ATTR_MAP.keys() for axis in axes: - # Check if the axis coordinates can be mapped to. try: - get_axis_coord(self._dataset, axis) + coords = get_dim_coords(ds, axis) except KeyError: continue - # Determine if the axis is also a dimension by determining if there - # is overlap between the CF axis names and the dimension names. If - # not, skip over axis for validation. - if len(set(CF_NAME_MAP[axis]) & set(self._dataset.dims.keys())) == 0: - continue - - # Check if bounds also exist using the "bounds" attribute. - # Otherwise, try to add bounds if it meets the function's criteria. - try: - self.get_bounds(axis) - continue - except KeyError: - pass + for coord in coords.coords.values(): + try: + self.get_bounds(axis, str(coord.name)) + continue + except KeyError: + pass - try: - self._dataset = self.add_bounds(axis, width) - except ValueError: - continue + try: + bounds = self._create_bounds(axis, coord, width) + ds[bounds.name] = bounds + ds[coord.name].attrs["bounds"] = bounds.name + except ValueError: + continue - return self._dataset + return ds - def get_bounds(self, axis: CFAxisName) -> xr.DataArray: - """Get bounds for axis coordinates if both exist. + def get_bounds( + self, axis: CFAxisKey, var_key: Optional[str] = None + ) -> Union[xr.Dataset, xr.DataArray]: + """Gets coordinate bounds. Parameters ---------- - axis : CFAxisName - The CF-compliant axis name ("X", "Y", "T", "Z"). + axis : CFAxisKey + The CF axis key ("X", "Y", "T", "Z"). + var_key: Optional[str] + The key of the coordinate or data variable to get axis bounds for. + This parameter is useful if you only want the single bounds + DataArray related to the axis on the variable (e.g., "tas" has + a "lat" dimension and you want "lat_bnds"). Returns ------- - xr.DataArray - The coordinate bounds. + Union[xr.Dataset, xr.DataArray] + A Dataset of N bounds variables, or a single bounds variable + DataArray. Raises ------ ValueError If an incorrect ``axis`` argument is passed. - KeyError - If the coordinate variable was not found for the ``axis``. - - KeyError - If the coordinate bounds were not found for the ``axis``. + KeyError: + If bounds were not found for the specific ``axis``. """ self._validate_axis_arg(axis) - coord_var = get_axis_coord(self._dataset, axis) - try: - bounds_key = coord_var.attrs["bounds"] - except KeyError: - raise KeyError( - f"The coordinate variable '{coord_var.name}' has no 'bounds' attr. " - "Set the 'bounds' attr to the name of the bounds data variable." - ) + if var_key is None: + # Get all bounds keys in the Dataset for this axis. + bounds_keys = self._get_bounds_keys(axis) + else: + # Get the obj in the Dataset using the key. + obj = _get_data_var(self._dataset, key=var_key) - try: - bounds_var = self._dataset[bounds_key].copy() - except KeyError: + # Check if the object is a data variable or a coordinate variable. + # If it is a data variable, derive the axis coordinate variable. + if obj.name in list(self._dataset.data_vars): + coord = get_dim_coords(obj, axis) + elif obj.name in list(self._dataset.coords): + coord = obj + + try: + bounds_keys = [coord.attrs["bounds"]] + except KeyError: + bounds_keys = [] + + if len(bounds_keys) == 0: raise KeyError( - f"Bounds were not found for the coordinate variable '{coord_var.name}'. " - "Add bounds with `Dataset.bounds.add_bounds()`." + f"No bounds were found for the '{axis}' axis. Make sure bounds vars " + "exist in the Dataset with names that match the 'bounds' keys, or try " + "adding bounds." ) - return bounds_var + bounds: Union[xr.Dataset, xr.DataArray] = self._dataset[ + bounds_keys if len(bounds_keys) > 1 else bounds_keys[0] + ].copy() + + return bounds - def add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: + def add_bounds(self, axis: CFAxisKey, width: float = 0.5) -> xr.Dataset: """Add bounds for an axis using its coordinate points. - The coordinates must meet the following criteria in order to add - bounds: + This method loops over the axis's coordinate variables and attempts to + add bounds for each of them if they don't exist. The coordinates must + meet the following criteria in order to add bounds: 1. The axis for the coordinates are "X", "Y", "T", or "Z" - 2. Coordinates are a single dimension, not multidimensional + 2. Coordinates are single dimensional, not multidimensional 3. Coordinates are a length > 1 (not singleton) 4. Bounds must not already exist. * Determined by attempting to map the coordinate variable's @@ -228,8 +243,8 @@ def add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: Parameters ---------- - axis : CFAxisName - The CF-compliant axis name ("X", "Y", "T", "Z"). + axis : CFAxisKey + The CF axis key ("X", "Y", "T", or "Z"). width : float, optional Width of the bounds relative to the position of the nearest points, by default 0.5. @@ -245,33 +260,80 @@ def add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: If bounds already exist. They must be dropped first. """ + ds = self._dataset.copy() self._validate_axis_arg(axis) + coord_vars: Union[xr.DataArray, xr.Dataset] = get_dim_coords( + self._dataset, axis + ) + + for coord in coord_vars.coords.values(): + # Check if the coord var has a "bounds" attr and the bounds actually + # exist in the Dataset. If it does not, then add the bounds. + try: + bounds_key = ds[coord.name].attrs["bounds"] + ds[bounds_key] + + continue + except KeyError: + bounds = self._create_bounds(axis, coord, width) + + ds[bounds.name] = bounds + ds[coord.name].attrs["bounds"] = bounds.name + + return ds + + def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]: + """Get bounds keys for an axis's coordinate variables in the dataset. + + This function attempts to map bounds to an axis using ``cf_xarray`` + and its interpretation of the CF "bounds" attribute. + + Parameters + ---------- + axis : CFAxisKey + The CF axis key ("X", "Y", "T", or "Z"). + + Returns + ------- + List[str] + The axis bounds key(s). + """ + cf_method = self._dataset.cf.bounds + cf_attrs = CF_ATTR_MAP[axis] + + keys: List[str] = [] + try: - self.get_bounds(axis) - raise ValueError( - f"{axis} bounds already exist. Drop them first to add new bounds." - ) + keys = keys + cf_method[cf_attrs["axis"]] except KeyError: - dataset = self._add_bounds(axis, width) + pass - return dataset + try: + keys = cf_method[cf_attrs["coordinate"]] + except KeyError: + pass - def _add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: - """Add bounds for an axis using its coordinate points. + return list(set(keys)) + + def _create_bounds( + self, axis: CFAxisKey, coord_var: xr.DataArray, width: float + ) -> xr.DataArray: + """Creates bounds for an axis using its coordinate points. Parameters ---------- - axis : CFAxisName - The CF-compliant axis name ("X", "Y", "T", "Z"). - width : float, optional - Width of the bounds relative to the position of the nearest points, - by default 0.5. + axis: CFAxisKey + The CF axis key ("X", "Y", "T" ,"Z"). + coord_var : xr.DataArray + The coordinate variable for the axis. + width : float + Width of the bounds relative to the position of the nearest points. Returns ------- - xr.Dataset - The dataset with new coordinate bounds for an axis. + xr.DataArray + The axis coordinate bounds. Raises ------ @@ -289,20 +351,11 @@ def _add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: .. [2] https://cf-xarray.readthedocs.io/en/latest/generated/xarray.Dataset.cf.add_bounds.html# """ - # Add coordinate bounds to the dataset - ds = self._dataset.copy() - coord_var: xr.DataArray = get_axis_coord(ds, axis) - - # Validate coordinate shape and dimensions - if coord_var.ndim != 1: + is_singleton = coord_var.size <= 1 + if is_singleton: raise ValueError( f"Cannot generate bounds for coordinate variable '{coord_var.name}'" - " because it is multidimensional coordinates." - ) - if coord_var.shape[0] <= 1: - raise ValueError( - f"Cannot generate bounds for coordinate variable '{coord_var.name}'" - " which has a length <= 1." + " which has a length <= 1 (singleton)." ) # Retrieve coordinate dimension to calculate the diffs between points. @@ -314,14 +367,10 @@ def _add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: diffs = np.insert(diffs, 0, diffs[0]) diffs = np.append(diffs, diffs[-1]) - # In xarray and xCDAT, time coordinates with non-CF compliant calendars - # (360-day, noleap) and/or units ("months", "years") are decoded using - # `cftime` objects instead of `datetime` objects. `cftime` objects only - # support arithmetic using `timedelta` objects, so the values of `diffs` - # must be casted from `dtype="timedelta64[ns]"` to `timedelta`. - if coord_var.name in ("T", "time") and issubclass( - type(coord_var.values[0]), cftime.datetime - ): + # `cftime` objects only support arithmetic using `timedelta` objects, so + # the values of `diffs` must be casted from `dtype="timedelta64[ns]"` + # to `timedelta` objects. + if axis == "T" and issubclass(type(coord_var.values[0]), cftime.datetime): diffs = pd.to_timedelta(diffs) # FIXME: These lines produces the warning: `PerformanceWarning: @@ -336,10 +385,10 @@ def _add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: upper_bounds = coord_var + diffs[1:] * (1 - width) # Transpose both bound arrays into a 2D array. - bounds = np.array([lower_bounds, upper_bounds]).transpose() + data = np.array([lower_bounds, upper_bounds]).transpose() # Clip latitude bounds at (-90, 90) - if coord_var.name in ("lat", "latitude", "grid_latitude"): + if axis == "Y": units = coord_var.attrs.get("units") if units is None: @@ -355,31 +404,26 @@ def _add_bounds(self, axis: CFAxisName, width: float = 0.5) -> xr.Dataset: ) if (coord_var >= -90).all() and (coord_var <= 90).all(): - np.clip(bounds, -90, 90, out=bounds) + np.clip(data, -90, 90, out=data) # Create the bounds data variable and add it to the Dataset. - bounds_var = xr.DataArray( + bounds = xr.DataArray( name=f"{coord_var.name}_bnds", - data=bounds, + data=data, coords={coord_var.name: coord_var}, - dims=[coord_var.name, "bnds"], + dims=[*coord_var.dims, "bnds"], attrs={"xcdat_bounds": "True"}, ) - ds[bounds_var.name] = bounds_var - - # Update the attributes of the coordinate variable. - coord_var.attrs["bounds"] = bounds_var.name - ds[coord_var.name] = coord_var - return ds + return bounds - def _validate_axis_arg(self, axis: CFAxisName): - cf_axis_names = CF_NAME_MAP.keys() + def _validate_axis_arg(self, axis: CFAxisKey): + cf_axis_keys = CF_ATTR_MAP.keys() - if axis not in cf_axis_names: - keys = ", ".join(f"'{key}'" for key in cf_axis_names) + if axis not in cf_axis_keys: + keys = ", ".join(f"'{key}'" for key in cf_axis_keys) raise ValueError( f"Incorrect 'axis' argument value. Supported values include {keys}." ) - get_axis_coord(self._dataset, axis) + get_dim_coords(self._dataset, axis) diff --git a/xcdat/dataset.py b/xcdat/dataset.py index 70808407..13c3c403 100644 --- a/xcdat/dataset.py +++ b/xcdat/dataset.py @@ -2,19 +2,21 @@ import pathlib from datetime import datetime from functools import partial -from glob import glob -from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import xarray as xr from dateutil import parser from dateutil import relativedelta as rd from xarray.coding.cftime_offsets import get_date_type -from xarray.coding.times import convert_times +from xarray.coding.times import convert_times, decode_cf_datetime +from xarray.coding.variables import lazy_elemwise_func, pop_to, unpack_for_decoding +from xarray.core.variable import as_variable -from xcdat import bounds # noqa: F401 +from xcdat import bounds as bounds_accessor # noqa: F401 +from xcdat.axis import _get_all_coord_keys from xcdat.axis import center_times as center_times_func -from xcdat.axis import get_axis_coord, get_axis_dim, swap_lon_axis +from xcdat.axis import swap_lon_axis from xcdat.logger import setup_custom_logger logger = setup_custom_logger(__name__) @@ -57,16 +59,16 @@ def open_dataset( features. decode_times: bool, optional If True, attempt to decode times encoded in the standard NetCDF - datetime format into datetime objects. Otherwise, leave them encoded - as numbers. This keyword may not be supported by all the backends, - by default True. + datetime format into cftime.datetime objects. Otherwise, leave them + encoded as numbers. This keyword may not be supported by all the + backends, by default True. center_times: bool, optional - If True, center time coordinates using the midpoint between its upper - and lower bounds. Otherwise, use the provided time coordinates, by - default False. + If True, attempt to center time coordinates using the midpoint between + its upper and lower bounds. Otherwise, use the provided time + coordinates, by default False. lon_orient: Optional[Tuple[float, float]], optional - The orientation to use for the Dataset's longitude axis (if it exists), - by default None. + The orientation to use for the Dataset's longitude axis (if it exists). + Either `(-180, 180)` or `(0, 360)`, by default None. Supported options: @@ -95,16 +97,10 @@ def open_dataset( .. [1] https://xarray.pydata.org/en/stable/generated/xarray.open_dataset.html """ + ds = xr.open_dataset(path, decode_times=False, **kwargs) # type: ignore + if decode_times: - cf_compliant_time: Optional[bool] = _has_cf_compliant_time(path) - # xCDAT attempts to decode non-CF compliant time coordinates. - if cf_compliant_time is False: - ds = xr.open_dataset(path, decode_times=False, **kwargs) # type: ignore - ds = decode_non_cf_time(ds) - else: - ds = xr.open_dataset(path, decode_times=True, **kwargs) # type: ignore - else: - ds = xr.open_dataset(path, decode_times=False, **kwargs) # type: ignore + ds = decode_time(ds) ds = _postprocess_dataset(ds, data_var, center_times, add_bounds, lon_orient) @@ -140,13 +136,14 @@ def open_mfdataset( data_var: Optional[str], optional The key of the data variable to keep in the Dataset, by default None. decode_times: bool, optional - If True, decode times encoded in the standard NetCDF datetime format - into datetime objects. Otherwise, leave them encoded as numbers. - This keyword may not be supported by all the backends, by default True. + If True, attempt to decode times encoded in the standard NetCDF + datetime format into cftime.datetime objects. Otherwise, leave them + encoded as numbers. This keyword may not be supported by all the + backends, by default True. center_times: bool, optional - If True, center time coordinates using the midpoint between its upper - and lower bounds. Otherwise, use the provided time coordinates, by - default False. + If True, attempt to center time coordinates using the midpoint between + its upper and lower bounds. Otherwise, use the provided time + coordinates, by default False. lon_orient: Optional[Tuple[float, float]], optional The orientation to use for the Dataset's longitude axis (if it exists), by default None. @@ -203,67 +200,35 @@ def open_mfdataset( .. [2] https://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html """ - # `xr.open_mfdataset()` drops the time coordinates encoding dictionary if - # multiple files are merged with `decode_times=True` (refer to - # https://github.com/pydata/xarray/issues/2436). The workaround is to store - # the time encoding from the first dataset as a variable, and add the time - # encoding back to final merged dataset. - time_encoding = None - - if decode_times: - time_encoding = _keep_time_encoding(paths) - - cf_compliant_time: Optional[bool] = _has_cf_compliant_time(paths) - # xCDAT attempts to decode non-CF compliant time coordinates using the - # preprocess keyword arg with `xr.open_mfdataset()`. - if cf_compliant_time is False: - decode_times = False - preprocess = partial(_preprocess_non_cf_dataset, callable=preprocess) + preprocess = partial(_preprocess, decode_times=decode_times, callable=preprocess) ds = xr.open_mfdataset( paths, - decode_times=decode_times, + decode_times=False, data_vars=data_vars, preprocess=preprocess, **kwargs, # type: ignore ) - ds = _postprocess_dataset(ds, data_var, center_times, add_bounds, lon_orient) - if time_encoding is not None: - time_dim = get_axis_dim(ds, "T") - ds[time_dim].encoding = time_encoding - # Update "original_shape" to reflect the final time coordinates shape. - ds[time_dim].encoding["original_shape"] = ds[time_dim].shape + ds = _postprocess_dataset(ds, data_var, center_times, add_bounds, lon_orient) return ds -def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset: - """Decodes time coordinates and time bounds with non-CF compliant units. +def decode_time(dataset: xr.Dataset) -> xr.Dataset: + """Decodes CF and non-CF time coordinates and time bounds using ``cftime``. By default, ``xarray`` only supports decoding time with CF compliant units - [3]_. This function enables decoding time with non-CF compliant units. + [3]_. This function enables also decoding time with non-CF compliant units. + It skips decoding time coordinates that have already been decoded as + ``"datetime64[ns]"`` or ``cftime.datetime``. - The time coordinates must have a "calendar" attribute set to a CF calendar - type supported by ``cftime`` ("noleap", "360_day", "365_day", "366_day", - "gregorian", "proleptic_gregorian", "julian", "all_leap", or "standard") - and a "units" attribute set to a supported format ("months since ..." or - "years since ..."). - - The logic for this function: - - 1. Extract units and reference date strings from the "units" attribute. - - * For example with "months since 1800-01-01", the units are "months" and - reference date is "1800-01-01". - - 2. Using the reference date, create a reference ``datetime`` object. - 3. Starting from the reference ``datetime`` object, use the numerically - encoded time coordinate values (each representing an offset) to create an - array of ``cftime`` objects based on the calendar type. - 4. Using the array of ``cftime`` objects, create a new xr.DataArray - of time coordinates to replace the numerically encoded ones. - 5. If it exists, create a time bounds DataArray using steps 3 and 4. + For time coordinates to be decodable, they must have a "calendar" attribute + set to a CF calendar type supported by ``cftime``. CF calendar types + include "noleap", "360_day", "365_day", "366_day", "gregorian", + "proleptic_gregorian", "julian", "all_leap", or "standard". They must also + have a "units" attribute set to a format supported by xcdat ("months since + ..." or "years since ..."). Parameters ---------- @@ -293,9 +258,9 @@ def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset: Examples -------- - Decode the time coordinates with non-CF units in a Dataset: + Decode the time coordinates in a Dataset: - >>> from xcdat.dataset import decode_non_cf_time + >>> from xcdat.dataset import decode_time >>> >>> ds.time @@ -310,7 +275,7 @@ def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset: standard_name: time calendar: noleap >>> - >>> ds_decoded = decode_non_cf_time(ds) + >>> ds_decoded = decode_time(ds) >>> ds_decoded.time array([cftime.DatetimeNoLeap(1850, 1, 1, 0, 0, 0, 0, has_year_zero=True), @@ -337,206 +302,94 @@ def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset: 'calendar': 'noleap'} """ ds = dataset.copy() - time = get_axis_coord(ds, "T") - time_attrs = time.attrs - - # NOTE: When opening datasets with `decode_times=False`, the "calendar" and - # "units" attributes are stored in `.attrs` (unlike `decode_times=True` - # which stores them in `.encoding`). Since xCDAT manually decodes non-CF - # compliant time coordinates by first setting `decode_times=False`, the - # "calendar" and "units" attrs are popped from the `.attrs` dict and stored - # in the `.encoding` dict to mimic xarray's behavior. - calendar = time_attrs.pop("calendar", None) - units_attr = time_attrs.pop("units", None) - - if calendar is None: - logger.warning( - "This dataset's time coordinates do not have a 'calendar' attribute set, " - "so time coordinates could not be decoded. Set the 'calendar' attribute " - f"(`ds.{time.name}.attrs['calendar]`) and try decoding the time " - "coordinates again." - ) - return ds - - if units_attr is None: - logger.warning( - "This dataset's time coordinates do not have a 'units' attribute set, " - "so the time coordinates could not be decoded. Set the 'units' attribute " - f"(`ds.{time.name}.attrs['units']`) and try decoding the time " - "coordinates again." + coord_keys = _get_all_coord_keys(ds, "T") + + if len(coord_keys) == 0: + raise KeyError( + "Unable to map to time coordinates in this dataset to perform decoding. " + "Make sure that the time coordinates have the CF 'axis' or 'standard_name' " + "attribute set (e.g., ds['time'].attrs['axis'] = 'T' or " + "ds['time'].attrs['standard_name'] = 'time'), and try decoding again. " ) - return ds - try: - units, ref_date = _split_time_units_attr(units_attr) - except ValueError: - logger.warning( - f"This dataset's time coordinates 'units' attribute ('{units_attr}') is " - "not in a supported format ('months since...' or 'years since...'), so the " - "time coordinates could not be decoded." - ) - return ds - - data = _get_cftime_coords(ref_date, time.values, calendar, units) - decoded_time = xr.DataArray( - name=time.name, - data=data, - dims=time.dims, - coords={time.name: data}, - # As mentioned in a comment above, the units and calendar attributes are - # popped from the `.attrs` dict. - attrs=time_attrs, - ) - decoded_time.encoding = { - "source": ds.encoding.get("source", "None"), - "dtype": time.dtype, - "original_shape": time.shape, - # The units and calendar attributes are now saved in the `.encoding` - # dict. - "units": units_attr, - "calendar": calendar, - } - ds = ds.assign_coords({time.name: decoded_time}) - - try: - time_bounds = ds.bounds.get_bounds("T") - except KeyError: - time_bounds = None - - if time_bounds is not None: - lowers = _get_cftime_coords(ref_date, time_bounds.values[:, 0], calendar, units) - uppers = _get_cftime_coords(ref_date, time_bounds.values[:, 1], calendar, units) - data_bounds = np.vstack((lowers, uppers)).T - - decoded_time_bnds = xr.DataArray( - name=time_bounds.name, - data=data_bounds, - dims=time_bounds.dims, - coords=time_bounds.coords, - attrs=time_bounds.attrs, - ) - decoded_time_bnds.coords[time.name] = decoded_time - decoded_time_bnds.encoding = time_bounds.encoding - ds = ds.assign({time_bounds.name: decoded_time_bnds}) + for key in coord_keys: + coords = ds[key].copy() + + if _is_decodable(coords) and not _is_decoded(coords): + if coords.attrs.get("calendar") is None: + coords.attrs["calendar"] = "standard" + logger.warning( + f"'{coords.name}' does not have a calendar attribute set. " + "Defaulting to CF 'standard' calendar." + ) + + decoded_time = _decode_time(coords) + ds = ds.assign_coords({coords.name: decoded_time}) + + try: + bounds = ds.bounds.get_bounds("T", var_key=coords.name) + except KeyError: + bounds = None + + if bounds is not None and not _is_decoded(bounds): + # Bounds don't typically store the "units" and "calendar" + # attributes required for decoding, so these attributes need to be + # copied from the coordinates. + bounds.attrs.update( + { + "units": coords.attrs["units"], + "calendar": coords.attrs["calendar"], + } + ) + decoded_bounds = _decode_time(bounds) + ds = ds.assign({bounds.name: decoded_bounds}) return ds -def _keep_time_encoding(paths: Paths) -> Dict[Hashable, Any]: - """ - Returns the time encoding attributes from the first dataset in a list of - paths. - - Time encoding information is critical for several xCDAT operations such as - temporal averaging (e.g., uses the "calendar" attr). This function is a - workaround to the undesired xarray behavior/quirk with - `xr.open_mfdataset()`, which drops the `.encoding` dict from the final - merged dataset (refer to https://github.com/pydata/xarray/issues/2436). - - Parameters - ---------- - paths: Paths - The paths to the dataset(s). - - Returns - ------- - Dict[Hashable, Any] - The time encoding dictionary. - """ - first_path = _get_first_path(paths) - - # xcdat.open_dataset() is called instead of xr.open_dataset() because - # we want to handle decoding non-CF compliant as well. - # FIXME: Remove `type: ignore` comment after properly handling the type - # annotations in `_get_first_path()`. - ds = open_dataset(first_path, decode_times=True, add_bounds=False) # type: ignore - time_coord = get_axis_coord(ds, "T") - - time_encoding = time_coord.encoding - time_encoding["source"] = paths - - return time_coord.encoding +def _preprocess( + ds: xr.Dataset, decode_times: Optional[bool], callable: Optional[Callable] = None +) -> xr.Dataset: + """Preprocesses each dataset passed to ``open_mfdataset()``. + This function accepts a user specified preprocess function, which is + executed before additional internal preprocessing functions. -def _has_cf_compliant_time(paths: Paths) -> Optional[bool]: - """Checks if a dataset has time coordinates with CF compliant units. - If the dataset does not contain a time dimension, None is returned. - Otherwise, the units attribute is extracted from the time coordinates to - determine whether it is CF or non-CF compliant. + An internal call to ``decode_time()`` is performed, which decodes + both CF and non-CF time coordinates and bounds (if they exist). By default, + if ``decode_times=False`` is passed to ``open_mfdataset()``, xarray will + concatenate time values using the first dataset's ``units`` attribute. This + results in an issue for cases where the numerically encoded time values are + the same and the ``units`` attribute differs between datasets. For example, + two files have the same time values, but the units of the first file is + "months since 2000-01-01" and the second is "months since 2001-01-01". Since + the first dataset's units are used in xarray for concatenating datasets, the + time values corresponding to the second file will be dropped since they + appear to be the same as the first file. Calling ``decode_time()`` + on each dataset individually before concatenating solves the aforementioned + issue. Parameters ---------- - path : Union[str, pathlib.Path, List[str], List[pathlib.Path], \ - List[List[str]], List[List[pathlib.Path]]] - Either a file (``"file.nc"``), a string glob in the form - ``"path/to/my/files/*.nc"``, or an explicit list of files to open. - Paths can be given as strings or as pathlib Paths. If concatenation - along more than one dimension is desired, then ``paths`` must be a - nested list-of-lists (see ``combine_nested`` for details). (A string - glob will be expanded to a 1-dimensional list.) - + ds : xr.Dataset + The Dataset. + callable : Optional[Callable], optional + A user specified optional callable function for preprocessing. Returns ------- - Optional[bool] - None if time dimension does not exist, True if CF compliant, or False if - non-CF compliant. - - Notes - ----- - This function only checks one file for multi-file datasets to optimize - performance because it is slower to combine all files then check for CF - compliance. + xr.Dataset + The preprocessed Dataset. """ - first_path = _get_first_path(paths) - ds = xr.open_dataset(first_path, decode_times=False) # type: ignore - - if ds.cf.dims.get("T") is None: - return None - - time = ds.cf["T"] - - # If the time units attr cannot be split, it is not cf_compliant. - try: - units = _split_time_units_attr(time.attrs.get("units"))[0] - except ValueError: - return False - - cf_compliant = units not in NON_CF_TIME_UNITS - - return cf_compliant - - -def _get_first_path(path: Paths) -> Optional[Union[pathlib.Path, str]]: - """Returns the first path from a list of paths. - - Parameters - ---------- - path : Paths - A list of paths. + ds_new = ds.copy() - Returns - ------- - str - Returns the first path from a list of paths. - """ - # FIXME: This function should throw an exception if the first file - # is not a supported type. - # FIXME: The `type: ignore` comments should be removed after properly - # handling the types. - first_file: Optional[Union[pathlib.Path, str]] = None + if callable: + ds_new = callable(ds) - if isinstance(path, str) and "*" in path: - first_file = glob(path)[0] - elif isinstance(path, str) or isinstance(path, pathlib.Path): - first_file = path - elif isinstance(path, list): - if any(isinstance(sublist, list) for sublist in path): - first_file = path[0][0] # type: ignore - else: - first_file = path[0] # type: ignore + if decode_times: + ds_new = decode_time(ds_new) - return first_file + return ds_new def _postprocess_dataset( @@ -583,27 +436,169 @@ def _postprocess_dataset( ValueError If ``lon_orient is not None`` but there are no longitude coordinates. """ + ds = dataset.copy() + if data_var is not None: - dataset = _keep_single_var(dataset, data_var) + ds = _keep_single_var(dataset, data_var) if center_times: - if dataset.cf.dims.get("T") is not None: - dataset = center_times_func(dataset) - else: - raise ValueError("This dataset does not have a time coordinates to center.") + ds = center_times_func(dataset) if add_bounds: - dataset = dataset.bounds.add_missing_bounds() + ds = ds.bounds.add_missing_bounds() if lon_orient is not None: - if dataset.cf.dims.get("X") is not None: - dataset = swap_lon_axis(dataset, to=lon_orient, sort_ascending=True) - else: - raise ValueError( - "This dataset does not have longitude coordinates to reorient." - ) + ds = swap_lon_axis(ds, to=lon_orient, sort_ascending=True) + + return ds + + +def _is_decodable(coords: xr.DataArray) -> bool: + """Checks if time coordinates are decodable. + + Time coordinates must have a "units" attribute in a supported format + to be decodable. + + Parameters + ---------- + coords : xr.DataArray + The time coordinates. + + Returns + ------- + bool + """ + units = coords.attrs.get("units") + + if units is None: + logger.warning( + f"'{coords.name}' does not have a 'units' attribute set so it " + "could not be decoded. Try setting the 'units' attribute " + "(`ds.{coords.name}.attrs['units']`) and try decoding again." + ) + return False + + if isinstance(units, str) and "since" not in units: + logger.warning( + f"The 'units' attribute ({units}) for '{coords.name}' is not in the " + "supported format 'X since Y', so it could not be decoded." + ) + return False + + return True + + +def _is_decoded(da: xr.DataArray) -> bool: + """Check if a time-based DataArray is decoded. + + This is determined by checking if the `encoding` dictionary has "units" and + "calendar" attributes set. + + Parameters + ---------- + da : xr.DataArray + A time-based DataArray (e.g,. coordinates, bounds) + + Returns + ------- + bool + """ + units = da.encoding.get("units") + calendar = da.encoding.get("calendar") + + return calendar is not None and units is not None + + +def _decode_time(da: xr.DataArray) -> xr.Variable: + """Lazily decodes a DataArray of numerically encoded time with cftime. + + The ``xr.DataArray`` is converted to an ``xr.Variable`` so that + ``xr.coding.variables.lazy_elemwise_func`` can be leveraged to lazily decode + time. + + This function is based on ``xarray.coding.times.CFDatetimeCoder.decode``. + + Parameters + ---------- + coords : xr.DataArray + A DataArray of numerically encoded time. + + Returns + ------- + xr.Variable + A Variable of time decoded as ``cftime`` objects. + """ + variable = as_variable(da) + dims, data, attrs, encoding = unpack_for_decoding(variable) + + units = pop_to(attrs, encoding, "units") + calendar = pop_to(attrs, encoding, "calendar") + + transform = partial(_get_cftime_coords, units=units, calendar=calendar) + data = lazy_elemwise_func(data, transform, np.dtype("object")) + + return xr.Variable(dims, data, attrs, encoding) + + +def _get_cftime_coords(offsets: np.ndarray, units: str, calendar: str) -> np.ndarray: + """Get an array of cftime coordinates starting from a reference date. + + This function calls xarray's ``decode_cf_datetime()`` if the units are + CF compliant because ``decode_cf_datetime()`` considers leap days when + decoding time offsets to ``cftime`` objects. + + For non-CF compliant units ("[months|years] since ..."), this function + performs custom decoding. It flattens the array, performs decoding on the + time offsets, then reshapes the array back to its original shape. + + Parameters + ---------- + offsets : np.ndarray + An array of numerically encoded time offsets from the reference date. + units : str + The time units. + calendar : str + The CF calendar type supported by ``cftime``. This includes "noleap", + "360_day", "365_day", "366_day", "gregorian", "proleptic_gregorian", + "julian", "all_leap", and "standard". + + Returns + ------- + np.ndarray + An array of ``cftime`` coordinates. + """ + units_type, ref_date = units.split(" since ") + + if units_type not in NON_CF_TIME_UNITS: + return decode_cf_datetime(offsets, units, calendar=calendar, use_cftime=True) - return dataset + offsets = np.asarray(offsets) + flat_offsets = offsets.ravel() + + # Convert offsets to `np.float64` to avoid "TypeError: unsupported type + # for timedelta days component: numpy.int64". + flat_offsets = flat_offsets.astype("float") + + # We don't need to do calendar arithmetic here because the units and + # offsets are in "months" or "years", which means leap days should not + # be factored. + ref_datetime: datetime = parser.parse(ref_date, default=datetime(2000, 1, 1)) + times = np.array( + [ + ref_datetime + rd.relativedelta(**{units_type: offset}) + for offset in flat_offsets + ], + dtype="object", + ) + # Convert the array of `datetime` objects into `cftime` objects based on + # the calendar type. + date_type = get_date_type(calendar) + coords = convert_times(times, date_type=date_type) + + # Reshape back to the original shape. + coords = coords.reshape(offsets.shape) + + return coords def _keep_single_var(dataset: xr.Dataset, key: str) -> xr.Dataset: @@ -676,123 +671,3 @@ def _get_data_var(dataset: xr.Dataset, key: str) -> xr.DataArray: raise KeyError(f"The data variable '{key}' does not exist in the Dataset.") return dv.copy() - - -def _preprocess_non_cf_dataset( - ds: xr.Dataset, callable: Optional[Callable] = None -) -> xr.Dataset: - """Preprocessing for each non-CF compliant dataset in ``open_mfdataset()``. - - This function accepts a user specified preprocess function, which is - executed before additional internal preprocessing functions. - - One call is performed to ``decode_non_cf_time()`` for decoding each - dataset's time coordinates and time bounds (if they exist) with non-CF - compliant units. By default, if ``decode_times=False`` is passed, xarray - will concatenate time values using the first dataset's ``units`` attribute. - This is an issue for cases where the numerically encoded time values are the - same and the ``units`` attribute differs between datasets. - - For example, two files have the same time values, but the units of the first - file is "months since 2000-01-01" and the second is "months since - 2001-01-01". Since the first dataset's units are used in xarray for - concatenating datasets, the time values corresponding to the second file - will be dropped since they appear to be the same as the first file. - - Calling ``decode_non_cf_time()`` on each dataset individually before - concatenating solves the aforementioned issue. - - Parameters - ---------- - ds : xr.Dataset - The Dataset. - callable : Optional[Callable], optional - A user specified optional callable function for preprocessing. - - Returns - ------- - xr.Dataset - The preprocessed Dataset. - """ - ds_new = ds.copy() - - if callable: - ds_new = callable(ds) - - # Attempt to decode non-cf-compliant time axis. - ds_new = decode_non_cf_time(ds_new) - - return ds_new - - -def _split_time_units_attr(units_attr: str) -> Tuple[str, str]: - """Splits the time coordinates' units attr into units and reference date. - - Parameters - ---------- - units_attr : str - The units attribute (e.g., "months since 1800-01-01"). - - Returns - ------- - Tuple[str, str] - The units (e.g, "months") and the reference date (e.g., "1800-01-01"). - - Raises - ------ - KeyError - If the time units attribute was not found. - - ValueError - If the time units attribute is not of the form `X since Y`. - """ - if "since" in units_attr: - units, reference_date = units_attr.split(" since ") - else: - raise ValueError( - "This dataset does not have time coordinates of the form 'X since Y'." - ) - - return units, reference_date - - -def _get_cftime_coords( - ref_date: str, offsets: np.ndarray, calendar: str, units: str -) -> np.ndarray: - """Get an array of `cftime` coordinates starting from a reference date. - - Parameters - ---------- - ref_date : str - The starting reference date. - offsets : np.ndarray - An array of numerically encoded time offsets from the reference date. - calendar : str - The CF calendar type supported by ``cftime``. This includes "noleap", - "360_day", "365_day", "366_day", "gregorian", "proleptic_gregorian", - "julian", "all_leap", and "standard". - units : str - The time units. - - Returns - ------- - np.ndarray - An array of `cftime` coordinates. - """ - # Starting from the reference date, create an array of `datetime` objects - # by adding each offset (a numerically encoded value) to the reference date. - # The `parse.parse` default is set to datetime(2000, 1, 1), with each - # component being a placeholder if the value does not exist. For example, 1 - # and 1 are placeholders for month and day if those values don't exist. - ref_datetime: datetime = parser.parse(ref_date, default=datetime(2000, 1, 1)) - offsets = np.array( - [ref_datetime + rd.relativedelta(**{units: offset}) for offset in offsets], - dtype="object", - ) - - # Convert the array of `datetime` objects into `cftime` objects based on - # the calendar type. - date_type = get_date_type(calendar) - coords = convert_times(offsets, date_type=date_type) - - return coords diff --git a/xcdat/regridder/accessor.py b/xcdat/regridder/accessor.py index 515effc3..d7c7e0a5 100644 --- a/xcdat/regridder/accessor.py +++ b/xcdat/regridder/accessor.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, Literal, Tuple +from typing import Any, Dict, Literal, Optional, Tuple import xarray as xr -from xcdat.axis import CFAxisName, get_axis_coord +from xcdat.axis import CFAxisKey, get_dim_coords from xcdat.regridder import regrid2 from xcdat.utils import _has_module @@ -60,7 +60,9 @@ def grid(self) -> xr.Dataset: Raises ------ ValueError - If axis data variable is not correctly identified. + If axis dimension coordinate variable is not correctly identified. + ValueError + If axis has multiple dimensions (only one is expected). """ x, x_bnds = self._get_axis_data("X") y, y_bnds = self._get_axis_data("Y") @@ -80,11 +82,20 @@ def grid(self) -> xr.Dataset: return ds - def _get_axis_data(self, name: CFAxisName) -> Tuple[xr.DataArray, xr.DataArray]: - coord_var = get_axis_coord(self._ds, name) + def _get_axis_data( + self, name: CFAxisKey + ) -> Tuple[xr.DataArray, Optional[xr.DataArray]]: + coord_var = get_dim_coords(self._ds, name) + + if isinstance(coord_var, xr.Dataset): + raise ValueError( + f"Multiple '{name}' axis dims were found in this dataset, " + f"{list(coord_var.dims)}. Please drop the unused dimension(s) before" + "getting grid information." + ) try: - bounds_var = self._ds.bounds.get_bounds(name) + bounds_var = self._ds.bounds.get_bounds(name, coord_var.name) except KeyError: bounds_var = None diff --git a/xcdat/regridder/grid.py b/xcdat/regridder/grid.py index 2a175e6d..8cc146ab 100644 --- a/xcdat/regridder/grid.py +++ b/xcdat/regridder/grid.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr -from xcdat.axis import get_axis_coord +from xcdat.axis import CFAxisKey, get_dim_coords # First 50 zeros for the bessel function # Taken from https://github.com/CDAT/cdms/blob/dd41a8dd3b5bac10a4bfdf6e56f6465e11efc51d/regrid2/Src/_regridmodule.c#L3390-L3402 @@ -378,16 +378,18 @@ def create_global_mean_grid(grid: xr.Dataset) -> xr.Dataset: xr.Dataset A dataset containing the global mean grid. """ - lat = get_axis_coord(grid, "Y") - lat_data = np.array([(lat[0] + lat[-1]) / 2.0]) + lat = get_dim_coords(grid, "Y") + _validate_grid_has_single_axis_dim("X", lat) - lat_bnds = grid.bounds.get_bounds("Y") + lat_data = np.array([(lat[0] + lat[-1]) / 2.0]) + lat_bnds = grid.bounds.get_bounds("Y", var_key=lat.name) lat_bnds = np.array([[lat_bnds[0, 0], lat_bnds[-1, 1]]]) - lon = get_axis_coord(grid, "X") - lon_data = np.array([(lon[0] + lon[-1]) / 2.0]) + lon = get_dim_coords(grid, "X") + _validate_grid_has_single_axis_dim("Y", lon) - lon_bnds = grid.bounds.get_bounds("X") + lon_data = np.array([(lon[0] + lon[-1]) / 2.0]) + lon_bnds = grid.bounds.get_bounds("X", var_key=lon.name) lon_bnds = np.array([[lon_bnds[0, 0], lon_bnds[-1, 1]]]) return create_grid(lat_data, lon_data, lat_bnds=lat_bnds, lon_bnds=lon_bnds) @@ -408,16 +410,22 @@ def create_zonal_grid(grid: xr.Dataset) -> xr.Dataset: xr.Dataset A dataset containing a zonal grid. """ - lon = get_axis_coord(grid, "X") - out_lon_data = np.array([(lon[0] + lon[-1]) / 2.0]) + lon = get_dim_coords(grid, "X") + _validate_grid_has_single_axis_dim("X", lon) - lon_bnds = grid.bounds.get_bounds("X") + out_lon_data = np.array([(lon[0] + lon[-1]) / 2.0]) + lon_bnds = grid.bounds.get_bounds("X", var_key=lon.name) lon_bnds = np.array([[lon_bnds[0, 0], lon_bnds[-1, 1]]]) - lat = get_axis_coord(grid, "Y") - lat_bnds = grid.bounds.get_bounds("Y") + lat = get_dim_coords(grid, "Y") + _validate_grid_has_single_axis_dim("Y", lat) - return create_grid(lat, out_lon_data, lat_bnds=lat_bnds, lon_bnds=lon_bnds) + lat_bnds = grid.bounds.get_bounds("Y", var_key=lat.name) + + # Ignore `Argument 1 to "create_grid" has incompatible type + # "Union[Dataset, DataArray]"; expected "Union[ndarray[Any, Any], DataArray]" + # mypy(error)` because this arg is validated to be a DataArray beforehand. + return create_grid(lat, out_lon_data, lat_bnds=lat_bnds, lon_bnds=lon_bnds) # type: ignore def create_grid( @@ -503,3 +511,32 @@ def create_grid( grid = grid.bounds.add_missing_bounds() return grid + + +def _validate_grid_has_single_axis_dim( + axis: CFAxisKey, coord_var: Union[xr.DataArray, xr.Dataset] +): + """Validates that the grid's axis has a single dimension. + + If the grid has multiple dimensions (e.g., "lat" and "latitude" dims), xcdat + cannot interpret which one to use for grid operations. If ``coord_var`` is + an ``xr.Dataset``, the grid has multiple dimensions. + + Parameters + ---------- + axis : CFAxisKey + The CF axis key ("X", "Y", "T", or "Z"). + coord_var : Union[xr.DataArray, xr.Dataset] + The dimension coordinate variable(s) for the axis. + + Raises + ------ + ValueError + If the grid has multiple dimensions. + """ + if isinstance(coord_var, xr.Dataset): + raise ValueError( + f"Multiple '{axis}' axis dims were found in this dataset, " + f"{list(coord_var.dims)}. Please drop the unused dimension(s) before " + "performing grid operations." + ) diff --git a/xcdat/spatial.py b/xcdat/spatial.py index a57f5b25..c5a15d70 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -16,15 +16,15 @@ import cf_xarray # noqa: F401 import numpy as np import xarray as xr -from dask.array.core import Array from xcdat.axis import ( _align_lon_bounds_to_360, _get_prime_meridian_index, - get_axis_coord, - get_axis_dim, + get_dim_coords, + get_dim_keys, ) from xcdat.dataset import _get_data_var +from xcdat.utils import _if_multidim_dask_array_then_load #: Type alias for a dictionary of axis keys mapped to their bounds. AxisWeights = Dict[Hashable, xr.DataArray] @@ -188,7 +188,7 @@ def average( self._validate_region_bounds("Y", lat_bounds) if lon_bounds is not None: self._validate_region_bounds("X", lon_bounds) - self._weights = self.get_weights(axis, lat_bounds, lon_bounds) + self._weights = self.get_weights(axis, lat_bounds, lon_bounds, data_var) elif isinstance(weights, xr.DataArray): self._weights = weights @@ -205,6 +205,7 @@ def get_weights( axis: List[SpatialAxis], lat_bounds: Optional[RegionAxisBounds] = None, lon_bounds: Optional[RegionAxisBounds] = None, + data_var: Optional[str] = None, ) -> xr.DataArray: """ Get area weights for specified axis keys and an optional target domain. @@ -229,6 +230,11 @@ def get_weights( lon_bounds : Optional[RegionAxisBounds] Tuple of longitude boundaries for regional selection, by default None. + data_var: Optional[str] + The key of the data variable, by default None. Pass this argument + when the dataset has more than one bounds per axis (e.g., "lon" + and "zlon_bnds" for the "X" axis), or you want weights for a + specific data variable. Returns ------- @@ -265,7 +271,16 @@ def get_weights( axis_weights: AxisWeights = {} for key in axis: - d_bounds = self._dataset.bounds.get_bounds(key).copy() + d_bounds = self._dataset.bounds.get_bounds(axis=key, var_key=data_var) + + if isinstance(d_bounds, xr.Dataset): + raise TypeError( + "Generating area weights requires a single bounds per " + f"axis, but the dataset has multiple bounds for the '{key}' axis " + f"{list(d_bounds.data_vars)}. Pass a `data_var` key " + "to reference a specific data variable's axis bounds." + ) + # The logic for generating longitude weights depends on the # bounds being ordered such that d_bounds[:, 0] < d_bounds[:, 1]. # They are re-ordered (if need be) for the purpose of creating @@ -307,7 +322,7 @@ def _validate_axis_arg(self, axis: List[SpatialAxis]): ) # Check the axis coordinate variable exists in the Dataset. - get_axis_coord(self._dataset, key) + get_dim_coords(self._dataset, key) def _force_domain_order_low_to_high(self, domain_bounds: xr.DataArray): """Reorders the ``domain_bounds`` low-to-high. @@ -524,12 +539,8 @@ def _swap_lon_axis( """ lon_swap = lon.copy() - # If chunking, must convert convert the xarray data structure from lazy - # Dask arrays into eager, in-memory NumPy arrays before performing - # manipulations on the data. Otherwise, it raises `NotImplementedError - # xarray can't set arrays with multiple array indices to dask yet`. - if type(lon_swap.data) == Array: - lon_swap.load() + if isinstance(lon_swap, xr.DataArray): + _if_multidim_dask_array_then_load(lon_swap) # Must set keep_attrs=True or the xarray DataArray attrs will get # dropped. This has no affect on NumPy arrays. @@ -582,8 +593,7 @@ def _scale_domain_to_region( d_bounds = domain_bounds.copy() r_bounds = region_bounds.copy() - if type(d_bounds.data) == Array: - d_bounds.load() + _if_multidim_dask_array_then_load(d_bounds) # Since longitude is circular, the logic depends on whether the region # spans across the prime meridian or not. If a region does not include @@ -692,7 +702,7 @@ def _validate_weights(self, data_var: xr.DataArray, axis: List[SpatialAxis]): # Check the weights includes the same axis as the data variable. for key in axis: - dim_name = get_axis_dim(data_var, key) + dim_name = get_dim_keys(data_var, key) if dim_name not in self._weights.dims: raise KeyError( f"The weights DataArray does not include an {key} axis, or the " @@ -741,7 +751,7 @@ def _averager(self, data_var: xr.DataArray, axis: List[SpatialAxis]): dim = [] for key in axis: - dim.append(get_axis_dim(data_var, key)) + dim.append(get_dim_keys(data_var, key)) with xr.set_options(keep_attrs=True): weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 803f934c..f5361cdf 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -11,7 +11,7 @@ from xarray.core.groupby import DataArrayGroupBy from xcdat import bounds # noqa: F401 -from xcdat.axis import get_axis_coord +from xcdat.axis import get_dim_coords from xcdat.dataset import _get_data_var from xcdat.logger import setup_custom_logger @@ -146,20 +146,6 @@ class TemporalAccessor: def __init__(self, dataset: xr.Dataset): self._dataset: xr.Dataset = dataset - # The name of the time dimension. - self._dim = get_axis_coord(self._dataset, "T").name - - try: - self.calendar = self._dataset[self._dim].encoding["calendar"] - self.date_type = get_date_type(self.calendar) - except KeyError: - raise KeyError( - "This dataset's time coordinates do not have a 'calendar' encoding " - "attribute set, which might indicate that the time coordinates were not " - "decoded to datetime objects. Ensure that the time coordinates are " - "decoded before performing temporal averaging operations." - ) - def average(self, data_var: str, weighted: bool = True, keep_weights: bool = False): """ Returns a Dataset with the average of a data variable and the time @@ -205,6 +191,8 @@ def average(self, data_var: str, weighted: bool = True, keep_weights: bool = Fal >>> ds_month = ds.temporal.average("ts", freq="month") >>> ds_month.ts """ + self._set_data_var_attrs(data_var) + freq = self._infer_freq() return self._averager(data_var, "average", freq, weighted, keep_weights) @@ -340,6 +328,8 @@ def group_average( 'drop_incomplete_djf': 'False' } """ + self._set_data_var_attrs(data_var) + return self._averager( data_var, "group_average", freq, weighted, keep_weights, season_config ) @@ -475,6 +465,8 @@ def climatology( 'drop_incomplete_djf': 'False' } """ + self._set_data_var_attrs(data_var) + return self._averager( data_var, "climatology", freq, weighted, keep_weights, season_config ) @@ -609,14 +601,15 @@ def departures( } """ ds = self._dataset.copy() - self._set_obj_attrs("departures", freq, weighted, season_config) + self._set_data_var_attrs(data_var) + self._set_arg_attrs("departures", freq, weighted, season_config) # Preprocess the dataset based on method argument values. ds = self._preprocess_dataset(ds) # Group the observation data variable. dv_obs = _get_data_var(ds, data_var) - self._labeled_time = self._label_time_coords(dv_obs[self._dim]) + self._labeled_time = self._label_time_coords(dv_obs[self.dim]) dv_obs_grouped = self._group_data(dv_obs) # Calculate the climatology of the data variable. @@ -630,7 +623,7 @@ def departures( # to work. Otherwise, the error below is thrown: `ValueError: # incompatible dimensions for a grouped binary operation: the group # variable '' is not a dimension on the other argument` - dv_climo = dv_climo.rename({self._dim: self._labeled_time.name}) + dv_climo = dv_climo.rename({self.dim: self._labeled_time.name}) # Calculate the departures for the data variable, which uses the formula # observation - climatology. @@ -666,7 +659,7 @@ def _infer_freq(self) -> Frequency: Frequency The time frequency. """ - time_coords = self._dataset[self._dim] + time_coords = self._dataset[self.dim] min_delta = pd.to_timedelta(np.diff(time_coords).min(), unit="ns") if min_delta < pd.Timedelta(days=1): @@ -689,15 +682,14 @@ def _averager( ) -> xr.Dataset: """Averages a data variable based on the averaging mode and frequency.""" ds = self._dataset.copy() - self._set_obj_attrs(mode, freq, weighted, season_config) + self._set_arg_attrs(mode, freq, weighted, season_config) # Preprocess the dataset based on method argument values. ds = self._preprocess_dataset(ds) - # Get the data variable and time bounds from the dataset and perform - # the averaging operation. + # Get the data variable and the required time axis metadata. dv = _get_data_var(ds, data_var) - time_bounds = ds.bounds.get_bounds("T") + time_bounds = ds.bounds.get_bounds("T", var_key=dv.name) if self._mode == "average": dv = self._average(dv, time_bounds) @@ -708,7 +700,7 @@ def _averager( # it becomes obsolete after the data variable is averaged. When the # averaged data variable is added to the dataset, the new time dimension # and its associated coordinates are also added. - ds = ds.drop_dims(self._dim) + ds = ds.drop_dims(self.dim) ds[dv.name] = dv if keep_weights: @@ -716,7 +708,39 @@ def _averager( return ds - def _set_obj_attrs( + def _set_data_var_attrs(self, data_var: str): + """Set data variable metadata as object attributes. + + This includes the name of the data variable, the time axis dimension + name, the calendar type and its corresponding cftime object (date type). + + Parameters + ---------- + data_var : str + The key of the data variable. + + Raises + ------ + KeyError + If the data variable does not have a "calendar" encoding attribute. + """ + dv = _get_data_var(self._dataset, data_var) + + self.data_var = data_var + self.dim = get_dim_coords(dv, "T").name + + try: + self.calendar = dv[self.dim].encoding["calendar"] + self.date_type = get_date_type(self.calendar) + except KeyError: + raise KeyError( + f"The 'calendar' encoding attribute is not set on the '{data_var}' " + f"time coordinate variable ({self.dim}). This might indicate that the " + "time coordinates were not decoded, which is required for temporal " + "averaging operations. " + ) + + def _set_arg_attrs( self, mode: Mode, freq: Frequency, @@ -896,19 +920,19 @@ def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset: # method concatenates the time dimension to non-time dimension data # vars, which is not a desired behavior. ds = dataset.copy() - ds_time = ds.get([v for v in ds.data_vars if self._dim in ds[v].dims]) # type: ignore - ds_no_time = ds.get([v for v in ds.data_vars if self._dim not in ds[v].dims]) # type: ignore + ds_time = ds.get([v for v in ds.data_vars if self.dim in ds[v].dims]) # type: ignore + ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore start_year, end_year = ( - ds[self._dim].dt.year.values[0], - ds[self._dim].dt.year.values[-1], + ds[self.dim].dt.year.values[0], + ds[self.dim].dt.year.values[-1], ) incomplete_seasons = (f"{start_year}-01", f"{start_year}-02", f"{end_year}-12") for year_month in incomplete_seasons: try: - coord_pt = ds.loc[dict(time=year_month)][self._dim][0] - ds_time = ds_time.where(ds_time[self._dim] != coord_pt, drop=True) + coord_pt = ds.loc[dict(time=year_month)][self.dim][0] + ds_time = ds_time.where(ds_time[self.dim] != coord_pt, drop=True) except (KeyError, IndexError): continue @@ -935,7 +959,7 @@ def _drop_leap_days(self, ds: xr.Dataset): xr.Dataset """ ds = ds.sel( # type: ignore - **{self._dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))} + **{self.dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))} ) return ds @@ -961,9 +985,9 @@ def _average( with xr.set_options(keep_attrs=True): if self._weighted: self._weights = self._get_weights(time_bounds) - dv = dv.weighted(self._weights).mean(dim=self._dim) + dv = dv.weighted(self._weights).mean(dim=self.dim) else: - dv = dv.mean(dim=self._dim) + dv = dv.mean(dim=self.dim) dv = self._add_operation_attrs(dv) @@ -990,7 +1014,7 @@ def _group_average( # Label the time coordinates for grouping weights and the data variable # values. - self._labeled_time = self._label_time_coords(dv[self._dim]) + self._labeled_time = self._label_time_coords(dv[self.dim]) if self._weighted: self._weights = self._get_weights(time_bounds) @@ -1020,14 +1044,14 @@ def _group_average( # with "year_season". This dimension needs to be renamed back to # the original time dimension name before the data variable is added # back to the dataset so that the original name is preserved. - dv = dv.rename({self._labeled_time.name: self._dim}) + dv = dv.rename({self._labeled_time.name: self.dim}) # After grouping and aggregating, the grouped time dimension's # attributes are removed. Xarray's `keep_attrs=True` option only keeps # attributes for data variables and not their coordinates, so the # coordinate attributes have to be restored manually. - dv[self._dim].attrs = self._labeled_time.attrs - dv[self._dim].encoding = self._labeled_time.encoding + dv[self.dim].attrs = self._labeled_time.attrs + dv[self.dim].encoding = self._labeled_time.encoding dv = self._add_operation_attrs(dv) @@ -1076,13 +1100,14 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray: # only select the general DType and not details such as the byte order # or time unit (with rare exceptions see release notes). To avoid this # warning please use the scalar types `np.float64`, or string notation.` - if type(time_lengths.values == Array): + if isinstance(time_lengths.data, Array): time_lengths.load() + time_lengths = time_lengths.astype(np.float64) grouped_time_lengths = self._group_data(time_lengths) weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum() - weights.name = f"{self._dim}_wts" + weights.name = f"{self.dim}_wts" # Validate the sum of weights for each group is 1.0. actual_sum = self._group_data(weights).sum().values @@ -1110,7 +1135,7 @@ def _group_data(self, data_var: xr.DataArray) -> DataArrayGroupBy: dv = data_var.copy() if self._mode == "average": - dv_gb = dv.groupby(f"{self._dim}.{self._freq}") + dv_gb = dv.groupby(f"{self.dim}.{self._freq}") else: dv.coords[self._labeled_time.name] = self._labeled_time dv_gb = dv.groupby(self._labeled_time.name) @@ -1167,11 +1192,11 @@ def _label_time_coords(self, time_coords: xr.DataArray) -> xr.DataArray: time_grouped = xr.DataArray( name="_".join(df_dt_components.columns), data=dt_objects, - coords={self._dim: time_coords[self._dim]}, - dims=[self._dim], - attrs=time_coords[self._dim].attrs, + coords={self.dim: time_coords[self.dim]}, + dims=[self.dim], + attrs=time_coords[self.dim].attrs, ) - time_grouped.encoding = time_coords[self._dim].encoding + time_grouped.encoding = time_coords[self.dim].encoding return time_grouped @@ -1215,7 +1240,7 @@ def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: # Use the TIME_GROUPS dictionary to determine which components # are needed to form the labeled time coordinates. for component in TIME_GROUPS[self._mode][self._freq]: - df[component] = time_coords[f"{self._dim}.{component}"].values + df[component] = time_coords[f"{self.dim}.{component}"].values # The season frequency requires additional datetime components for # processing, which are later removed before time coordinates are @@ -1224,11 +1249,11 @@ def _get_df_dt_components(self, time_coords: xr.DataArray) -> pd.DataFrame: # `TIME_GROUPS` represents the final grouping labels. if self._freq == "season": if self._mode in ["climatology", "departures"]: - df["year"] = time_coords[f"{self._dim}.year"].values - df["month"] = time_coords[f"{self._dim}.month"].values + df["year"] = time_coords[f"{self.dim}.year"].values + df["month"] = time_coords[f"{self.dim}.month"].values if self._mode == "group_average": - df["month"] = time_coords[f"{self._dim}.month"].values + df["month"] = time_coords[f"{self.dim}.month"].values df = self._process_season_df(df) @@ -1475,7 +1500,7 @@ def _keep_weights(self, ds: xr.Dataset) -> xr.Dataset: # avoid conflict with the grouped time coordinates in the Dataset (can # have a different shape). if self._mode in ["group_average", "climatology"]: - self._weights = self._weights.rename({self._dim: f"{self._dim}_original"}) + self._weights = self._weights.rename({self.dim: f"{self.dim}_original"}) # Only keep the original time coordinates, not the ones labeled # by group. self._weights = self._weights.drop_vars(self._labeled_time.name) @@ -1483,7 +1508,7 @@ def _keep_weights(self, ds: xr.Dataset) -> xr.Dataset: # because the final departures Dataset has the original time coordinates # restored after performing grouped subtraction. elif self._mode == "departures": - self._weights = self._weights.rename({f"{self._dim}_original": self._dim}) + self._weights = self._weights.rename({f"{self.dim}_original": self.dim}) ds[self._weights.name] = self._weights diff --git a/xcdat/utils.py b/xcdat/utils.py index 162108e0..af4113a3 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -1,8 +1,9 @@ import importlib import json -from typing import Dict, List +from typing import Dict, List, Optional, Union import xarray as xr +from dask.array.core import Array def compare_datasets(ds1: xr.Dataset, ds2: xr.Dataset) -> Dict[str, List[str]]: @@ -107,3 +108,27 @@ def _has_module(modname: str) -> bool: # pragma: no cover has = False return has + + +def _if_multidim_dask_array_then_load( + obj: Union[xr.DataArray, xr.Dataset] +) -> Optional[Union[xr.DataArray, xr.Dataset]]: + """ + If the underlying array for an xr.DataArray or xr.Dataset is a + multidimensional, lazy Dask Array, load it into an in-memory NumPy array. + + This function must be called before manipulating values in a + multidimensional Dask Array, which xarray does not support directly. + Otherwise, it raises `NotImplementedError xarray can't set arrays with + multiple array indices to dask yet`. + + Parameters + ---------- + obj : Union[xr.DataArray, xr.Dataset] + The xr.DataArray or xr.Dataset. If the xarray object is chunked, + the underlying array will be a Dask Array. + """ + if isinstance(obj.data, Array) and obj.ndim > 1: + return obj.load() + + return None