Skip to content

Commit

Permalink
Add NetCDF3 dtype coercion for unsigned integer types (#4018)
Browse files Browse the repository at this point in the history
* In netcdf3 backend, also coerce unsigned integer dtypes

* Adjust test for netcdf3 rountrip to include coercion

This might be a bit too general for what is required at this point,
though ... 🤔

* Add test for failing dtype coercion

* Add What's New entry for issue #4014 and PR #4018

* Move netcdf3-specific test to NetCDF3Only class

Also uses a class variable for definition of netcdf3 formats now.

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
blsqr and dcherian authored May 20, 2020
1 parent cb90d55 commit 5c04ebf
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ New Features
the :py:class:`~core.accessor_dt.DatetimeAccessor` (:pull:`3935`). This
feature requires cftime version 1.1.0 or greater. By
`Spencer Clark <https://github.com/spencerkclark>`_.
- For the netCDF3 backend, added dtype coercions for unsigned integer types.
(:issue:`4014`, :pull:`4018`)
By `Yunus Sevinchan <https://github.com/blsqr>`_
- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases
where the result of a computation could not be inferred automatically.
By `Deepak Cherian <https://github.com/dcherian>`_
Expand Down
26 changes: 19 additions & 7 deletions xarray/backends/netcdf3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@

# These data-types aren't supported by netCDF3, so they are automatically
# coerced instead as indicated by the "coerce_nc3_dtype" function
_nc3_dtype_coercions = {"int64": "int32", "bool": "int8"}
_nc3_dtype_coercions = {
"int64": "int32",
"uint64": "int32",
"uint32": "int32",
"uint16": "int16",
"uint8": "int8",
"bool": "int8",
}

# encode all strings as UTF-8
STRING_ENCODING = "utf-8"
Expand All @@ -37,12 +44,17 @@
def coerce_nc3_dtype(arr):
"""Coerce an array to a data type that can be stored in a netCDF-3 file
This function performs the following dtype conversions:
int64 -> int32
bool -> int8
Data is checked for equality, or equivalence (non-NaN values) with
`np.allclose` with the default keyword arguments.
This function performs the dtype conversions as specified by the
``_nc3_dtype_coercions`` mapping:
int64 -> int32
uint64 -> int32
uint32 -> int32
uint16 -> int16
uint8 -> int8
bool -> int8
Data is checked for equality, or equivalence (non-NaN values) using the
``(cast_array == original_array).all()``.
"""
dtype = str(arr.dtype)
if dtype in _nc3_dtype_coercions:
Expand Down
36 changes: 31 additions & 5 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
save_mfdataset,
)
from xarray.backends.common import robust_getitem
from xarray.backends.netcdf3 import _nc3_dtype_coercions
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
from xarray.backends.pydap_ import PydapDataStore
from xarray.coding.variables import SerializationWarning
Expand Down Expand Up @@ -227,7 +228,27 @@ def __getitem__(self, key):


class NetCDF3Only:
pass
netcdf3_formats = ("NETCDF3_CLASSIC", "NETCDF3_64BIT")

@requires_scipy
def test_dtype_coercion_error(self):
"""Failing dtype coercion should lead to an error"""
for dtype, format in itertools.product(
_nc3_dtype_coercions, self.netcdf3_formats
):
if dtype == "bool":
# coerced upcast (bool to int8) ==> can never fail
continue

# Using the largest representable value, create some data that will
# no longer compare equal after the coerced downcast
maxval = np.iinfo(dtype).max
x = np.array([0, 1, 2, maxval], dtype=dtype)
ds = Dataset({"x": ("t", x, {})})

with create_tmp_file(allow_cleanup_failure=False) as path:
with pytest.raises(ValueError, match="could not safely cast"):
ds.to_netcdf(path, format=format)


class DatasetIOBase:
Expand Down Expand Up @@ -296,9 +317,14 @@ def test_write_store(self):
def check_dtypes_roundtripped(self, expected, actual):
for k in expected.variables:
expected_dtype = expected.variables[k].dtype
if isinstance(self, NetCDF3Only) and expected_dtype == "int64":
# downcast
expected_dtype = np.dtype("int32")

# For NetCDF3, the backend should perform dtype coercion
if (
isinstance(self, NetCDF3Only)
and str(expected_dtype) in _nc3_dtype_coercions
):
expected_dtype = np.dtype(_nc3_dtype_coercions[str(expected_dtype)])

actual_dtype = actual.variables[k].dtype
# TODO: check expected behavior for string dtypes more carefully
string_kinds = {"O", "S", "U"}
Expand Down Expand Up @@ -2156,7 +2182,7 @@ def test_cross_engine_read_write_netcdf3(self):
valid_engines.add("scipy")

for write_engine in valid_engines:
for format in ["NETCDF3_CLASSIC", "NETCDF3_64BIT"]:
for format in self.netcdf3_formats:
with create_tmp_file() as tmp_file:
data.to_netcdf(tmp_file, format=format, engine=write_engine)
for read_engine in valid_engines:
Expand Down

0 comments on commit 5c04ebf

Please sign in to comment.