Skip to content

Commit

Permalink
Tests and fixes for binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Sep 8, 2016
1 parent 4dfca28 commit 879ee32
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 17 deletions.
25 changes: 10 additions & 15 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
name = result_name(args)
result_coords = build_output_coords(args, signature, new_coords)

data_vars = [getattr(a, 'variable') for a in args]
data_vars = [getattr(a, 'variable', a) for a in args]
result_var = func(*data_vars)

if signature.n_outputs > 1:
Expand All @@ -196,16 +196,11 @@ def join_dict_keys(objects, how='inner'):


def collect_dict_values(objects, keys, fill_value=None):
result_values = []
for key in keys:
values = []
for obj in objects:
if hasattr(obj, 'keys'):
values.append(obj.get(key, fill_value))
else:
values = tobj
result_values.append(values)
return result_values
return [[obj.get(key, fill_value)
if is_dict_like(obj)
else obj
for obj in objects]
for key in keys]


def apply_dataset_ufunc(func, *args, **kwargs):
Expand All @@ -220,8 +215,8 @@ def apply_dataset_ufunc(func, args, signature=None, join='inner',
fill_value = kwargs.pop('fill_value', None)
new_coords = kwargs.pop('new_coords', None)
if kwargs:
raise TypeError('apply_dataarray_ufunc() got unexpected keyword arguments: %s'
% list(kwargs))
raise TypeError('apply_dataarray_ufunc() got unexpected keyword '
'arguments: %s' % list(kwargs))

if signature is None:
signature = _default_signature(len(args))
Expand All @@ -230,10 +225,10 @@ def apply_dataset_ufunc(func, args, signature=None, join='inner',

list_of_coords = build_output_coords(args, signature, new_coords)

list_of_data_vars = [getattr(a, 'data_vars', {}) for a in args]
list_of_data_vars = [getattr(a, 'data_vars', a) for a in args]
names = join_dict_keys(list_of_data_vars, how=join)

list_of_variables = [getattr(a, 'variables', {}) for a in args]
list_of_variables = [getattr(a, 'variables', a) for a in args]
lists_of_args = collect_dict_values(list_of_variables, names, fill_value)

result_vars = OrderedDict()
Expand Down
51 changes: 49 additions & 2 deletions xarray/test/test_computation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
import operator

import numpy as np
import pytest
Expand Down Expand Up @@ -47,8 +48,8 @@ def test_join_dict_keys():


def test_collect_dict_values():
dicts = [{'x': 1, 'y': 2, 'z': 3}, {'z': 4}]
expected = [[1, 0], [2, 0], [3, 4]]
dicts = [{'x': 1, 'y': 2, 'z': 3}, {'z': 4}, 5]
expected = [[1, 0, 5], [2, 0, 5], [3, 4, 5]]
collected = collect_dict_values(dicts, ['x', 'y', 'z'], fill_value=0)
assert collected == expected

Expand All @@ -74,6 +75,52 @@ def test_apply_ufunc_identity():
assert_identical(output, dataset)


def test_apply_ufunc_two_inputs():
array = np.array([1, 2, 3])
variable = xr.Variable('x', array)
data_array = xr.DataArray(variable, [('x', -array)])
dataset = xr.Dataset({'y': variable}, {'x': -array})

zeros_array = np.zeros_like(array)
zeros_variable = xr.Variable('x', zeros_array)
zeros_data_array = xr.DataArray(zeros_variable, [('x', -array)])
zeros_dataset = xr.Dataset({'y': zeros_variable}, {'x': -array})

add = lambda a, b: xr.apply_ufunc(operator.add, a, b)

assert_array_equal(array, add(array, 0))
assert_array_equal(array, add(array, zeros_array))
assert_array_equal(array, add(0, array))
assert_array_equal(array, add(zeros_array, array))

assert_identical(variable, add(variable, 0))
assert_identical(variable, add(variable, zeros_array))
assert_identical(variable, add(variable, zeros_variable))
assert_identical(variable, add(0, variable))
assert_identical(variable, add(zeros_array, variable))
assert_identical(variable, add(zeros_variable, variable))

assert_identical(data_array, add(data_array, 0))
assert_identical(data_array, add(data_array, zeros_array))
assert_identical(data_array, add(data_array, zeros_variable))
assert_identical(data_array, add(data_array, zeros_data_array))
assert_identical(data_array, add(0, data_array))
assert_identical(data_array, add(zeros_array, data_array))
assert_identical(data_array, add(zeros_variable, data_array))
assert_identical(data_array, add(zeros_data_array, data_array))

assert_identical(dataset, add(dataset, 0))
assert_identical(dataset, add(dataset, zeros_array))
assert_identical(dataset, add(dataset, zeros_variable))
assert_identical(dataset, add(dataset, zeros_data_array))
assert_identical(dataset, add(dataset, zeros_dataset))
assert_identical(dataset, add(0, dataset))
assert_identical(dataset, add(zeros_array, dataset))
assert_identical(dataset, add(zeros_variable, dataset))
assert_identical(dataset, add(zeros_data_array, dataset))
assert_identical(dataset, add(zeros_dataset, dataset))


def test_apply_ufunc_two_outputs():
array = np.arange(10)
variable = xr.Variable('x', array)
Expand Down

0 comments on commit 879ee32

Please sign in to comment.