-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: require that dtypes obey Python hashing rules #582
Comments
That seems fine to me to explicitly specify. >>> import numpy as np
>>> np.float32 == 'float32'
False
>>> np.dtype(np.float32) == 'float32'
True Only the first example is relevant for the array API standard, so I think this will be fine to specify since NumPy already complies. This one there is a problem in NumPy however: >>> np.dtype(np.float64) == float
True That can be considered a clear bug though, should be fixed in NumPy. |
So you're saying that
Agreed. What about This also violates Python's hashing invariant. |
I agree that it's a bug technically. Not 100% sure that the NumPy team will want that changed, but I hope so (and a proposal for a major release is in the works, so that could go into it). For the array API standard it's not an issue, because there is no
That's more for the NumPy issue tracker, but if it were up to me then yes. For this issue tracker, I'm +1 on adopting language in the standard like: "All objects in this standard must adhere to the following requirement (as required by Python itself): objects which compare equal have the same hash value". |
That would be amazing. That's exactly what I was hoping for.
Okay, thanks for explaining. If the above language were adopted, NumPy could implement that by making |
Let's give it a bit of time to see if anyone sees a reason not to add such a requirement. I can open a PR after the holidays. |
Just noticed this comment. It is currently an issue in NumPy's implementation of the Array API: import numpy.array_api as xp
xp.float32 == xp.float32.type # True! This is because With the language you suggested above, NumPy would be forced to do this to become compliant 😄 .
Same thing here, I think. NumPy will probably reject this for their own namespace ( Incidentally, I assume you want |
There is no
No, definitely not. No objects from two different libraries should ever compare equal, unless they're indeed the same object. |
Ok! Thanks for explaining.
So to do things like checking that two arrays have the same dtype, or creating a NumPy array that has the same type as a Jax array, we'll need mappings like: m = {jax.array_api.float32: np.array_api.float32, ...} And code like np.array_api.ones_like(some_jax_array) # works today, in either direction. is impossible, yes? You need: np.array_api.ones(some_jax_array.shape, dtype=m[some_jax_array.dtype]) |
Having to use library-specific constructs should not be needed - if so, we're missing an API I'd say. More importantly: mixing arrays from different libraries like this is a bit of an anti-pattern. You can't do much with that, neither library has kernels for functions that use both array types, so you're probably relying on implicit conversion of one to the other. So in this case, let me assume that # First retrieve the namespace you want to work with
xp = x.__array_namespace__()
# Use DLPack or the buffer protocol to convert a CPU JAX array to a NumPy array
y = xp.asarray(y)
# Now we can compare dtypes:
if x.dtype == y.dtype == xp.float32:
# If the same dtypes, do stuff
# Or, similarly:
if xp.isdtype(x, xp.float32) and xp.isdtype(y, xp.float32):
yes indeed I'm actually a little surprised JAX accepts numpy arrays. It seems to go against its philosophy; TensorFlow, PyTorch and CuPy will all raise. When you call JAX is also annotating its array inputs as >>> jnp.sin([1, 2, 3])
...
TypeError: sin requires ndarray or scalar arguments, got <class 'list'> at position 0 All this stuff is bug-prone: >>> jnp.sin(np.array([1, 2, 3]))
Array([0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
>>> jnp.sin(np.ma.array([1, 2, 3], mask=[True, False, True])) # bug in user code here, because JAX silently discards mask
Array([0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
>>> np.sin(np.ma.array([1, 2, 3], mask=[True, False, True]))
masked_array(data=[--, 0.9092974268256816, --],
mask=[ True, False, True],
fill_value=1e+20) |
Okay, makes sense. I haven't been very conscious about this because (as you pointed out) Jax implicitly converts. I will be more careful.
I think this is where I'm confused. Somehow numpy has to know what its equivalent dtypes are for Jax's dtypes even though they don't compare equal? Or will it produce a numpy array with a Jax dtype? As this seems to work: In [12]: x = jnp.ones(10, jnp.bfloat16)
In [14]: np.asarray(x)
Out[14]: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=bfloat16)
Very interesting. I wonder what the Jax team would say. |
NumPy knows the dtype, as does JAX. This conversion uses the Python buffer protocol or DLPack, both of which are protocols explicitly meant for exchanging data in a reliable way (that includes dtype, shape, endianness, etc.). So the
Let's try to find out:) This section of the JAX docs only explains why JAX doesn't accept list/tuple/etc., but I cannot find an explanation of why it does accept numpy arrays and scalars. @shoyer or @jakevdp, would you be able to comment on why JAX implements a limited form of "array-like"? Also, in addition to bug with masked arrays above, here is another bug: >>> jnp.sin(np.float64(1.5)) # silent precision loss here, downcasting to float32
Array(0.997495, dtype=float32)
>>> jax.__version__
'0.4.1' |
In that case, there should be a way to convert dtypes using both the buffer protocol or DLPack? Something more efficient than: def x_to_y_dtype(some_xp_dtype: DType, yp: ArrayInterfac) -> DType:
xp = some_xp_dtype.__array_interface__ # doesn't exist
x = xp.ones((), dtype=some_xp_dtype)
yp.asarray(x)
return yp.dtype Should dtypes have a |
No, those protocols are specifically for exchanging data (strided arrays/buffers). A dtype without data isn't very meaningful. You could exchange a size-1 array if needed, or a |
I understand, but in order to exchange data, they have to be able to convert dtypes. So, that dtype conversion is happening somehow, and I was just wondering if that conversion can be accessed by the user. |
It's not user-accessible, it's all under the hood. Specifically for JAX you have a shortcut, because it reuses NumPy dtypes directly: >>> type(jnp.float32)
<class 'jax._src.numpy.lax_numpy._ScalarMeta'>
>>> type(jnp.float32.dtype)
<class 'numpy.dtype[float32]'> |
(Thanks for all the patient explanations!) |
JAX avoids implicit conversion of Python sequences, because it can hide severe performance issues. When something like On the other hand,
This is working as intended: JAX only allows 64-bit values when explicitly enabled; see Double (64-bit) Precision. This was an early design decision that the team recognizes as non-ideal, but it has proven difficult to change because so many users depend on the bit truncation behavior and enjoy the accelerator-friendly type safety it confers.
The reason JAX defines this is that it made the early design choice to not distinguish between scalars and zero-dimensional arrays. We could have defined simple functions named
To my knowledge, this bug has never come up (probably because masked arrays are so rarely used in practice). I'll raise it in the JAX repo. |
Thanks for the context @jakevdp!
I knew that 64-bit precision must be explicitly enabled, but this is still surely a bug? The expected behavior is an exception, saying asking the user to explicitly downcast if the precision loss is fine, or to enable 64-bit precision. Or at the very least emit a warning. Silent downcasting is terrible - it may be okay for deep learning, but it typically isn't for general purposes numerical/scientific computing. |
That might be unfortunate when one goes from converting a CPU program to a GPU one? It might be nice to be able to enable a flag that makes this into a runtime error. That way I can remove all of my unintentional jax/numpy array mixing. |
I think you hit on the key point here: there are different communities with different requirements, and JAX attempts, maybe clumsily, to serve them all. If you are doing deep learning and care about performance over potential precision loss, you can set It's a difficult problem to solve well in a single package: it's worth noting that NumPy's answer to requests to serve the needs of deep learning is essentially no, which is a defensible choice given the package's early design decisions. |
This has been one of the few NumPy things that I dislike (and that would be moot for Array API). In NumPy, >>> type(np.float32)
<class 'type'> whereas >>> type(np.dtype(np.float32))
<class 'numpy.dtype[float32]'> The former is needed, IIUC, only because of the need to construct NumPy scalars. Once NumPy removes this concept (how about NumPy 2.0, @seberg? 🙂) we can (and should) make them equivalent! |
I do not really want to touch removing scalars from NumPy; maybe someone more confident about it can push for such a thing... Maybe to be clear, to change NumPy here I see now other way then (I think this is what Ralf said):
If you remove scalars, then I don't see another way, so you can put it into |
I'd argue that if there's any design that could bring us closer to full compliance in the main namespace with the standard, we should consider it, and removing scalars in favor of 0D arrays is one of them. It's been a source of confusion with no obvious gain except for keeping legacy code work. It's been made clear that no accelerator library would support it. Also, removing scalars would keep the type promotion lattice cleaner. So,
Yes.
Not at all noisy 🙂
All I care is 1. eventual compliance, and 2. reducing both user confusion and developer (you) workload 🙂 If this is something that could take 1 full developer year to do, so be it. |
You can change NumPy relatively easily. The problem is dealing with whether pandas and others need involved change. So the issue about scalars (and to some degree also this in general), is that it is very much holistic and I can zoom in on NumPy and give you a branch where scalars may still be in the code base but should never be created... (I am also admittedly the one person who hates everything about NumPy scalars, but isn't sure that scalars themselves are all that bad.) |
Scalars themselves aren't that bad, if only they weren't created by operations like I have a sense that it's doable in principle, but that it's one step too far for NumPy 2.0. |
Yes, I would be willing to experiment with the "getting scalars more right part". But also yes: even that needs at least testing to have confidence that it would be 2.0 scoped (i.e. few enough users actually notice and if they do mostly in harmless ways). |
I'd be interested in helping out with an effort like this. But I don't think I can be one of the two "champions" (see here) for this one, I already signed up for enough other stuff for NumPy 2.0. |
Going back to the original discussion, another annoying thing NumPy does is
which has tripped us up a few times in the test suite. |
Would it be possible in NumPy to make |
I love this idea. This would be a step towards what Leo wanted above: "any design that could bring us closer to full compliance in the main namespace with the standard". I think if we don't do what you're suggesting, it will be a source of confusion that
(And |
Python's documentation promises that: "The only required property is that objects which compare equal have the same hash value…" However, NumPy dtypes do not follow this requirement. As discussed in numpy/numpy#7242, dtype objects, their types, and their names all compare equal despite hashing unequal. Could the Array API promise that this will no longer be the case?
The text was updated successfully, but these errors were encountered: