-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
einsum for xarray #1968
einsum for xarray #1968
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
xarray/core/computation.py
Outdated
subscripts = '' | ||
for ds in input_core_dims: | ||
subscripts += '...' + ''.join([dim_map[d] for d in ds]) + ',' | ||
subscripts = subscripts[:-1] # remove last comma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would probably be cleaner to build up subscripts as a list and use ','.join(subscripts_list)
once at the end.
xarray/core/computation.py
Outdated
|
||
result = apply_ufunc(np.einsum, subscripts, *arrays, | ||
input_core_dims=[[]] + input_core_dims, | ||
output_core_dims=output_core_dims, dask='allowed') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dask='parallelized'
is what you want here -- that generate the wrapper to do this with dask. This will require also determining the result data type, probably with dtypes.result_type
or even np.result_type
(we don't need support for non-numeric types in einsum, so I'm pretty sure NumPy's casting rules would work fine).
dask='allowed'
would be appropriate if np.einsum
already supported dask arrays (but it does not).
It's possible that a dask specific einsum
could be much more efficient than the auto-generated wrapper here, but certainly this is good enough for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I noticed that my current implementation is not very efficient for dask.
Maybe smaller number of input_core_dims
is better for dask?
I think I need some improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dask='parallelized'
will only parallelize over broadcast dimensions, i.e., ones that don't appear in either input_core_dims
or output_core_dims
. So yes, it will probably be slow in many cases.
I'm still OK adding the non-optimal einsum for now and improving it later.
xarray/core/computation.py
Outdated
if len(arrays) < 2: | ||
raise TypeError('More than two arrays must be provided') | ||
|
||
if any(not hasattr(arr, 'dims') for arr in arrays): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dataset also defines dims
. It's probably better to explicitly use an isinstance()
check.
xarray/core/dataarray.py
Outdated
[d for d in other.dims if d not in dims]) | ||
|
||
return type(self)(new_data, new_coords.variables, new_dims) | ||
# backward compat: if there is no shared dimension, we rais an Errror |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to eliminate this special case. Then users can understand DataArray.dot
as a simple short-cut for xarray.dot()
.
xarray/core/computation.py
Outdated
arrays = args | ||
if dims is None and isinstance(args[-1], (list, tuple, basestring)): | ||
dims = args[-1] | ||
arrays = args[:-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to require specifying dims
with a keyword argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our previous dot
does not require dim
. This assumes to sum over along all the common dimensions.
I think dim=None
is not surprising.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, the default dims=None
should be OK. I meant that dims
should be a keyword only argument, not a required argument.
Here you are supporting xr.dot(a, b, 'x')
, where 'x'
denotes a dimension. I would require writing xr.dot(a, b, dim='x')
or omitting dim
altogether.
xarray/core/computation.py
Outdated
@@ -926,6 +926,86 @@ def earth_mover_distance(first_samples, | |||
return apply_array_ufunc(func, *args, dask=dask) | |||
|
|||
|
|||
def dot(*args, **kwargs): | |||
""" dot(*arrays, dims=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dot(*arrays, *, dims=None)
is the way to write this with Python 3's keyword only arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we would keep this as dot(*arrays, **kwargs)
as we did not yet drop python 2 support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused. def dot(*arrays, *, dims=None)
is not valid syntax in Python 3, either. (There can only be one single *
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP3102 says we python 3 supports the form def dot(*arrays, dim=None)
.
return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed', | ||
input_core_dims=input_core_dims, | ||
output_core_dims=output_core_dims, | ||
kwargs={'axes': axes}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I added a path for tensordot, which dask can compute more efficiently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some feedback on the documentation (mostly grammar).
xarray/core/computation.py
Outdated
---------- | ||
arrays: multiple DataArrays | ||
arrays to compute. | ||
dims: tuple of strings, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str or tuple of strings
xarray/core/computation.py
Outdated
""" dot(*arrays, *, dims=None) | ||
|
||
einsum for xarray object, but providing simpler interface based on | ||
the array dimensions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should lead with a more general description. Maybe:
Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.
xarray/core/computation.py
Outdated
|
||
Parameters | ||
---------- | ||
arrays: multiple DataArrays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*arrays: DataArray objects
xarray/core/computation.py
Outdated
Parameters | ||
---------- | ||
arrays: multiple DataArrays | ||
arrays to compute. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arrays
xarray/core/computation.py
Outdated
arrays: multiple DataArrays | ||
arrays to compute. | ||
dims: tuple of strings, optional | ||
Along which dimensions to be summed over. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which dimensions to sum over.
xarray/core/computation.py
Outdated
|
||
Returns | ||
------- | ||
dot: same type to input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should just "DataArray"?
xarray/core/computation.py
Outdated
|
||
common_dims = set(arrays[0].dims) | ||
for arr in arrays[1:]: | ||
common_dims = common_dims.intersection(set(arr.dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a slightly different choice of default dimensions than np.einsum
:
np.einsum
sums over any dimensions that are defined in two over more inputs.- This sums only over dimensions that are defined on all inputs.
Should we switch this behavior to match einsum
?
xarray/core/computation.py
Outdated
dims=['a', 'b', 'c']) | ||
>>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) | ||
|
||
>>> dot(da_a, da_b, dims=['a', 'b']).dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should use the full name xr.dot
.
xarray/core/computation.py
Outdated
dims = kwargs.pop('dims', None) | ||
|
||
if len(arrays) < 2: | ||
raise TypeError('More than one arrays must be provided') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this special case? If not, let's remove this. For consistency, it is nice to use the same logic even for edge cases when possible. This makes it easier to think about the function.
In this case, I think a dot product of 1 array would consistently defined by summing over dimensions listed explicitly in dims
.
xarray/core/computation.py
Outdated
dims = kwargs.pop('dims', None) | ||
if len(kwargs) > 0: | ||
raise TypeError('Invalid keyward arguments {} are given'.format( | ||
kwargs.keys())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
W1655 dict.keys referenced when not iterating
xarray/core/computation.py
Outdated
# find dimensions that exist in more than two arrays | ||
whole_dims = [] | ||
for arr in arrays: | ||
whole_dims += [d for d in arr.dims] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a nice use for collections.Counter()
, e.g.,
dim_counts = Counter():
for arr in arrays:
dim_counts.update(arr.dims)
xarray/core/computation.py
Outdated
@@ -974,27 +977,30 @@ def dot(*arrays, **kwargs): | |||
dims = [dims] | |||
|
|||
common_dims = set(arrays[0].dims) | |||
all_dims = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it work to make all_dims
a set instead of a list? I think that would be slightly more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to keep the occurrence order in all_dims
, so that to move input_core_dims positions back to the original position.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, sounds good.
xarray/core/computation.py
Outdated
if len(arrays) < 2: | ||
raise TypeError('More than one arrays must be provided') | ||
if len(arrays) < 2 and dims is None: | ||
raise TypeError('dim must be provided for one array computation.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there's only one array, wouldn't dims
just be any repeated dimensions on the single array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xarray objects do not have any repeated dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not strictly true: #1378 . That said, we certainly don't support repeated dims well right now.
Even if we banned repeated dimensions, I still think there's no harm in supporting the trivial xr.dot(array) -> array
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Updated.
xarray/core/computation.py
Outdated
@@ -926,6 +926,86 @@ def earth_mover_distance(first_samples, | |||
return apply_array_ufunc(func, *args, dask=dask) | |||
|
|||
|
|||
def dot(*args, **kwargs): | |||
""" dot(*arrays, dims=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused. def dot(*arrays, *, dims=None)
is not valid syntax in Python 3, either. (There can only be one single *
)
xarray/core/computation.py
Outdated
common_dims = set(arrays[0].dims) | ||
all_dims = [] | ||
for arr in arrays[1:]: | ||
common_dims = common_dims.intersection(set(arr.dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be slightly more efficient to construct common_dims
with a single call to intersection
?
e.g.,
common_dims = set.intersection(*[set(arr.dims) for arr in arrays])
if len(kwargs) > 0: | ||
raise TypeError('Invalid keyward arguments {} are given'.format( | ||
list(kwargs.keys()))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if you write xr.dot()
? I suppose we still need to raise an error for 0 arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait a little while to see if anyone else has feedback, e.g,. on the name. But this looks very nice to me!
xarray/core/computation.py
Outdated
@@ -968,15 +968,19 @@ def dot(*arrays, **kwargs): | |||
list(kwargs.keys()))) | |||
|
|||
if any(not isinstance(arr, DataArray) for arr in arrays): | |||
raise TypeError('Only xr.DataArray and xr.Variable are supported.') | |||
raise TypeError('Only xr.DataArray and xr.Variable are supported.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should either update the error message or isinstance()
check here -- right now they are inconsistent.
xarray/core/computation.py
Outdated
list(kwargs.keys()))) | ||
|
||
if any(not isinstance(arr, DataArray) for arr in arrays): | ||
raise TypeError('Only xr.DataArray and xr.Variable are supported.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either a type checking or a docstring issue:
In [8]: v=xr.Variable(data=np.random.rand(3,4), dims=('a','b'))
In [9]: xr.dot(v,v)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-9-fac8e1cb222a> in <module>()
----> 1 xr.dot(v,v)
~/drive/workspace/xarray/xarray/core/computation.py in dot(*arrays, **kwargs)
970 if any(not isinstance(arr, DataArray) for arr in arrays):
971 raise TypeError('Only xr.DataArray and xr.Variable are supported.'
--> 972 'Given {}.'.format([type(arr) for arr in arrays]))
973
974 if len(arrays) == 0:
TypeError: Only xr.DataArray and xr.Variable are supported.Given [<class 'xarray.core.variable.Variable'>, <class 'xarray.core.variable.Variable'>].
raise TypeError('At least one array should be given.') | ||
|
||
if isinstance(dims, basestring): | ||
dims = (dims, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW you don't need the parentheses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally like parentheses, as I think it is more descriptive.
xarray/core/computation.py
Outdated
if isinstance(dims, basestring): | ||
dims = (dims, ) | ||
elif isinstance(dims, list): | ||
dims = tuple(dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW dims=tuple(dims)
doesn't create any copies if dims
is already a tuple, so you could skip the if isinstance
check
Thanks, @maxim-lian |
This is awesome. Beautiful code, immediately impactful, and the API is so simple - a testament to the benefits of named dims Thank you @fujiisoup ! |
Do you know why the tests are failing? Do you want me to have a look? The arrays look the same: https://travis-ci.org/pydata/xarray/jobs/350640898#L5182. Would |
I just noticed the test failings. |
I'm going to merge this tomorrow if there are no further comments. |
whats-new.rst
for all changes andapi.rst
for new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later)Currently, lazy-einsum for dask is not yet working.
@shoyer
I think
apply_ufunc
supports lazy computation, but I did not yet figure out how to do this.Can you give me a help?