-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numpy-style indexed update support. #101
Comments
When you use import jax.numpy as np
import numpy as onp # original numpy
x = onp.zeros((10, 3)) # original numpy array, supports indexed assignment/mutation
for i in range(10):
x[i, :] = (i, 2 * i)
# now we can pass x into any jax.numpy function Why does We're thinking about better solutions, and getting feedback from users is exactly the kind of information we could use! That was all outside the context of JAX function transformations, like Below are some comments that I think are only tangentially related. If you want to express something like an in-place update in There's a second challenge here that would again only come up if you're using
If I understand correctly, adding Does that make sense? If you say a bit more about your intended use case for indexed assignment, or give a simple example, we might be able to provide more recommendations. It would also help inform our design for when arrays created by |
There may be a simple way we can support this specific behavior (namely, in-place indexed assignment outside the context of |
…NumPy-style indexed updates. Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy. Progress on issue jax-ml#101. Fixes jax-ml#122.
…NumPy-style indexed updates. Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy. Progress on issue jax-ml#101. Fixes jax-ml#122. Reenable some disabled TPU indexing tests that now pass.
We now recommend |
Hi all, import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(42)
size = 4
start = jax.random.randint(key, (1, ), 0, 10 - size)
x = jnp.ones((10,))
# Dynamic slicing (not jittable)
y = x[start:start + size]
# Use fancy indexing instead of slicing
funky_arange = lambda start, size: start + jnp.cumsum(jnp.ones((size,), jnp.int32))
y = x[funky_arang(start, size)] I had to define
I am quite new to JAX, so not entirely sure exactly when/where this could break, what do you think? @mattjj |
This situation (dynamic start with static size) is exactly what from jax import lax
lax.dynamic_slice(x, [start], [size]) |
Hi, I use numpy to read data. In the naive numpy x_train[(i - 1) * 10000 : i * 10000, :, :, :] = data is OK!However, in the jax.numpy , it raise a ValueError: assignment destination is read-only. Meanwhile, np.set_printoptions() is not implemented.
The text was updated successfully, but these errors were encountered: