-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Options to binary ops kwargs #1065
Changes from 7 commits
3b04fb0
643fa49
43df4d0
ebb2ad0
ab717fa
8aa969b
46ad36a
4a05ca4
a7b7497
88c37bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
from .pycompat import (iteritems, basestring, OrderedDict, | ||
dask_array_type) | ||
from .combine import concat | ||
from .options import OPTIONS | ||
|
||
|
||
# list of attributes of pd.DatetimeIndex that are ndarrays of time info | ||
|
@@ -2018,15 +2019,17 @@ def func(self, *args, **kwargs): | |
return func | ||
|
||
@staticmethod | ||
def _binary_op(f, reflexive=False, join='inner', fillna=False): | ||
def _binary_op(f, reflexive=False, join=None, fillna=False): | ||
@functools.wraps(f) | ||
def func(self, other): | ||
if isinstance(other, groupby.GroupBy): | ||
return NotImplemented | ||
align_type = OPTIONS['arithmetic_join'] if join is None else join | ||
if hasattr(other, 'indexes'): | ||
self, other = align(self, other, join=join, copy=False) | ||
self, other = align(self, other, join=align_type, copy=False) | ||
g = f if not reflexive else lambda x, y: f(y, x) | ||
ds = self._calculate_binary_op(g, other, fillna=fillna) | ||
ds = self._calculate_binary_op(g, other, join=align_type, | ||
fillna=fillna) | ||
return ds | ||
return func | ||
|
||
|
@@ -2048,25 +2051,30 @@ def func(self, other): | |
return self | ||
return func | ||
|
||
def _calculate_binary_op(self, f, other, inplace=False, fillna=False): | ||
def _calculate_binary_op(self, f, other, join='inner', | ||
inplace=False, fillna=False): | ||
|
||
def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): | ||
if fillna and join != 'left': | ||
raise ValueError('`fillna` must be accompanied by left join') | ||
if fillna and not set(rhs_data_vars) <= set(lhs_data_vars): | ||
raise ValueError('all variables in the argument to `fillna` ' | ||
'must be contained in the original dataset') | ||
if inplace and set(lhs_data_vars) != set(rhs_data_vars): | ||
raise ValueError('datasets must have the same data variables ' | ||
'for in-place arithmetic operations: %s, %s' | ||
% (list(lhs_data_vars), list(rhs_data_vars))) | ||
|
||
dest_vars = OrderedDict() | ||
for k in lhs_data_vars: | ||
if k in rhs_data_vars: | ||
dest_vars[k] = f(lhs_vars[k], rhs_vars[k]) | ||
elif inplace: | ||
raise ValueError( | ||
'datasets must have the same data variables ' | ||
'for in-place arithmetic operations: %s, %s' | ||
% (list(lhs_data_vars), list(rhs_data_vars))) | ||
elif fillna: | ||
# this shortcuts left alignment of variables for fillna | ||
|
||
for k in set(lhs_data_vars) & set(rhs_data_vars): | ||
dest_vars[k] = f(lhs_vars[k], rhs_vars[k]) | ||
if join in ["outer", "left"]: | ||
for k in set(lhs_data_vars) - set(rhs_data_vars): | ||
dest_vars[k] = lhs_vars[k] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was only the right default for |
||
if join in ["outer", "right"]: | ||
for k in set(rhs_data_vars) - set(lhs_data_vars): | ||
dest_vars[k] = rhs_vars[k] | ||
return dest_vars | ||
|
||
if utils.is_dict_like(other) and not isinstance(other, Dataset): | ||
|
@@ -2086,7 +2094,6 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): | |
other_variable = getattr(other, 'variable', other) | ||
new_vars = OrderedDict((k, f(self.variables[k], other_variable)) | ||
for k in self.data_vars) | ||
|
||
ds._variables.update(new_vars) | ||
return ds | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
OPTIONS = {'display_width': 80} | ||
OPTIONS = {'display_width': 80, | ||
'arithmetic_join': "inner"} | ||
|
||
|
||
class set_options(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you update the docstring for this class? |
||
|
@@ -25,6 +26,10 @@ class set_options(object): | |
""" | ||
def __init__(self, **kwargs): | ||
self.old = OPTIONS.copy() | ||
for key in kwargs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please rebase or merge master -- I actually already added in something very similar |
||
if key not in OPTIONS: | ||
raise KeyError("acceptable keys are: {}".\ | ||
format(', '.join(OPTIONS.keys()))) | ||
OPTIONS.update(kwargs) | ||
|
||
def __enter__(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2268,3 +2268,16 @@ def test_dot(self): | |
da.dot(dm.values) | ||
with self.assertRaisesRegexp(ValueError, 'no shared dimensions'): | ||
da.dot(DataArray(1)) | ||
|
||
def test_binary_op_join_setting(self): | ||
dim = 'x' | ||
align_type = "outer" | ||
coords_l, coords_r = [0, 1, 2], [1, 2, 3] | ||
missing_3 = xr.DataArray(coords_l, [(dim, coords_l)]) | ||
missing_0 = xr.DataArray(coords_r, [(dim, coords_r)]) | ||
with xr.set_options(arithmetic_join=align_type): | ||
actual = missing_0 + missing_3 | ||
missing_0_aligned, missing_3_aligned =\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line break after |
||
xr.align(missing_0, missing_3, join=align_type) | ||
expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])]) | ||
self.assertDataArrayEqual(actual, expected) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
import pytest | ||
|
||
from xarray import (align, broadcast, concat, merge, conventions, backends, | ||
|
@@ -2916,6 +2917,40 @@ def test_filter_by_attrs(self): | |
for var in new_ds.data_vars: | ||
self.assertEqual(new_ds[var].height, '10 m') | ||
|
||
def test_binary_op_join_setting(self): | ||
# arithmetic_join applies to data array coordinates | ||
missing_2 = xr.Dataset({'x':[0, 1]}) | ||
missing_0 = xr.Dataset({'x':[1, 2]}) | ||
with xr.set_options(arithmetic_join='outer'): | ||
actual = missing_2 + missing_0 | ||
expected = xr.Dataset({'x':[0, 1, 2]}) | ||
self.assertDatasetEqual(actual, expected) | ||
|
||
# arithmetic join also applies to data_vars | ||
ds1 = xr.Dataset({'foo': 1, 'bar': 2}) | ||
ds2 = xr.Dataset({'bar': 2, 'baz': 3}) | ||
expected = xr.Dataset({'bar': 4}) # default is inner joining | ||
actual = ds1 + ds2 | ||
self.assertDatasetEqual(actual, expected) | ||
|
||
with xr.set_options(arithmetic_join='outer'): | ||
expected = xr.Dataset({'foo':1, 'bar': 4, 'baz': 3}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noted above, should be |
||
actual = ds1 + ds2 | ||
self.assertDatasetEqual(actual, expected) | ||
|
||
with xr.set_options(arithmetic_join='left'): | ||
expected = xr.Dataset({'foo':1, 'bar': 4}) | ||
actual = ds1 + ds2 | ||
self.assertDatasetEqual(actual, expected) | ||
|
||
with xr.set_options(arithmetic_join='right'): | ||
expected = xr.Dataset({'baz':3, 'bar': 4}) | ||
actual = ds1 + ds2 | ||
self.assertDatasetEqual(actual, expected) | ||
|
||
|
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. too much white space here -- by convention, leave only 2 blank lines |
||
|
||
### Py.test tests | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is important to maintain the order of data variables: should be
lhs
first, thenrhs
only variables.So, probably should use a loop over
lhs_data_vars
, followed byrhs_data_vars
.