Skip to content

Commit

Permalink
Fix open_mfdataset() dropping time encoding attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Aug 11, 2022
1 parent 1d242e5 commit a9916df
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 38 deletions.
40 changes: 40 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,46 @@ def setUp(self, tmp_path):
self.file_path1 = f"{dir}/file1.nc"
self.file_path2 = f"{dir}/file2.nc"

def test_mfdataset_keeps_time_encoding_dict(self):
# FIXME: This test always passes because `xr.open_mfdatset()` always
# keeps the time encoding attrs, which isn't the expected behavior.
# Based on this test, if datasets are generated in xarray and written
# out with `to_netcdf()` and then opened and merged using
# `xr.open_mfdataset()`, the time encoding attributes are not dropped.
# On the other hand, if multiple real world datasets that did not
# originate from xarray (written out with `.to_netcdf()`) are opened
# using `xr.open_mfdataset()`, the time encoding attrs are dropped.
# (Refer to https://github.com/pydata/xarray/issues/2436). My theory is
# that xarray maintains the time encoding attrs if datasets are written
# out with `.to_netcdf()`, and drops it for other cases such
# as opening multiple datasets from other sources.
ds1 = generate_dataset(cf_compliant=True, has_bounds=True)
ds1.to_netcdf(self.file_path1)
ds2 = generate_dataset(cf_compliant=True, has_bounds=True)
ds2 = ds2.rename_vars({"ts": "tas"})
ds2.to_netcdf(self.file_path2)

result = open_mfdataset([self.file_path1, self.file_path2], decode_times=True)

expected = ds1.copy().merge(ds2.copy())
expected.time.encoding = {
"zlib": False,
"shuffle": False,
"complevel": 0,
"fletcher32": False,
"contiguous": True,
"chunksizes": None,
# Set source as result source because it changes every test run.
"source": result.time.encoding["source"],
"original_shape": (15,),
"dtype": np.dtype(np.int64),
"units": "hours since 2000-01-16 12:00:00",
"calendar": "proleptic_gregorian",
}

assert result.identical(expected)
assert result.time.encoding == expected.time.encoding

def test_non_cf_compliant_time_is_not_decoded(self):
ds1 = generate_dataset(cf_compliant=False, has_bounds=True)
ds1.to_netcdf(self.file_path1)
Expand Down
144 changes: 106 additions & 38 deletions xcdat/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from functools import partial
from glob import glob
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple, Union

import numpy as np
import xarray as xr
Expand All @@ -14,14 +14,24 @@

from xcdat import bounds # noqa: F401
from xcdat.axis import center_times as center_times_func
from xcdat.axis import get_axis_coord, swap_lon_axis
from xcdat.axis import get_axis_coord, get_axis_dim, swap_lon_axis
from xcdat.logger import setup_custom_logger

logger = setup_custom_logger(__name__)

#: List of non-CF compliant time units.
NON_CF_TIME_UNITS: List[str] = ["months", "years"]

# Type annotation for the `paths` arg.
Paths = Union[
str,
pathlib.Path,
List[str],
List[pathlib.Path],
List[List[str]],
List[List[pathlib.Path]],
]


def open_dataset(
path: str,
Expand Down Expand Up @@ -87,10 +97,9 @@ def open_dataset(
"""
if decode_times:
cf_compliant_time: Optional[bool] = _has_cf_compliant_time(path)
# xCDAT attempts to decode non-CF compliant time coordinates.
if cf_compliant_time is False:
# XCDAT handles decoding time values with non-CF units.
ds = xr.open_dataset(path, decode_times=False, **kwargs)
# attempt to decode non-cf-compliant time axis
ds = decode_non_cf_time(ds)
else:
ds = xr.open_dataset(path, decode_times=True, **kwargs)
Expand All @@ -103,14 +112,7 @@ def open_dataset(


def open_mfdataset(
paths: Union[
str,
pathlib.Path,
List[str],
List[pathlib.Path],
List[List[str]],
List[List[pathlib.Path]],
],
paths: Paths,
data_var: Optional[str] = None,
add_bounds: bool = True,
decode_times: bool = True,
Expand Down Expand Up @@ -201,10 +203,19 @@ def open_mfdataset(
.. [2] https://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html
"""
# `xr.open_mfdataset()` drops the time coordinates encoding dictionary if
# multiple files are merged with `decode_times=True` (refer to
# https://github.com/pydata/xarray/issues/2436). The workaround is to store
# the time encoding from the first dataset as a variable, and add the time
# encoding back to final merged dataset in the postprocessing function.
time_encoding = None

if decode_times:
time_encoding = _keep_time_encoding(paths)

cf_compliant_time: Optional[bool] = _has_cf_compliant_time(paths)
# XCDAT handles decoding time values with non-CF units using the
# preprocess kwarg.
# xCDAT attempts to decode non-CF compliant time coordinates using the
# preprocess keyword arg with `xr.open_mfdataset()`.
if cf_compliant_time is False:
decode_times = False
preprocess = partial(_preprocess_non_cf_dataset, callable=preprocess)
Expand All @@ -216,7 +227,9 @@ def open_mfdataset(
preprocess=preprocess,
**kwargs,
)
ds = _postprocess_dataset(ds, data_var, center_times, add_bounds, lon_orient)
ds = _postprocess_dataset(
ds, data_var, center_times, add_bounds, lon_orient, time_encoding
)

return ds

Expand Down Expand Up @@ -393,16 +406,41 @@ def decode_non_cf_time(dataset: xr.Dataset) -> xr.Dataset:
return ds


def _has_cf_compliant_time(
path: Union[
str,
pathlib.Path,
List[str],
List[pathlib.Path],
List[List[str]],
List[List[pathlib.Path]],
]
) -> Optional[bool]:
def _keep_time_encoding(paths: Paths) -> Dict[Hashable, Any]:
"""
Returns the time encoding attributes from the first dataset in a list of
paths.
Time encoding information is critical for several xCDAT operations such as
temporal averaging (e.g., uses the "calendar" attr). This function is a
workaround to the undesired xarray behavior/quirk with
`xr.open_mfdataset()`, which drops the `.encoding` dict from the final
merged dataset (refer to https://github.com/pydata/xarray/issues/2436).
Parameters
----------
paths: Paths
The paths to the dataset(s).
Returns
-------
Dict[Hashable, Any]
The time encoding dictionary.
"""
first_path = _get_first_path(paths)

# xcdat.open_dataset() is called instead of xr.open_dataset() because
# we want to handle decoding non-CF compliant as well.
# FIXME: Remove `type: ignore` comment after properly handling the type
# annotations in `_get_first_path()`.
ds = open_dataset(first_path, decode_times=True, add_bounds=False) # type: ignore

time_coord = get_axis_coord(ds, "T")

return time_coord.encoding


def _has_cf_compliant_time(paths: Paths) -> Optional[bool]:
"""Checks if a dataset has time coordinates with CF compliant units.
If the dataset does not contain a time dimension, None is returned.
Expand Down Expand Up @@ -432,19 +470,8 @@ def _has_cf_compliant_time(
performance because it is slower to combine all files then check for CF
compliance.
"""
first_file: Optional[Union[pathlib.Path, str]] = None

if isinstance(path, str) and "*" in path:
first_file = glob(path)[0]
elif isinstance(path, str) or isinstance(path, pathlib.Path):
first_file = path
elif isinstance(path, list):
if any(isinstance(sublist, list) for sublist in path):
first_file = path[0][0] # type: ignore
else:
first_file = path[0] # type: ignore

ds = xr.open_dataset(first_file, decode_times=False)
first_path = _get_first_path(paths)
ds = xr.open_dataset(first_path, decode_times=False)

if ds.cf.dims.get("T") is None:
return None
Expand All @@ -462,12 +489,45 @@ def _has_cf_compliant_time(
return cf_compliant


def _get_first_path(path: Paths) -> Optional[Union[pathlib.Path, str]]:
"""Returns the first path from a list of paths.
Parameters
----------
path : Paths
A list of paths.
Returns
-------
str
Returns the first path from a list of paths.
"""
# FIXME: This function should throw an exception if the first file
# is not a supported type.
# FIXME: The `type: ignore` comments should be removed after properly
# handling the types.
first_file: Optional[Union[pathlib.Path, str]] = None

if isinstance(path, str) and "*" in path:
first_file = glob(path)[0]
elif isinstance(path, str) or isinstance(path, pathlib.Path):
first_file = path
elif isinstance(path, list):
if any(isinstance(sublist, list) for sublist in path):
first_file = path[0][0] # type: ignore
else:
first_file = path[0] # type: ignore

return first_file


def _postprocess_dataset(
dataset: xr.Dataset,
data_var: Optional[str] = None,
center_times: bool = False,
add_bounds: bool = True,
lon_orient: Optional[Tuple[float, float]] = None,
time_encoding: Optional[Dict[Hashable, Any]] = None,
) -> xr.Dataset:
"""Post-processes a Dataset object.
Expand All @@ -494,6 +554,10 @@ def _postprocess_dataset(
* (-180, 180): represents [-180, 180) in math notation
* (0, 360): represents [0, 360) in math notation
time_encoding: Optional[Dict[Hashable, Any]], optional
The encoding information for the decoded time coordinates (if the
Dataset has a time axis), by default None.
Returns
-------
xr.Dataset
Expand Down Expand Up @@ -526,6 +590,10 @@ def _postprocess_dataset(
"This dataset does not have longitude coordinates to reorient."
)

if time_encoding is not None:
time_dim = get_axis_dim(dataset, "T")
dataset[time_dim].encoding = time_encoding

return dataset


Expand Down

0 comments on commit a9916df

Please sign in to comment.