Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Should we allow querying certain implementation details #499

Closed
seberg opened this issue Oct 20, 2022 · 11 comments · Fixed by #689
Closed

ENH: Should we allow querying certain implementation details #499

seberg opened this issue Oct 20, 2022 · 11 comments · Fixed by #689
Milestone

Comments

@seberg
Copy link
Contributor

seberg commented Oct 20, 2022

I am wondering if it would make sense to standardize a bit of information about the exported namespace. For example

  • immutability: If arrays are always immutable, sometimes a copy may not be needed?
  • Is it interesting to know whether in-place operations are actually in-place?

I somewhat suspect that such information will be useful eventually, but not sure that we have anything explicit right now that would be.

@asmeurer
Copy link
Member

Should this be a query on the namespace or on an array object?

@seberg
Copy link
Contributor Author

seberg commented Oct 21, 2022

I was thinking of namespace, but maybe immutability would need to be array object specific and not namespace... I suppose we can probably close this for now though, I think it may be useful to "bag" details, but right now I am not sure if users actually have a need for any information :).

@kgryte
Copy link
Contributor

kgryte commented Nov 3, 2022

@seberg Can you think of any practical use cases where an array API consumer would/should branch based on the implementation info you mention in the OP?

E.g., what would be an example of when an array API consumer should be concerned with whether an array library supports mutation? You mention determining whether a copy is needed; any practical examples of where this would be applicable and a downstream library responsibility rather than in the array library itself?

@seberg
Copy link
Contributor Author

seberg commented Nov 3, 2022

The thing why I was wondering was the discussion about esuring you have a copy:

api = getnamespace
if not api.mutable:
    # ensure user input cannot be mutated:
    arr = arr.copy()

# continue using in-place ops:
arr += 1
arr *= 5

@seberg
Copy link
Contributor Author

seberg commented Nov 3, 2022

In the meeting there were a few other things that came up and may be interesting:

  • Support for non-native byte order
  • Actually supported dtypes
  • Supported devices

@rgommers
Copy link
Member

rgommers commented Nov 4, 2022

api = getnamespace
if not api.mutable:
    # ensure user input cannot be mutated:
    arr = arr.copy()

I assume the second line was supposed to be if api.mutable:? Either way, not sure this is a convincing example, because for immutable arrays copy() should be a no-op, so you may as well leave out the check.

For the other things, that seems like potentially useful info when using a single library. But when you're using the standard API, I'm wondering what you are going to do with something like "yes, supports non-native byte order". Because it doesn't come back anywhere else in the API, so it seems more like an FYI.

The actual dtypes and devices I can imagine using; they can be passed on to dtype= or device= keywords.

@rgommers
Copy link
Member

rgommers commented Mar 9, 2023

I was reminded of this issue because of gh-609. It's still a pain that we have data-dependent features in the API that are optional to implement. So for writing portable and performant code across those libraries that do and don't perform it, we need something like:

if xp.supports_dynamic_shapes:
    y = x[mask]
    z = xp.nonzero(x)
    ...
else:
    # longer workarounds
    ...

This is the most concrete case that needs library-specific querying I think. Unless we can avoid any optional features at all in the standard, those should be introspectable somehow.

@honno
Copy link
Member

honno commented May 31, 2023

Another reason to support querying supported dtypes: JAX may or may not support double-precision dtypes, depending if they're enabled. In practice the respective dtype objects (int64, float64, complex128) always exist in the jax.numpy namespace, but when used resulting arrays are returned with the respective single-precision dtype, unless double-precision is enabled. For example:

>>> from jax import numpy as jnp
>>> jnp.int64
>>> jnp.asarray(1, dtype=jnp.int64)
UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in asarray is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
Array(1, dtype=int32)

Currently in Hypothesis and array-api-tests we essentially use hasattr(<array_namespace>, "<dtype>") to check if a dtype is supported. This worked with torch not supporting any unsigned int but uint8, e.g. hasattr(torch, "uint32") == False, but obviously doesn't work for JAX. I can't think of a clean "library-agnostic" way to check for what JAX is doing either.

Just like with torch unlikely to support other uints anytime soon, JAX might not change their behaviour either—from @jakevdp in #582 (comment):

JAX only allows 64-bit values when explicitly enabled; see Double (64-bit) Precision. This was an early design decision that the team recognizes as non-ideal, but it has proven difficult to change because so many users depend on the bit truncation behavior and enjoy the accelerator-friendly type safety it confers.

@asmeurer
Copy link
Member

Maybe we should standardize libraries being able to not support a subset of the dtypes. Right now all dtypes are required, and libraries like PyTorch and JAX are strictly out of compliance by not including them.

@jakevdp
Copy link

jakevdp commented Jun 1, 2023

JAX does include all dtypes – just run the tests with JAX_ENABLE_X64=true.

@honno
Copy link
Member

honno commented Jun 1, 2023

JAX does include all dtypes – just run the tests with JAX_ENABLE_X64=true.

Right, and this could just be fine for practical applications of using JAX + Array API. My concern is downstream libraries need to know they're dealing with JAX so they can enable this (or avoid double-precision dtypes), which hurts the use case a bit using this spec.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants