Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing array-api support for some stats functions? #8834

Open
4 of 5 tasks
jeremiah-corrado opened this issue Mar 14, 2024 · 9 comments
Open
4 of 5 tasks

Missing array-api support for some stats functions? #8834

jeremiah-corrado opened this issue Mar 14, 2024 · 9 comments
Labels
array API standard Support for the Python array API standard topic-arrays related to flexible array support upstream issue

Comments

@jeremiah-corrado
Copy link

What happened?

When creating a DataArray using an (array-api compliant) Arkouda Array, some numpy operations are not being directed to their array-api counterparts. For example, calling mean on the DataArray produces an error indicating that numpy.nanmean could not be called on the array object, but calling sum on the DataArray successfully calls arkouda's array-api implementation of sum.

As a temporary workaround, I've made a shim for mean that matches numpy's interface and calls the array-api compliant mean, annotating it with an implements decorator as described in this section of numpy's documentation:

@implements_numpy(np.nanmean)
@implements_numpy(np.mean)
def mean_shim(x: Array, axis=None, dtype=None, out=None, keepdims=False):
    return mean(x, axis=axis, keepdims=keepdims)

def mean(
    x: Array,
    /,
    *,
    axis: Optional[Union[int, Tuple[int, ...]]] = None,
    keepdims: bool = False,
) -> Array:
   ...

*full code here

What did you expect to happen?

My expectation was that XArray would call the array-api version of mean, as it does with sum for example. Note that I've run into a similar error with some other stats functions (like var), but haven't yet tested a wide selection of XArray's API to se what is/isn't working.

It's also very possible that the Arkouda implementation is missing something that would allow those calls to be properly redirected to arkouda's array-api, or that I've misunderstood XArrays requirements for array-api support, but I'm not sure what that would be.

Any help with getting to the bottom of what's going wrong here would be much appreciated!

Minimal Complete Verifiable Example

import arkouda as ak
import arkouda.array_api as Array
import xarray as xr
import numpy as np

ak.connect()

# random data
r = np.random.uniform(0, 1.0, 1000)

# calculate sum and mean using numpy
npD = xr.DataArray(
    name="R",
    data=r,
    dims=["x"],
)

print(npD.sum("x"))
print(npD.mean("x"))

# calculate sum and mean using arkouda
akD = xr.DataArray(
    name="R",
    data=Array.asarray(r),
    dims=["x"],
)

print(akD.sum("x"))
print(akD.mean("x"))

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

_         _                   _       
   / \   _ __| | _____  _   _  __| | __ _ 
  / _ \ | '__| |/ / _ \| | | |/ _` |/ _` |
 / ___ \| |  |   < (_) | |_| | (_| | (_| |
/_/   \_\_|  |_|\_\___/ \__,_|\__,_|\__,_|
                                          

Client Version: v2024.02.02+50.g4a04bc92b.dirty
arkouda/arkouda/array_api/__init__.py:275: UserWarning: The arkouda.array_api submodule is still experimental.
  warnings.warn("The arkouda.array_api submodule is still experimental.")
connected to arkouda server tcp://*:5555
<xarray.DataArray 'R' ()> Size: 8B
array(488.98327671)
<xarray.DataArray 'R' ()> Size: 8B
array(0.48898328)
making array, shape: (1000,),  dtype: float64
<xarray.DataArray 'R' ()> Size: 8B
array([488.98327670815462])

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 29
     22 akD = xr.DataArray(
     23     name="R",
     24     data=Array.asarray(r),
     25     dims=["x"],
     26 )
     28 print(akD.sum("x"))
---> 29 print(akD.mean("x"))

File virtualenv/lib/python3.11/site-packages/xarray/core/_aggregations.py:1664, in DataArrayAggregations.mean(self, dim, skipna, keep_attrs, **kwargs)
   1589 def mean(
   1590     self,
   1591     dim: Dims = None,
   (...)
   1595     **kwargs: Any,
   1596 ) -> Self:
   1597     """
   1598     Reduce this DataArray's data by applying ``mean`` along some dimension(s).
   1599 
   (...)
   1662     array(nan)
   1663     """
-> 1664     return self.reduce(
   1665         duck_array_ops.mean,
   1666         dim=dim,
   1667         skipna=skipna,
   1668         keep_attrs=keep_attrs,
   1669         **kwargs,
   1670     )

File virtualenv/lib/python3.11/site-packages/xarray/core/dataarray.py:3768, in DataArray.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   3724 def reduce(
   3725     self,
   3726     func: Callable[..., Any],
   (...)
   3732     **kwargs: Any,
   3733 ) -> Self:
   3734     """Reduce this array by applying `func` along some dimension(s).
   3735 
   3736     Parameters
   (...)
   3765         summarized data and the indicated dimension(s) removed.
   3766     """
-> 3768     var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
   3769     return self._replace_maybe_drop_dims(var)

File virtualenv/lib/python3.11/site-packages/xarray/core/variable.py:1618, in Variable.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   1611 keep_attrs_ = (
   1612     _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
   1613 )
   1615 # Noe that the call order for Variable.mean is
   1616 #    Variable.mean -> NamedArray.mean -> Variable.reduce
   1617 #    -> NamedArray.reduce
-> 1618 result = super().reduce(
   1619     func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs
   1620 )
   1622 # return Variable always to support IndexVariable
   1623 return Variable(
   1624     result.dims, result._data, attrs=result._attrs if keep_attrs_ else None
   1625 )

File virtualenv/lib/python3.11/site-packages/xarray/namedarray/core.py:888, in NamedArray.reduce(self, func, dim, axis, keepdims, **kwargs)
    884     if isinstance(axis, tuple) and len(axis) == 1:
    885         # unpack axis for the benefit of functions
    886         # like np.argmin which can't handle tuple arguments
    887         axis = axis[0]
--> 888     data = func(self.data, axis=axis, **kwargs)
    889 else:
    890     data = func(self.data, **kwargs)

File virtualenv/lib/python3.11/site-packages/xarray/core/duck_array_ops.py:649, in mean(array, axis, skipna, **kwargs)
    647     return _to_pytimedelta(mean_timedeltas, unit="us") + offset
    648 else:
--> 649     return _mean(array, axis=axis, skipna=skipna, **kwargs)

File virtualenv/lib/python3.11/site-packages/xarray/core/duck_array_ops.py:416, in _create_nan_agg_method.<locals>.f(values, axis, skipna, **kwargs)
    414     with warnings.catch_warnings():
    415         warnings.filterwarnings("ignore", "All-NaN slice encountered")
--> 416         return func(values, axis=axis, **kwargs)
    417 except AttributeError:
    418     if not is_duck_dask_array(values):

File virtualenv/lib/python3.11/site-packages/xarray/core/nanops.py:131, in nanmean(a, axis, dtype, out)
    126 with warnings.catch_warnings():
    127     warnings.filterwarnings(
    128         "ignore", r"Mean of empty slice", category=RuntimeWarning
    129     )
--> 131     return np.nanmean(a, axis=axis, dtype=dtype)

TypeError: no implementation found for 'numpy.nanmean' on types that implement __array_function__: [<class 'arkouda.array_api._array_object.Array'>]

Anything else we need to know?

The arkouda array-api implementation is still under development (I'm the main person working on it atm), so unfortunately the example I've provided is not easily reproducible, but I'd be happy to run any tests/experiments on my local environment where I have it set up.

Environment

INSTALLED VERSIONS

commit: None
python: 3.11.7 (main, Dec 4 2023, 18:10:11) [Clang 15.0.0 (clang-1500.1.0.2.5)]
python-bits: 64
OS: Darwin
OS-release: 22.6.0
machine: arm64
processor: arm
byteorder: little
LC_ALL: None
LANG: None
LOCALE: (None, 'UTF-8')
libhdf5: 1.12.2
libnetcdf: None

xarray: 2024.2.0
pandas: 2.2.1
numpy: 1.26.4
scipy: 1.12.0
netCDF4: None
pydap: None
h5netcdf: None
h5py: 3.10.0
Nio: None
zarr: None
cftime: None
nc_time_axis: None
iris: None
bottleneck: None
dask: None
distributed: None
matplotlib: 3.8.3
cartopy: None
seaborn: None
numbagg: None
fsspec: 2024.2.0
cupy: None
pint: None
sparse: None
flox: None
numpy_groupies: None
setuptools: 69.0.2
pip: 24.0
conda: None
pytest: None
mypy: None
IPython: 8.22.2
sphinx: None

@jeremiah-corrado jeremiah-corrado added bug needs triage Issue that has not been reviewed by xarray team member labels Mar 14, 2024
Copy link

welcome bot commented Mar 14, 2024

Thanks for opening your first issue here at xarray! Be sure to follow the issue template!
If you have an idea for a solution, we would really welcome a Pull Request with proposed changes.
See the Contributing Guide for more.
It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better.
Thank you!

@jhamman jhamman added the topic-arrays related to flexible array support label Mar 14, 2024
@keewis
Copy link
Collaborator

keewis commented Mar 14, 2024

thanks for the report, and we'd definitely like to be able to use reductions while ignoring nan values. However, as far as I can tell the array API does not include nan* functions (or anything equivalent), so we're out of luck there.

Note that you can opt out of the nan-skipping using skipna=False, which should properly dispatch to the arḱouda implementation (if it doesn't, that would be a bug). This will only have an effect for dtypes that can represent missing values, like float*, complex*, and datetime*.

@keewis keewis added upstream issue array API standard Support for the Python array API standard and removed needs triage Issue that has not been reviewed by xarray team member labels Mar 14, 2024
@dcherian
Copy link
Contributor

If you do have a nanmean, implementing __array_function__ on arkouda arrays will make it work.

@dcherian dcherian removed the bug label Mar 14, 2024
@jeremiah-corrado
Copy link
Author

jeremiah-corrado commented Mar 14, 2024

Note that you can opt out of the nan-skipping using skipna=False, which should properly dispatch to the arḱouda implementation

Okay, that makes sense. And it looks like that's working now, thanks!

However, doing the same thing for std and var is not working. I'm hitting:

TypeError: std() got an unexpected keyword argument 'ddof'

It looks like this is because the array api std calls it correction instead of ddof. Is this a case where a compatibility shim would be needed, or would this be considered a bug?

If you do have a nanmean, implementing array_function on arkouda arrays will make it work.

There isn't currently, but I'm thinking that will be a good solution for now (at least until something link nanmean and friends get added to the array api).

@keewis
Copy link
Collaborator

keewis commented Mar 14, 2024

would this be considered a bug?

indeed, that's a bug. See #8566 and #8573 Edit: and #7243

@jeremiah-corrado
Copy link
Author

okay, that's good to know, thank you!

@TomNicholas
Copy link
Member

If you do have a nanmean, implementing __array_function__ on arkouda arrays will make it work.

According to @tomwhite in cubed-dev/cubed#469 (comment) it's not quite as simple as this:

Yes, of course. With this change Xarray now thinks that it can use NdArrayLikeIndexingAdapter for indexing, rather than ArrayApiIndexingAdapter, see

if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)

Switching the order here (i.e. check for __array_namespace__ before __array_function__) would work in this case, but I'm not sure if it has other implications.

Switching the order seems okay to me... It should only matter for types that implement both. And if a type supports __array_namespace__ ideally we should be prioritizing that as it is the newer standard.

One question is whether there are any types of array indexing operations that xarray supports via __array_function__ that aren't in the array API standard?

@dcherian
Copy link
Contributor

One question is whether there are any types of array indexing operations that xarray supports via array_function that aren't in the array API standard?

NDArrayLikeIndexingAdapter is just NumpyIndexingAdapter so we're assuming those array types support numpy-like indexing.

the array API adapter doesn't support vectorized indexing yet (#8667) but we can add a fallback with indexing.explicit_indexing_adapter that all the backends use to harmonize indexing APIs.

@tomwhite
Copy link
Contributor

Switching the order seems okay to me... It should only matter for types that implement both. And if a type supports __array_namespace__ ideally we should be prioritizing that as it is the newer standard.

Agreed. NumPy 2 supports both but should not be affected since NumPy is special-cased earlier in the function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard topic-arrays related to flexible array support upstream issue
Projects
None yet
Development

No branches or pull requests

6 participants