diff --git a/doc/api.rst b/doc/api.rst index dbca0e2563f..433aa93c9de 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -18,6 +18,7 @@ Top-level functions broadcast concat merge + where set_options full_like zeros_like diff --git a/doc/computation.rst b/doc/computation.rst index 80dfc115d07..6185ca401f4 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -25,7 +25,7 @@ numpy) over all array values: .. ipython:: python - arr = xr.DataArray(np.random.randn(2, 3), + arr = xr.DataArray(np.random.RandomState(0).randn(2, 3), [('x', ['a', 'b']), ('y', [10, 20, 30])]) arr - 3 abs(arr) @@ -39,6 +39,12 @@ __ http://docs.scipy.org/doc/numpy/reference/ufuncs.html np.sin(arr) +Use :py:func:`~xarray.where` to conditionally switch between values: + +.. ipython:: python + + xr.where(arr > 0, 'positive', 'negative') + Data arrays also implement many :py:class:`numpy.ndarray` methods: .. ipython:: python diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f66a32d8c06..7d06a49d486 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -32,6 +32,29 @@ Enhancements - Speed-up (x 100) of :py:func:`~xarray.conventions.decode_cf_datetime`. By `Christian Chwala `_. +- New function :py:func:`~xarray.where` for conditionally switching between + values in xarray objects, like :py:func:`numpy.where`: + + .. ipython:: + :verbatim: + + In [1]: import xarray as xr + + In [2]: arr = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=('x', 'y')) + + In [3]: xr.where(arr % 2, 'even', 'odd') + Out[3]: + + array([['even', 'odd', 'even'], + ['odd', 'even', 'odd']], + dtype='`_. + Bug fixes ~~~~~~~~~ @@ -49,7 +72,7 @@ Bug fixes - Fix :py:func:`xarray.testing.assert_allclose` to actually use ``atol`` and ``rtol`` arguments when called on ``DataArray`` objects. - By `Stephan Hoyer `_. + By `Stephan Hoyer `_. .. _whats-new.0.9.6: diff --git a/xarray/__init__.py b/xarray/__init__.py index 950ca3403ea..654ed77b28a 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -6,6 +6,7 @@ from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine +from .core.computation import where from .core.extensions import (register_dataarray_accessor, register_dataset_accessor) from .core.variable import as_variable, Variable, IndexVariable, Coordinate diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py index 7360b9764ae..59aca6b67f0 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessors.py @@ -2,11 +2,9 @@ from __future__ import division from __future__ import print_function -from .common import is_datetime_like +from .dtypes import is_datetime_like from .pycompat import dask_array_type -from functools import partial - import numpy as np import pandas as pd @@ -147,4 +145,4 @@ def f(self, dtype=dtype): time = _tslib_field_accessor( "time", "Timestamps corresponding to datetimes", object - ) \ No newline at end of file + ) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 5770ec2cd50..44ad5faf5e2 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -7,8 +7,9 @@ import numpy as np -from . import duck_array_ops, utils -from .common import _maybe_promote +from . import duck_array_ops +from . import dtypes +from . import utils from .indexing import get_indexer from .pycompat import iteritems, OrderedDict, suppress from .utils import is_full_slice, is_dict_like @@ -368,7 +369,7 @@ def var_indexers(var, indexers): if any_not_full_slices(assign_to): # there are missing values to in-fill data = var[assign_from].data - dtype, fill_value = _maybe_promote(var.dtype) + dtype, fill_value = dtypes.maybe_promote(var.dtype) if isinstance(data, np.ndarray): shape = tuple(new_sizes.get(dim, size) diff --git a/xarray/core/common.py b/xarray/core/common.py index d61e2cdb15f..090d33cc04a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -5,7 +5,9 @@ import pandas as pd from .pycompat import basestring, suppress, dask_array_type, OrderedDict +from . import dtypes from . import formatting +from . import ops from .utils import SortedKeysDict, not_implemented, Frozen @@ -557,64 +559,74 @@ def resample(self, freq, dim, how='mean', skipna=None, closed=None, result = result.rename({RESAMPLE_DIM: dim.name}) return result - def where(self, cond, other=None, drop=False): - """Return an object of the same shape with all entries where cond is - True and all other entries masked. + def where(self, cond, other=dtypes.NA, drop=False): + """Filter elements from this object according to a condition. This operation follows the normal broadcasting and alignment rules that xarray uses for binary arithmetic. Parameters ---------- - cond : boolean DataArray or Dataset - other : unimplemented, optional - Unimplemented placeholder for compatibility with future - numpy / pandas versions + cond : DataArray or Dataset with boolean dtype + Locations at which to preserve this object's values. + other : scalar, DataArray or Dataset, optional + Value to use for locations in this object where ``cond`` is False. + By default, these locations filled with NA. drop : boolean, optional - Coordinate labels that only correspond to NA values should be - dropped + If True, coordinate labels that only correspond to False values of + the condition are dropped from the result. Mutually exclusive with + ``other``. Returns ------- - same type as caller or if drop=True same type as caller with dimensions - reduced for dim element where mask is True + Same type as caller. Examples -------- >>> import numpy as np >>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=('x', 'y')) - >>> a.where((a > 6) & (a < 18)) + >>> a.where(a.x + a.y < 4) - array([[ nan, nan, nan, nan, nan], - [ nan, nan, 7., 8., 9.], - [ 10., 11., 12., 13., 14.], - [ 15., 16., 17., nan, nan], + array([[ 0., 1., 2., 3., nan], + [ 5., 6., 7., nan, nan], + [ 10., 11., nan, nan, nan], + [ 15., nan, nan, nan, nan], [ nan, nan, nan, nan, nan]]) - Coordinates: - * y (y) int64 0 1 2 3 4 - * x (x) int64 0 1 2 3 4 - >>> a.where((a > 6) & (a < 18), drop=True) + Dimensions without coordinates: x, y + >>> a.where(a.x + a.y < 5, -1) - array([[ nan, nan, 7., 8., 9.], - [ 10., 11., 12., 13., 14.], - [ 15., 16., 17., nan, nan], - Coordinates: - * x (x) int64 1 2 3 - * y (y) int64 0 1 2 3 4 + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, -1], + [10, 11, 12, -1, -1], + [15, 16, -1, -1, -1], + [20, -1, -1, -1, -1]]) + Dimensions without coordinates: x, y + >>> a.where(a.x + a.y < 4, drop=True) + + array([[ 0., 1., 2., 3.], + [ 5., 6., 7., nan], + [ 10., 11., nan, nan], + [ 15., nan, nan, nan]]) + Dimensions without coordinates: x, y + + See also + -------- + numpy.where : corresponding numpy function + where : equivalent function """ - if other is not None: - raise NotImplementedError("The optional argument 'other' has not " - "yet been implemented") + from .alignment import align + from .dataarray import DataArray + from .dataset import Dataset if drop: - from .dataarray import DataArray - from .dataset import Dataset - from .alignment import align + if other is not dtypes.NA: + raise ValueError('cannot set `other` if drop=True') if not isinstance(cond, (Dataset, DataArray)): - raise TypeError("Cond argument is %r but must be a %r or %r" % + raise TypeError("cond argument is %r but must be a %r or %r" % (cond, Dataset, DataArray)) + # align so we can use integer indexing self, cond = align(self, cond) @@ -627,16 +639,11 @@ def where(self, cond, other=None, drop=False): # clip the data corresponding to coordinate dims that are not used nonzeros = zip(clipcond.dims, np.nonzero(clipcond.values)) indexers = {k: np.unique(v) for k, v in nonzeros} - outobj = self.isel(**indexers) - outcond = cond.isel(**indexers) - else: - outobj = self - outcond = cond - # preserve attributes - out = outobj._where(outcond) - out._copy_attrs_from(self) - return out + self = self.isel(**indexers) + cond = cond.isel(**indexers) + + return ops.where_method(self, cond, other) def close(self): """Close any files linked to this object @@ -658,42 +665,6 @@ def __exit__(self, exc_type, exc_value, traceback): __or__ = __div__ = __eq__ = __ne__ = not_implemented -def _maybe_promote(dtype): - """Simpler equivalent of pandas.core.common._maybe_promote""" - # N.B. these casting rules should match pandas - if np.issubdtype(dtype, float): - fill_value = np.nan - elif np.issubdtype(dtype, int): - # convert to floating point so NaN is valid - dtype = float - fill_value = np.nan - elif np.issubdtype(dtype, complex): - fill_value = np.nan + np.nan * 1j - elif np.issubdtype(dtype, np.datetime64): - fill_value = np.datetime64('NaT') - elif np.issubdtype(dtype, np.timedelta64): - fill_value = np.timedelta64('NaT') - else: - dtype = object - fill_value = np.nan - return np.dtype(dtype), fill_value - - -def _possibly_convert_objects(values): - """Convert arrays of datetime.datetime and datetime.timedelta objects into - datetime64 and timedelta64, according to the pandas convention. - """ - return np.asarray(pd.Series(values.ravel())).reshape(values.shape) - - -def _get_fill_value(dtype): - """Return a fill value that appropriately promotes types when used with - np.concatenate - """ - _, fill_value = _maybe_promote(dtype) - return fill_value - - def full_like(other, fill_value, dtype=None): """Return a new object with the same shape and type as a given object. @@ -761,10 +732,3 @@ def ones_like(other, dtype=None): """Shorthand for full_like(other, 1, dtype) """ return full_like(other, 1, dtype) - - -def is_datetime_like(dtype): - """Check if a dtype is a subclass of the numpy datetime types - """ - return (np.issubdtype(dtype, np.datetime64) or - np.issubdtype(dtype, np.timedelta64)) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index e866de2752a..784c722989c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -19,6 +19,8 @@ _DEFAULT_FROZEN_SET = frozenset() _DEFAULT_FILL_VALUE = object() +_DEFAULT_NAME = object() +_JOINS_WITHOUT_FILL_VALUES = frozenset({'inner', 'exact'}) class _UFuncSignature(object): @@ -109,8 +111,8 @@ def result_name(objects): # type: List[object] -> Any # use the same naming heuristics as pandas: # https://github.com/blaze/blaze/issues/458#issuecomment-51936356 - names = {getattr(obj, 'name', None) for obj in objects} - names.discard(None) + names = {getattr(obj, 'name', _DEFAULT_NAME) for obj in objects} + names.discard(_DEFAULT_NAME) if len(names) == 1: name, = names else: @@ -220,11 +222,22 @@ def ordered_set_intersection(all_keys): return [key for key in all_keys[0] if key in intersection] +def assert_and_return_exact_match(all_keys): + first_keys = all_keys[0] + for keys in all_keys[1:]: + if keys != first_keys: + raise ValueError( + 'exact match required for all data variable names, but %r != %r' + % (keys, first_keys)) + return first_keys + + _JOINERS = { 'inner': ordered_set_intersection, 'outer': ordered_set_union, 'left': operator.itemgetter(0), 'right': operator.itemgetter(-1), + 'exact': assert_and_return_exact_match, } @@ -320,7 +333,8 @@ def apply_dataset_ufunc(func, *args, **kwargs): keep_attrs = kwargs.pop('keep_attrs', False) first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True - if dataset_join != 'inner' and fill_value is _DEFAULT_FILL_VALUE: + if (dataset_join not in _JOINS_WITHOUT_FILL_VALUES and + fill_value is _DEFAULT_FILL_VALUE): raise TypeError('To apply an operation to datasets with different ', 'data variables, you must supply the ', 'dataset_fill_value argument.') @@ -750,3 +764,45 @@ def stack(objects, dim, new_coord): return variables_ufunc(*args) else: return array_ufunc(*args) + + +def where(cond, x, y): + """Return elements from `x` or `y` depending on `cond`. + + Performs xarray-like broadcasting across input arguments. + + Parameters + ---------- + cond : scalar, array, Variable, DataArray or Dataset with boolean dtype + When True, return values from `x`, otherwise returns values from `y`. + x, y : scalar, array, Variable, DataArray or Dataset + Values from which to choose. All dimension coordinates on these objects + must be aligned with each other and with `cond`. + + Returns + ------- + In priority order: Dataset, DataArray, Variable or array, whichever + type appears as an input argument. + + Examples + -------- + + >>> cond = xr.DataArray([True, False], dims=['x']) + >>> x = xr.DataArray([1, 2], dims=['y']) + >>> xr.where(cond, x, 0) + + array([[1, 2], + [0, 0]]) + Dimensions without coordinates: x, y + + See also + -------- + numpy.where : corresponding numpy function + Dataset.where, DataArray.where : equivalent methods + """ + # alignment for three arguments is complicated, so don't support it yet + return apply_ufunc(duck_array_ops.where, + cond, x, y, + join='exact', + dataset_join='exact', + dask_array='allowed') diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b02b954fb29..bffdbf10724 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -22,7 +22,8 @@ from .. import conventions from .alignment import align from .coordinates import DatasetCoordinates, LevelCoordinatesSource, Indexes -from .common import ImplementsDatasetReduce, BaseDataObject, is_datetime_like +from .common import ImplementsDatasetReduce, BaseDataObject +from .dtypes import is_datetime_like from .merge import (dataset_update_method, dataset_merge_method, merge_data_and_coords) from .utils import (Frozen, SortedKeysDict, maybe_wrap_array, hashable, diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py new file mode 100644 index 00000000000..f9d24d0204e --- /dev/null +++ b/xarray/core/dtypes.py @@ -0,0 +1,58 @@ +import numpy as np + + +class NA(object): + """Use as a sentinel value to indicate a dtype appropriate NA value.""" + + +def maybe_promote(dtype): + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + if np.issubdtype(dtype, float): + fill_value = np.nan + elif np.issubdtype(dtype, int): + # convert to floating point so NaN is valid + dtype = float + fill_value = np.nan + elif np.issubdtype(dtype, complex): + fill_value = np.nan + np.nan * 1j + elif np.issubdtype(dtype, np.datetime64): + fill_value = np.datetime64('NaT') + elif np.issubdtype(dtype, np.timedelta64): + fill_value = np.timedelta64('NaT') + else: + dtype = object + fill_value = np.nan + return np.dtype(dtype), fill_value + + +def get_fill_value(dtype): + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def is_datetime_like(dtype): + """Check if a dtype is a subclass of the numpy datetime types + """ + return (np.issubdtype(dtype, np.datetime64) or + np.issubdtype(dtype, np.timedelta64)) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 66d44ce547b..f44b459789f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -16,6 +16,7 @@ import pandas as pd from . import npcompat +from . import dtypes from .pycompat import dask_array_type from .nputils import nanfirst, nanlast @@ -148,13 +149,16 @@ def count(data, axis=None): return sum(~isnull(data), axis=axis) -def where_method(data, cond, other=np.nan): - """Select values from this object that are True in cond. Everything else - gets masked with other. Follows normal broadcasting and alignment rules. - """ +def where_method(data, cond, other=dtypes.NA): + if other is dtypes.NA: + other = dtypes.get_fill_value(data.dtype) return where(cond, data, other) +def fillna(data, other): + return where(isnull(data), other, data) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2e31ae5bc2a..b5fb02a8083 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -5,12 +5,13 @@ import numpy as np import pandas as pd +from . import dtypes from . import duck_array_ops from . import nputils from . import ops from .combine import concat from .common import ( - ImplementsArrayReduce, ImplementsDatasetReduce, _maybe_promote, + ImplementsArrayReduce, ImplementsDatasetReduce, ) from .pycompat import range, zip, integer_types from .utils import hashable, peek_at, maybe_wrap_array, safe_cast_to_index @@ -44,27 +45,19 @@ def unique_value_groups(ar, sort=True): return values, groups -def _get_fill_value(dtype): - """Return a fill value that appropriately promotes types when used with - np.concatenate - """ - dtype, fill_value = _maybe_promote(dtype) - return fill_value - - def _dummy_copy(xarray_obj): from .dataset import Dataset from .dataarray import DataArray if isinstance(xarray_obj, Dataset): - res = Dataset(dict((k, _get_fill_value(v.dtype)) + res = Dataset(dict((k, dtypes.get_fill_value(v.dtype)) for k, v in xarray_obj.data_vars.items()), - dict((k, _get_fill_value(v.dtype)) + dict((k, dtypes.get_fill_value(v.dtype)) for k, v in xarray_obj.coords.items() if k not in xarray_obj.dims), xarray_obj.attrs) elif isinstance(xarray_obj, DataArray): - res = DataArray(_get_fill_value(xarray_obj.dtype), - dict((k, _get_fill_value(v.dtype)) + res = DataArray(dtypes.get_fill_value(xarray_obj.dtype), + dict((k, dtypes.get_fill_value(v.dtype)) for k, v in xarray_obj.coords.items() if k not in xarray_obj.dims), dims=[], @@ -387,16 +380,16 @@ def fillna(self, value): out = ops.fillna(self, value) return out - def where(self, cond): - """Return an object of the same shape with all entries where cond is - True and all other entries masked. - - This operation follows the normal broadcasting and alignment rules that - xarray uses for binary arithmetic. + def where(self, cond, other=dtypes.NA): + """Return elements from `self` or `other` depending on `cond`. Parameters ---------- - cond : DataArray or Dataset + cond : DataArray or Dataset with boolean dtype + Locations at which to preserve this objects values. + other : scalar, DataArray or Dataset, optional + Value to use for locations in this object where ``cond`` is False. + By default, inserts missing values. Returns ------- @@ -406,7 +399,7 @@ def where(self, cond): -------- Dataset.where """ - return self._where(cond) + return ops.where_method(self, cond, other) def _first_or_last(self, op, skipna, keep_attrs): if isinstance(self._group_indices[0], integer_types): diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d816c341dbe..b840f26073c 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -15,6 +15,7 @@ import numpy as np import pandas as pd +from . import dtypes from . import duck_array_ops from .pycompat import PY3 from .nputils import array_eq, array_ne @@ -145,14 +146,40 @@ def fillna(data, other, join="left", dataset_join="left"): """ from .computation import apply_ufunc - def _fillna(data, other): - return duck_array_ops.where(duck_array_ops.isnull(data), other, data) - return apply_ufunc(_fillna, data, other, join=join, dask_array="allowed", + return apply_ufunc(duck_array_ops.fillna, data, other, + join=join, + dask_array="allowed", dataset_join=dataset_join, dataset_fill_value=np.nan, keep_attrs=True) +def where_method(self, cond, other=dtypes.NA): + """Return elements from `self` or `other` depending on `cond`. + + Parameters + ---------- + cond : DataArray or Dataset with boolean dtype + Locations at which to preserve this objects values. + other : scalar, DataArray or Dataset, optional + Value to use for locations in this object where ``cond`` is False. + By default, inserts missing values. + + Returns + ------- + Same type as caller. + """ + from .computation import apply_ufunc + # alignment for three arguments is complicated, so don't support it yet + join = 'inner' if other is dtypes.NA else 'exact' + return apply_ufunc(duck_array_ops.where_method, + self, cond, other, + join=join, + dataset_join=join, + dask_array='allowed', + keep_attrs=True) + + def _call_possibly_missing_method(arg, name, args, kwargs): try: method = getattr(arg, name) @@ -268,10 +295,6 @@ def inject_binary_ops(cls, inplace=False): for name, f in [('eq', array_eq), ('ne', array_ne)]: setattr(cls, op_str(name), cls._binary_op(f)) - # patch in where - f = _func_slash_method_wrapper(duck_array_ops.where_method, 'where') - setattr(cls, '_where', cls._binary_op(f)) - for name in NUM_BINARY_OPS: # only numeric operations have in-place and reflexive variants setattr(cls, op_str('r' + name), diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 121bce979ee..b1a2b1add74 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -12,6 +12,7 @@ from . import common from . import duck_array_ops +from . import dtypes from . import indexing from . import nputils from . import ops @@ -121,6 +122,13 @@ def _maybe_wrap_data(data): return data +def _possibly_convert_objects(values): + """Convert arrays of datetime.datetime and datetime.timedelta objects into + datetime64 and timedelta64, according to the pandas convention. + """ + return np.asarray(pd.Series(values.ravel())).reshape(values.shape) + + def as_compatible_data(data, fastpath=False): """Prepare and wrap data to put in a Variable. @@ -171,7 +179,7 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, np.ma.MaskedArray): mask = np.ma.getmaskarray(data) if mask.any(): - dtype, fill_value = common._maybe_promote(data.dtype) + dtype, fill_value = dtypes.maybe_promote(data.dtype) data = np.asarray(data, dtype=dtype) data[mask] = fill_value else: @@ -179,7 +187,7 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, np.ndarray): if data.dtype.kind == 'O': - data = common._possibly_convert_objects(data) + data = _possibly_convert_objects(data) elif data.dtype.kind == 'M': data = np.asarray(data, 'datetime64[ns]') elif data.dtype.kind == 'm': @@ -603,7 +611,7 @@ def _shift_one_dim(self, dim, count): keep = slice(None) trimmed_data = self[(slice(None),) * axis + (keep,)].data - dtype, fill_value = common._maybe_promote(self.dtype) + dtype, fill_value = dtypes.maybe_promote(self.dtype) shape = list(self.shape) shape[axis] = min(abs(count), shape[axis]) @@ -890,8 +898,8 @@ def unstack(self, **dimensions): def fillna(self, value): return ops.fillna(self, value) - def where(self, cond): - return self._where(cond) + def where(self, cond, other=dtypes.NA): + return ops.where_method(self, cond, other) def reduce(self, func, dim=None, axis=None, keep_attrs=False, allow_lazy=False, **kwargs): diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index f96e44dc390..424484b438a 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -4,6 +4,7 @@ import warnings from contextlib import contextmanager from distutils.version import LooseVersion +import re import numpy as np from numpy.testing import assert_array_equal @@ -182,6 +183,16 @@ def assertDataArrayAllClose(self, ar1, ar2, rtol=1e-05, atol=1e-08): assert_allclose(ar1, ar2, rtol=rtol, atol=atol) +@contextmanager +def raises_regex(error, pattern): + with pytest.raises(error) as excinfo: + yield + message = str(excinfo.value) + if not re.match(pattern, message): + raise AssertionError('exception %r did not match pattern %s' + % (excinfo.value, pattern)) + + class UnexpectedDataAccess(Exception): pass diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index cde0a92ddfa..76d12d51616 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -9,7 +9,7 @@ import xarray as xr from xarray.core.computation import ( - _UFuncSignature, broadcast_compat_data, collect_dict_values, + _UFuncSignature, result_name, broadcast_compat_data, collect_dict_values, join_dict_keys, ordered_set_intersection, ordered_set_union, unified_dim_sizes, apply_ufunc) @@ -36,6 +36,19 @@ def test_signature_properties(): assert _UFuncSignature([['x']]) != _UFuncSignature([['y']]) +def test_result_name(): + + class Named(object): + def __init__(self, name=None): + self.name = name + + assert result_name([1, 2]) is None + assert result_name([Named()]) is None + assert result_name([Named('foo'), 2]) == 'foo' + assert result_name([Named('foo'), Named('bar')]) is None + assert result_name([Named('foo'), Named()]) is None + + def test_ordered_set_union(): assert list(ordered_set_union([[1, 2]])) == [1, 2] assert list(ordered_set_union([[1, 2], [2, 1]])) == [1, 2] @@ -55,6 +68,8 @@ def test_join_dict_keys(): assert list(join_dict_keys(dicts, 'right')) == ['y', 'z'] assert list(join_dict_keys(dicts, 'inner')) == ['y'] assert list(join_dict_keys(dicts, 'outer')) == ['x', 'y', 'z'] + with pytest.raises(ValueError): + join_dict_keys(dicts, 'exact') with pytest.raises(KeyError): join_dict_keys(dicts, 'foobar') @@ -563,3 +578,10 @@ def dask_safe_identity(x): actual = dask_safe_identity(dataset) assert isinstance(actual['y'].data, da.Array) assert_identical(dataset, actual) + + +def test_where(): + cond = xr.DataArray([True, False], dims='x') + actual = xr.where(cond, 1, 0) + expected = xr.DataArray([1, 0], dims='x') + assert_identical(expected, actual) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 950890e8ff0..e6a91b103ed 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -27,7 +27,7 @@ integer_types) from xarray.core.common import full_like -from . import (TestCase, InaccessibleArray, UnexpectedDataAccess, +from . import (TestCase, raises_regex, InaccessibleArray, UnexpectedDataAccess, requires_dask, source_ndarray) from xarray.tests import (assert_equal, assert_allclose, @@ -2667,17 +2667,33 @@ def test_where(self): self.assertEqual(actual.a.attrs, ds.a.attrs) # attrs - da = DataArray(range(5), name='a', attrs={'attr':'da'}) + da = DataArray(range(5), name='a', attrs={'attr': 'da'}) actual = da.where(da.values > 1) self.assertEqual(actual.name, 'a') self.assertEqual(actual.attrs, da.attrs) - ds = Dataset({'a': da}, attrs={'attr':'ds'}) + ds = Dataset({'a': da}, attrs={'attr': 'ds'}) actual = ds.where(ds > 0) self.assertEqual(actual.attrs, ds.attrs) self.assertEqual(actual.a.name, 'a') self.assertEqual(actual.a.attrs, ds.a.attrs) + def test_where_other(self): + ds = Dataset({'a': ('x', range(5))}, {'x': range(5)}) + expected = Dataset({'a': ('x', [-1, -1, 2, 3, 4])}, {'x': range(5)}) + actual = ds.where(ds > 1, -1) + assert_equal(expected, actual) + assert actual.a.dtype == int + + with raises_regex(ValueError, "cannot set"): + ds.where(ds > 1, other=0, drop=True) + + with raises_regex(ValueError, "indexes .* are not equal"): + ds.where(ds > 1, ds.isel(x=slice(3))) + + with raises_regex(ValueError, "exact match required"): + ds.where(ds > 1, ds.assign(b=2)) + def test_where_drop(self): # if drop=True