Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Options to binary ops kwargs #1065

Merged
merged 10 commits into from
Nov 12, 2016
33 changes: 19 additions & 14 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2024,11 +2024,12 @@ def _binary_op(f, reflexive=False, join=None, fillna=False):
def func(self, other):
if isinstance(other, groupby.GroupBy):
return NotImplemented
align_type = OPTIONS['arithmetic_join'] if join is None else join
if hasattr(other, 'indexes'):
align_type = OPTIONS['arithmetic_join'] if join is None else join
self, other = align(self, other, join=align_type, copy=False)
g = f if not reflexive else lambda x, y: f(y, x)
ds = self._calculate_binary_op(g, other, fillna=fillna)
ds = self._calculate_binary_op(g, other, join=align_type,
fillna=fillna)
return ds
return func

Expand All @@ -2050,25 +2051,30 @@ def func(self, other):
return self
return func

def _calculate_binary_op(self, f, other, inplace=False, fillna=False):
def _calculate_binary_op(self, f, other, join='inner',
inplace=False, fillna=False):

def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
if fillna and join != 'left':
raise ValueError('`fillna` must be accompanied by left join')
if fillna and not set(rhs_data_vars) <= set(lhs_data_vars):
raise ValueError('all variables in the argument to `fillna` '
'must be contained in the original dataset')
if inplace and set(lhs_data_vars) != set(rhs_data_vars):
raise ValueError('datasets must have the same data variables '
'for in-place arithmetic operations: %s, %s'
% (list(lhs_data_vars), list(rhs_data_vars)))

dest_vars = OrderedDict()
for k in lhs_data_vars:
if k in rhs_data_vars:
dest_vars[k] = f(lhs_vars[k], rhs_vars[k])
elif inplace:
raise ValueError(
'datasets must have the same data variables '
'for in-place arithmetic operations: %s, %s'
% (list(lhs_data_vars), list(rhs_data_vars)))
elif fillna:
# this shortcuts left alignment of variables for fillna

for k in set(lhs_data_vars) & set(rhs_data_vars):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is important to maintain the order of data variables: should be lhs first, then rhs only variables.

So, probably should use a loop over lhs_data_vars, followed by rhs_data_vars.

dest_vars[k] = f(lhs_vars[k], rhs_vars[k])
if join in ["outer", "left"]:
for k in set(lhs_data_vars) - set(rhs_data_vars):
dest_vars[k] = lhs_vars[k]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only the right default for fillna. For other operations, should be NaN instead (or, more specifically, the result of f(lhs_vars[k], np.nan).

if join in ["outer", "right"]:
for k in set(rhs_data_vars) - set(lhs_data_vars):
dest_vars[k] = rhs_vars[k]
return dest_vars

if utils.is_dict_like(other) and not isinstance(other, Dataset):
Expand All @@ -2088,7 +2094,6 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
other_variable = getattr(other, 'variable', other)
new_vars = OrderedDict((k, f(self.variables[k], other_variable))
for k in self.data_vars)

ds._variables.update(new_vars)
return ds

Expand Down
4 changes: 4 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class set_options(object):
"""
def __init__(self, **kwargs):
self.old = OPTIONS.copy()
for key in kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rebase or merge master -- I actually already added in something very similar
https://github.com/pydata/xarray/blob/master/xarray/core/options.py

if key not in OPTIONS:
raise KeyError("acceptable keys are: {}".\
format(', '.join(OPTIONS.keys())))
OPTIONS.update(kwargs)

def __enter__(self):
Expand Down
43 changes: 32 additions & 11 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,17 +2918,38 @@ def test_filter_by_attrs(self):
self.assertEqual(new_ds[var].height, '10 m')

def test_binary_op_join_setting(self):
dim = 'x'
align_type = "outer"
coords_l, coords_r = [0, 1, 2], [1, 2, 3]
missing_3 = xr.DataArray(coords_l, [(dim, coords_l)]).to_dataset(name='a')
missing_0 = xr.DataArray(coords_r, [(dim, coords_r)]).to_dataset(name='a')
with xr.set_options(arithmetic_join=align_type):
experimental = missing_0 + missing_3
missing_0_aligned, missing_3_aligned =\
xr.align(missing_0, missing_3, join=align_type)
control = missing_0_aligned + missing_3_aligned
self.assertDatasetEqual(experimental, control)
# arithmetic_join applies to data array coordinates
missing_2 = xr.Dataset({'x':[0, 1]})
missing_0 = xr.Dataset({'x':[1, 2]})
with xr.set_options(arithmetic_join='outer'):
actual = missing_2 + missing_0
expected = xr.Dataset({'x':[0, 1, 2]})
self.assertDatasetEqual(actual, expected)

# arithmetic join also applies to data_vars
ds1 = xr.Dataset({'foo': 1, 'bar': 2})
ds2 = xr.Dataset({'bar': 2, 'baz': 3})
expected = xr.Dataset({'bar': 4}) # default is inner joining
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='outer'):
expected = xr.Dataset({'foo':1, 'bar': 4, 'baz': 3})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, should be xr.Dataset({'foo': nan, 'bar': 4, 'baz': nan}) for consistent outer join semantics

actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='left'):
expected = xr.Dataset({'foo':1, 'bar': 4})
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='right'):
expected = xr.Dataset({'baz':3, 'bar': 4})
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)




Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too much white space here -- by convention, leave only 2 blank lines


### Py.test tests
Expand Down