Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Options to binary ops kwargs #1065

Merged
merged 10 commits into from
Nov 12, 2016
13 changes: 11 additions & 2 deletions doc/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ coordinates with the same name as a dimension, marked by ``*``) on objects used
in binary operations.

Similarly to pandas, this alignment is automatic for arithmetic on binary
operations. Note that unlike pandas, this the result of a binary operation is
by the *intersection* (not the union) of coordinate labels:
operations. The default result of a binary operation is by the *intersection*
(not the union) of coordinate labels:

.. ipython:: python

Expand All @@ -225,6 +225,15 @@ If the result would be empty, an error is raised instead:
In [1]: arr[:2] + arr[2:]
ValueError: no overlapping labels for some dimensions: ['x']

However, one can explicitly change this default automatic alignment type ("inner")
via :py:func:`~xarray.set_options()` in context manager:

.. ipython:: python

with xr.set_options(arithmetic_join="outer"):
arr + arr[:1]
arr + arr[:1]

Before loops or performance critical code, it's a good idea to align arrays
explicitly (e.g., by putting them in the same Dataset or using
:py:func:`~xarray.align`) to avoid the overhead of repeated alignment with each
Expand Down
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ Deprecations

Enhancements
~~~~~~~~~~~~
- Added the ability to change default automatic alignment (arithmetic_join="inner")
for binary operations via :py:func:`~xarray.set_options()`
(see :ref:`automatic alignment`).
By `Chun-Wei Yuan <https://github.com/chunweiyuan>`_.

- Add checking of ``attr`` names and values when saving to netCDF, raising useful
error messages if they are invalid. (:issue:`911`).
By `Robin Wilson <https://github.com/robintw>`_.
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
assert_unique_multiindex_level_names)
from .formatting import format_item
from .utils import decode_numpy_dict_values, ensure_us_time_resolution
from .options import OPTIONS


def _infer_coords_and_dims(shape, coords, dims):
Expand Down Expand Up @@ -1377,13 +1378,14 @@ def func(self, *args, **kwargs):
return func

@staticmethod
def _binary_op(f, reflexive=False, join='inner', **ignored_kwargs):
def _binary_op(f, reflexive=False, join=None, **ignored_kwargs):
@functools.wraps(f)
def func(self, other):
if isinstance(other, (Dataset, groupby.GroupBy)):
return NotImplemented
if hasattr(other, 'indexes'):
self, other = align(self, other, join=join, copy=False)
align_type = OPTIONS['arithmetic_join'] if join is None else join
self, other = align(self, other, join=align_type, copy=False)
other_variable = getattr(other, 'variable', other)
other_coords = getattr(other, 'coords', None)

Expand Down
37 changes: 22 additions & 15 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .pycompat import (iteritems, basestring, OrderedDict,
dask_array_type)
from .combine import concat
from .options import OPTIONS


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
Expand Down Expand Up @@ -2018,15 +2019,17 @@ def func(self, *args, **kwargs):
return func

@staticmethod
def _binary_op(f, reflexive=False, join='inner', fillna=False):
def _binary_op(f, reflexive=False, join=None, fillna=False):
@functools.wraps(f)
def func(self, other):
if isinstance(other, groupby.GroupBy):
return NotImplemented
align_type = OPTIONS['arithmetic_join'] if join is None else join
if hasattr(other, 'indexes'):
self, other = align(self, other, join=join, copy=False)
self, other = align(self, other, join=align_type, copy=False)
g = f if not reflexive else lambda x, y: f(y, x)
ds = self._calculate_binary_op(g, other, fillna=fillna)
ds = self._calculate_binary_op(g, other, join=align_type,
fillna=fillna)
return ds
return func

Expand All @@ -2048,25 +2051,30 @@ def func(self, other):
return self
return func

def _calculate_binary_op(self, f, other, inplace=False, fillna=False):
def _calculate_binary_op(self, f, other, join='inner',
inplace=False, fillna=False):

def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
if fillna and join != 'left':
raise ValueError('`fillna` must be accompanied by left join')
if fillna and not set(rhs_data_vars) <= set(lhs_data_vars):
raise ValueError('all variables in the argument to `fillna` '
'must be contained in the original dataset')
if inplace and set(lhs_data_vars) != set(rhs_data_vars):
raise ValueError('datasets must have the same data variables '
'for in-place arithmetic operations: %s, %s'
% (list(lhs_data_vars), list(rhs_data_vars)))

dest_vars = OrderedDict()
for k in lhs_data_vars:
if k in rhs_data_vars:
dest_vars[k] = f(lhs_vars[k], rhs_vars[k])
elif inplace:
raise ValueError(
'datasets must have the same data variables '
'for in-place arithmetic operations: %s, %s'
% (list(lhs_data_vars), list(rhs_data_vars)))
elif fillna:
# this shortcuts left alignment of variables for fillna

for k in set(lhs_data_vars) & set(rhs_data_vars):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is important to maintain the order of data variables: should be lhs first, then rhs only variables.

So, probably should use a loop over lhs_data_vars, followed by rhs_data_vars.

dest_vars[k] = f(lhs_vars[k], rhs_vars[k])
if join in ["outer", "left"]:
for k in set(lhs_data_vars) - set(rhs_data_vars):
dest_vars[k] = lhs_vars[k]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only the right default for fillna. For other operations, should be NaN instead (or, more specifically, the result of f(lhs_vars[k], np.nan).

if join in ["outer", "right"]:
for k in set(rhs_data_vars) - set(lhs_data_vars):
dest_vars[k] = rhs_vars[k]
return dest_vars

if utils.is_dict_like(other) and not isinstance(other, Dataset):
Expand All @@ -2086,7 +2094,6 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
other_variable = getattr(other, 'variable', other)
new_vars = OrderedDict((k, f(self.variables[k], other_variable))
for k in self.data_vars)

ds._variables.update(new_vars)
return ds

Expand Down
7 changes: 6 additions & 1 deletion xarray/core/options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
OPTIONS = {'display_width': 80}
OPTIONS = {'display_width': 80,
'arithmetic_join': "inner"}


class set_options(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the docstring for this class?

Expand All @@ -25,6 +26,10 @@ class set_options(object):
"""
def __init__(self, **kwargs):
self.old = OPTIONS.copy()
for key in kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rebase or merge master -- I actually already added in something very similar
https://github.com/pydata/xarray/blob/master/xarray/core/options.py

if key not in OPTIONS:
raise KeyError("acceptable keys are: {}".\
format(', '.join(OPTIONS.keys())))
OPTIONS.update(kwargs)

def __enter__(self):
Expand Down
13 changes: 13 additions & 0 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2268,3 +2268,16 @@ def test_dot(self):
da.dot(dm.values)
with self.assertRaisesRegexp(ValueError, 'no shared dimensions'):
da.dot(DataArray(1))

def test_binary_op_join_setting(self):
dim = 'x'
align_type = "outer"
coords_l, coords_r = [0, 1, 2], [1, 2, 3]
missing_3 = xr.DataArray(coords_l, [(dim, coords_l)])
missing_0 = xr.DataArray(coords_r, [(dim, coords_r)])
with xr.set_options(arithmetic_join=align_type):
actual = missing_0 + missing_3
missing_0_aligned, missing_3_aligned =\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line break after xr.align( instead of using the explicit \

xr.align(missing_0, missing_3, join=align_type)
expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])])
self.assertDataArrayEqual(actual, expected)
35 changes: 35 additions & 0 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import pandas as pd
import xarray as xr
import pytest

from xarray import (align, broadcast, concat, merge, conventions, backends,
Expand Down Expand Up @@ -2916,6 +2917,40 @@ def test_filter_by_attrs(self):
for var in new_ds.data_vars:
self.assertEqual(new_ds[var].height, '10 m')

def test_binary_op_join_setting(self):
# arithmetic_join applies to data array coordinates
missing_2 = xr.Dataset({'x':[0, 1]})
missing_0 = xr.Dataset({'x':[1, 2]})
with xr.set_options(arithmetic_join='outer'):
actual = missing_2 + missing_0
expected = xr.Dataset({'x':[0, 1, 2]})
self.assertDatasetEqual(actual, expected)

# arithmetic join also applies to data_vars
ds1 = xr.Dataset({'foo': 1, 'bar': 2})
ds2 = xr.Dataset({'bar': 2, 'baz': 3})
expected = xr.Dataset({'bar': 4}) # default is inner joining
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='outer'):
expected = xr.Dataset({'foo':1, 'bar': 4, 'baz': 3})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, should be xr.Dataset({'foo': nan, 'bar': 4, 'baz': nan}) for consistent outer join semantics

actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='left'):
expected = xr.Dataset({'foo':1, 'bar': 4})
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='right'):
expected = xr.Dataset({'baz':3, 'bar': 4})
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)




Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too much white space here -- by convention, leave only 2 blank lines


### Py.test tests

Expand Down