diff --git a/doc/computation.rst b/doc/computation.rst index 7ba66f6db8b..9cba5db2061 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -210,8 +210,8 @@ coordinates with the same name as a dimension, marked by ``*``) on objects used in binary operations. Similarly to pandas, this alignment is automatic for arithmetic on binary -operations. Note that unlike pandas, this the result of a binary operation is -by the *intersection* (not the union) of coordinate labels: +operations. The default result of a binary operation is by the *intersection* +(not the union) of coordinate labels: .. ipython:: python @@ -225,6 +225,15 @@ If the result would be empty, an error is raised instead: In [1]: arr[:2] + arr[2:] ValueError: no overlapping labels for some dimensions: ['x'] +However, one can explicitly change this default automatic alignment type ("inner") +via :py:func:`~xarray.set_options()` in context manager: + +.. ipython:: python + + with xr.set_options(arithmetic_join="outer"): + arr + arr[:1] + arr + arr[:1] + Before loops or performance critical code, it's a good idea to align arrays explicitly (e.g., by putting them in the same Dataset or using :py:func:`~xarray.align`) to avoid the overhead of repeated alignment with each diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f001c5044ce..ce305a6aa24 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -44,6 +44,11 @@ Deprecations Enhancements ~~~~~~~~~~~~ +- Added the ability to change default automatic alignment (arithmetic_join="inner") + for binary operations via :py:func:`~xarray.set_options()` + (see :ref:`automatic alignment`). + By `Chun-Wei Yuan `_. + - Add checking of ``attr`` names and values when saving to netCDF, raising useful error messages if they are invalid. (:issue:`911`). By `Robin Wilson `_. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d7879625847..fdbb8c773f6 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -25,6 +25,7 @@ assert_unique_multiindex_level_names) from .formatting import format_item from .utils import decode_numpy_dict_values, ensure_us_time_resolution +from .options import OPTIONS def _infer_coords_and_dims(shape, coords, dims): @@ -1357,13 +1358,14 @@ def func(self, *args, **kwargs): return func @staticmethod - def _binary_op(f, reflexive=False, join='inner', **ignored_kwargs): + def _binary_op(f, reflexive=False, join=None, **ignored_kwargs): @functools.wraps(f) def func(self, other): if isinstance(other, (Dataset, groupby.GroupBy)): return NotImplemented if hasattr(other, 'indexes'): - self, other = align(self, other, join=join, copy=False) + align_type = OPTIONS['arithmetic_join'] if join is None else join + self, other = align(self, other, join=align_type, copy=False) other_variable = getattr(other, 'variable', other) other_coords = getattr(other, 'coords', None) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c9fdf5ed40b..964bd3dbb7a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -28,6 +28,7 @@ from .pycompat import (iteritems, basestring, OrderedDict, dask_array_type) from .combine import concat +from .options import OPTIONS # list of attributes of pd.DatetimeIndex that are ndarrays of time info @@ -2012,15 +2013,17 @@ def func(self, *args, **kwargs): return func @staticmethod - def _binary_op(f, reflexive=False, join='inner', fillna=False): + def _binary_op(f, reflexive=False, join=None, fillna=False): @functools.wraps(f) def func(self, other): if isinstance(other, groupby.GroupBy): return NotImplemented + align_type = OPTIONS['arithmetic_join'] if join is None else join if hasattr(other, 'indexes'): - self, other = align(self, other, join=join, copy=False) + self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) - ds = self._calculate_binary_op(g, other, fillna=fillna) + ds = self._calculate_binary_op(g, other, join=align_type, + fillna=fillna) return ds return func @@ -2042,25 +2045,32 @@ def func(self, other): return self return func - def _calculate_binary_op(self, f, other, inplace=False, fillna=False): + def _calculate_binary_op(self, f, other, join='inner', + inplace=False, fillna=False): def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): + if fillna and join != 'left': + raise ValueError('`fillna` must be accompanied by left join') if fillna and not set(rhs_data_vars) <= set(lhs_data_vars): raise ValueError('all variables in the argument to `fillna` ' 'must be contained in the original dataset') + if inplace and set(lhs_data_vars) != set(rhs_data_vars): + raise ValueError('datasets must have the same data variables ' + 'for in-place arithmetic operations: %s, %s' + % (list(lhs_data_vars), list(rhs_data_vars))) dest_vars = OrderedDict() + for k in lhs_data_vars: if k in rhs_data_vars: dest_vars[k] = f(lhs_vars[k], rhs_vars[k]) - elif inplace: - raise ValueError( - 'datasets must have the same data variables ' - 'for in-place arithmetic operations: %s, %s' - % (list(lhs_data_vars), list(rhs_data_vars))) - elif fillna: - # this shortcuts left alignment of variables for fillna - dest_vars[k] = lhs_vars[k] + elif join in ["left", "outer"]: + dest_vars[k] = (lhs_vars[k] if fillna else + f(lhs_vars[k], np.nan)) + for k in rhs_data_vars: + if k not in dest_vars and join in ["right", "outer"]: + dest_vars[k] = (rhs_vars[k] if fillna else + f(rhs_vars[k], np.nan)) return dest_vars if utils.is_dict_like(other) and not isinstance(other, Dataset): @@ -2080,7 +2090,6 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): other_variable = getattr(other, 'variable', other) new_vars = OrderedDict((k, f(self.variables[k], other_variable)) for k in self.data_vars) - ds._variables.update(new_vars) return ds diff --git a/xarray/core/options.py b/xarray/core/options.py index 4b80b400f69..65264bf4919 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,14 +1,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -OPTIONS = {'display_width': 80} +OPTIONS = {'display_width': 80, + 'arithmetic_join': "inner"} class set_options(object): """Set options for xarray in a controlled context. - Currently, the only supported option is ``display_width``, which has a - default value of 80. + Currently, the only supported options are: + 1.) display_width: maximum terminal display width of data arrays. + Default=80. + 2.) arithmetic_join: dataarray/dataset alignment in binary operations. + Default='inner'. You can use ``set_options`` either as a context manager: diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index ce5605d9098..d5ee9851469 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -2278,3 +2278,17 @@ def test_dot(self): da.dot(dm.values) with self.assertRaisesRegexp(ValueError, 'no shared dimensions'): da.dot(DataArray(1)) + + def test_binary_op_join_setting(self): + dim = 'x' + align_type = "outer" + coords_l, coords_r = [0, 1, 2], [1, 2, 3] + missing_3 = xr.DataArray(coords_l, [(dim, coords_l)]) + missing_0 = xr.DataArray(coords_r, [(dim, coords_r)]) + with xr.set_options(arithmetic_join=align_type): + actual = missing_0 + missing_3 + missing_0_aligned, missing_3_aligned = xr.align(missing_0, + missing_3, + join=align_type) + expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])]) + self.assertDataArrayEqual(actual, expected) diff --git a/xarray/test/test_dataset.py b/xarray/test/test_dataset.py index 9321bea60d4..5b8af6b6437 100644 --- a/xarray/test/test_dataset.py +++ b/xarray/test/test_dataset.py @@ -15,6 +15,7 @@ import numpy as np import pandas as pd +import xarray as xr import pytest from xarray import (align, broadcast, concat, merge, conventions, backends, @@ -2935,6 +2936,37 @@ def test_filter_by_attrs(self): for var in new_ds.data_vars: self.assertEqual(new_ds[var].height, '10 m') + def test_binary_op_join_setting(self): + # arithmetic_join applies to data array coordinates + missing_2 = xr.Dataset({'x':[0, 1]}) + missing_0 = xr.Dataset({'x':[1, 2]}) + with xr.set_options(arithmetic_join='outer'): + actual = missing_2 + missing_0 + expected = xr.Dataset({'x':[0, 1, 2]}) + self.assertDatasetEqual(actual, expected) + + # arithmetic join also applies to data_vars + ds1 = xr.Dataset({'foo': 1, 'bar': 2}) + ds2 = xr.Dataset({'bar': 2, 'baz': 3}) + expected = xr.Dataset({'bar': 4}) # default is inner joining + actual = ds1 + ds2 + self.assertDatasetEqual(actual, expected) + + with xr.set_options(arithmetic_join='outer'): + expected = xr.Dataset({'foo': np.nan, 'bar': 4, 'baz': np.nan}) + actual = ds1 + ds2 + self.assertDatasetEqual(actual, expected) + + with xr.set_options(arithmetic_join='left'): + expected = xr.Dataset({'foo': np.nan, 'bar': 4}) + actual = ds1 + ds2 + self.assertDatasetEqual(actual, expected) + + with xr.set_options(arithmetic_join='right'): + expected = xr.Dataset({'bar': 4, 'baz': np.nan}) + actual = ds1 + ds2 + self.assertDatasetEqual(actual, expected) + ### Py.test tests