Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

einsum for xarray #1968

Merged
merged 14 commits into from
Mar 12, 2018
Merged

einsum for xarray #1968

merged 14 commits into from
Mar 12, 2018

Conversation

fujiisoup
Copy link
Member

@fujiisoup fujiisoup commented Mar 6, 2018

  • Closes einsum for xarray #1951
  • Tests added
  • Tests passed (for all non-documentation changes)
  • Fully documented, including whats-new.rst for all changes and api.rst for new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later)

Currently, lazy-einsum for dask is not yet working.

@shoyer
I think apply_ufunc supports lazy computation, but I did not yet figure out how to do this.
Can you give me a help?

Copy link
Member

@shoyer shoyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

subscripts = ''
for ds in input_core_dims:
subscripts += '...' + ''.join([dim_map[d] for d in ds]) + ','
subscripts = subscripts[:-1] # remove last comma
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be cleaner to build up subscripts as a list and use ','.join(subscripts_list) once at the end.


result = apply_ufunc(np.einsum, subscripts, *arrays,
input_core_dims=[[]] + input_core_dims,
output_core_dims=output_core_dims, dask='allowed')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dask='parallelized' is what you want here -- that generate the wrapper to do this with dask. This will require also determining the result data type, probably with dtypes.result_type or even np.result_type (we don't need support for non-numeric types in einsum, so I'm pretty sure NumPy's casting rules would work fine).

dask='allowed' would be appropriate if np.einsum already supported dask arrays (but it does not).

It's possible that a dask specific einsum could be much more efficient than the auto-generated wrapper here, but certainly this is good enough for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I noticed that my current implementation is not very efficient for dask.
Maybe smaller number of input_core_dims is better for dask?

I think I need some improvement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dask='parallelized' will only parallelize over broadcast dimensions, i.e., ones that don't appear in either input_core_dims or output_core_dims. So yes, it will probably be slow in many cases.

I'm still OK adding the non-optimal einsum for now and improving it later.

if len(arrays) < 2:
raise TypeError('More than two arrays must be provided')

if any(not hasattr(arr, 'dims') for arr in arrays):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset also defines dims. It's probably better to explicitly use an isinstance() check.

[d for d in other.dims if d not in dims])

return type(self)(new_data, new_coords.variables, new_dims)
# backward compat: if there is no shared dimension, we rais an Errror
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to eliminate this special case. Then users can understand DataArray.dot as a simple short-cut for xarray.dot().

arrays = args
if dims is None and isinstance(args[-1], (list, tuple, basestring)):
dims = args[-1]
arrays = args[:-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to require specifying dims with a keyword argument.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our previous dot does not require dim. This assumes to sum over along all the common dimensions.
I think dim=None is not surprising.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the default dims=None should be OK. I meant that dims should be a keyword only argument, not a required argument.

Here you are supporting xr.dot(a, b, 'x'), where 'x' denotes a dimension. I would require writing xr.dot(a, b, dim='x') or omitting dim altogether.

@@ -926,6 +926,86 @@ def earth_mover_distance(first_samples,
return apply_array_ufunc(func, *args, dask=dask)


def dot(*args, **kwargs):
""" dot(*arrays, dims=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dot(*arrays, *, dims=None) is the way to write this with Python 3's keyword only arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we would keep this as dot(*arrays, **kwargs) as we did not yet drop python 2 support?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused. def dot(*arrays, *, dims=None) is not valid syntax in Python 3, either. (There can only be one single *)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP3102 says we python 3 supports the form def dot(*arrays, dim=None).

return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed',
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
kwargs={'axes': axes})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I added a path for tensordot, which dask can compute more efficiently.

Copy link
Member

@shoyer shoyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some feedback on the documentation (mostly grammar).

----------
arrays: multiple DataArrays
arrays to compute.
dims: tuple of strings, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str or tuple of strings

""" dot(*arrays, *, dims=None)

einsum for xarray object, but providing simpler interface based on
the array dimensions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should lead with a more general description. Maybe:

Generalized dot product for xarray objects. Like np.einsum, but 
provides a simpler interface based on array dimensions.


Parameters
----------
arrays: multiple DataArrays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*arrays: DataArray objects

Parameters
----------
arrays: multiple DataArrays
arrays to compute.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays

arrays: multiple DataArrays
arrays to compute.
dims: tuple of strings, optional
Along which dimensions to be summed over.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which dimensions to sum over.


Returns
-------
dot: same type to input.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should just "DataArray"?


common_dims = set(arrays[0].dims)
for arr in arrays[1:]:
common_dims = common_dims.intersection(set(arr.dims))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slightly different choice of default dimensions than np.einsum:

  • np.einsum sums over any dimensions that are defined in two over more inputs.
  • This sums only over dimensions that are defined on all inputs.

Should we switch this behavior to match einsum?

dims=['a', 'b', 'c'])
>>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd'])

>>> dot(da_a, da_b, dims=['a', 'b']).dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should use the full name xr.dot.

dims = kwargs.pop('dims', None)

if len(arrays) < 2:
raise TypeError('More than one arrays must be provided')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this special case? If not, let's remove this. For consistency, it is nice to use the same logic even for edge cases when possible. This makes it easier to think about the function.

In this case, I think a dot product of 1 array would consistently defined by summing over dimensions listed explicitly in dims.

dims = kwargs.pop('dims', None)
if len(kwargs) > 0:
raise TypeError('Invalid keyward arguments {} are given'.format(
kwargs.keys()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

W1655 dict.keys referenced when not iterating

# find dimensions that exist in more than two arrays
whole_dims = []
for arr in arrays:
whole_dims += [d for d in arr.dims]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a nice use for collections.Counter(), e.g.,

dim_counts = Counter():
for arr in arrays:
    dim_counts.update(arr.dims)

@@ -974,27 +977,30 @@ def dot(*arrays, **kwargs):
dims = [dims]

common_dims = set(arrays[0].dims)
all_dims = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it work to make all_dims a set instead of a list? I think that would be slightly more efficient.

Copy link
Member Author

@fujiisoup fujiisoup Mar 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to keep the occurrence order in all_dims, so that to move input_core_dims positions back to the original position.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good.

if len(arrays) < 2:
raise TypeError('More than one arrays must be provided')
if len(arrays) < 2 and dims is None:
raise TypeError('dim must be provided for one array computation.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's only one array, wouldn't dims just be any repeated dimensions on the single array?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xarray objects do not have any repeated dimensions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not strictly true: #1378 . That said, we certainly don't support repeated dims well right now.

Even if we banned repeated dimensions, I still think there's no harm in supporting the trivial xr.dot(array) -> array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Updated.

@@ -926,6 +926,86 @@ def earth_mover_distance(first_samples,
return apply_array_ufunc(func, *args, dask=dask)


def dot(*args, **kwargs):
""" dot(*arrays, dims=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused. def dot(*arrays, *, dims=None) is not valid syntax in Python 3, either. (There can only be one single *)

common_dims = set(arrays[0].dims)
all_dims = []
for arr in arrays[1:]:
common_dims = common_dims.intersection(set(arr.dims))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be slightly more efficient to construct common_dims with a single call to intersection?

e.g.,
common_dims = set.intersection(*[set(arr.dims) for arr in arrays])

if len(kwargs) > 0:
raise TypeError('Invalid keyward arguments {} are given'.format(
list(kwargs.keys())))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you write xr.dot()? I suppose we still need to raise an error for 0 arguments.

Copy link
Member

@shoyer shoyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait a little while to see if anyone else has feedback, e.g,. on the name. But this looks very nice to me!

@@ -968,15 +968,19 @@ def dot(*arrays, **kwargs):
list(kwargs.keys())))

if any(not isinstance(arr, DataArray) for arr in arrays):
raise TypeError('Only xr.DataArray and xr.Variable are supported.')
raise TypeError('Only xr.DataArray and xr.Variable are supported.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should either update the error message or isinstance() check here -- right now they are inconsistent.

list(kwargs.keys())))

if any(not isinstance(arr, DataArray) for arr in arrays):
raise TypeError('Only xr.DataArray and xr.Variable are supported.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either a type checking or a docstring issue:

In [8]: v=xr.Variable(data=np.random.rand(3,4), dims=('a','b'))

In [9]: xr.dot(v,v)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-9-fac8e1cb222a> in <module>()
----> 1 xr.dot(v,v)

~/drive/workspace/xarray/xarray/core/computation.py in dot(*arrays, **kwargs)
    970     if any(not isinstance(arr, DataArray) for arr in arrays):
    971         raise TypeError('Only xr.DataArray and xr.Variable are supported.'
--> 972                         'Given {}.'.format([type(arr) for arr in arrays]))
    973
    974     if len(arrays) == 0:

TypeError: Only xr.DataArray and xr.Variable are supported.Given [<class 'xarray.core.variable.Variable'>, <class 'xarray.core.variable.Variable'>].

raise TypeError('At least one array should be given.')

if isinstance(dims, basestring):
dims = (dims, )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW you don't need the parentheses

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally like parentheses, as I think it is more descriptive.

if isinstance(dims, basestring):
dims = (dims, )
elif isinstance(dims, list):
dims = tuple(dims)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW dims=tuple(dims) doesn't create any copies if dims is already a tuple, so you could skip the if isinstance check

@fujiisoup
Copy link
Member Author

Thanks, @maxim-lian
added xr.Variable support for xr.dot.

@max-sixty
Copy link
Collaborator

This is awesome. Beautiful code, immediately impactful, and the API is so simple - a testament to the benefits of named dims

Thank you @fujiisoup !

@max-sixty
Copy link
Collaborator

Do you know why the tests are failing? Do you want me to have a look?

The arrays look the same: https://travis-ci.org/pydata/xarray/jobs/350640898#L5182. Would assert_close help?

@fujiisoup
Copy link
Member Author

I just noticed the test failings.
This was a bug caused by the undefined order of set.
Fixed. Thanks :)

@shoyer shoyer mentioned this pull request Mar 9, 2018
3 tasks
@fujiisoup
Copy link
Member Author

I'm going to merge this tomorrow if there are no further comments.

@fujiisoup fujiisoup merged commit 8271dff into pydata:master Mar 12, 2018
@fujiisoup fujiisoup deleted the einsum branch March 12, 2018 06:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants