Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpy-style indexed update support. #101

Closed
meet-cjli opened this issue Dec 13, 2018 · 5 comments
Closed

Numpy-style indexed update support. #101

meet-cjli opened this issue Dec 13, 2018 · 5 comments
Assignees
Labels
enhancement New feature or request question Questions for the JAX team

Comments

@meet-cjli
Copy link

Hi, I use numpy to read data. In the naive numpy x_train[(i - 1) * 10000 : i * 10000, :, :, :] = data is OK!However, in the jax.numpy , it raise a ValueError: assignment destination is read-only. Meanwhile, np.set_printoptions() is not implemented.

@mattjj
Copy link
Collaborator

mattjj commented Dec 13, 2018

When you use jax.numpy as np, array creation functions like np.zeros return immutable arrays, unlike regular numpy arrays. If you want to use regular, mutable numpy ndarrays outside of the context of JAX differentiation or jit compilation, perhaps to build up a training dataset using indexed assignment like this, one option is to import numpy as onp and use that separately from jax.numpy as np. That is, you can write things like this:

import jax.numpy as np
import numpy as onp  # original numpy

x = onp.zeros((10, 3))  # original numpy array, supports indexed assignment/mutation
for i in range(10):
  x[i, :] = (i, 2 * i)

# now we can pass x into any jax.numpy function

Why does jax.numpy create immutable arrays at all? We could have our jax.numpy array-creation functions return regular mutable ndarrays in some cases, though there's a tradeoff here: when should array-creation functions to create normal CPU-backed ndarrays, and when should they create device-backed (e.g. GPU-backed) immutable XLA arrays? By importing the original numpy as onp, you can control when you're working with the original numpy versus our device-backed numpy. The downside is you have to juggle two numpys in your code.

We're thinking about better solutions, and getting feedback from users is exactly the kind of information we could use!

That was all outside the context of JAX function transformations, like grad and jit. If you want to use those, you have to work with jax.numpy and not the original numpy. Those transformations are much easier to implement without mutation like indexed assignment, though there's a chance JAX might support it someday. See the What's supported section of the README for more information. If you want convenient in-place assignment syntax, you can try creating wrapper objects that override __setitem__ like those in jax.experimental.lapax, but be careful: updated values have to be passed back out of any jit compiled computation.

Below are some comments that I think are only tangentially related.

If you want to express something like an in-place update in jit compiled code, you can do it by using lax.dynamic_update_slice; when using jit the compiler will turn any lax.dynamic_update_slice calls into in-place updates (unless it knows it can do something better!). There's no docstring for that function yet, but it very closely models the XLA DynamicUpdateSlice HLO.

There's a second challenge here that would again only come up if you're using @jit on code that includes expressions like the example you gave, which is that we also don't yet support expressions like x[i:i+5] (even on the right-hand side of assignments) when i is dynamic (in the sense that it's an abstract value during jit tracing). We're looking at a design to add support for that, but in the meantime to for slice-based indexed expressions like that one, your options are to

  1. use lax.dynamic_slice to read blocks out of arrays (HLO docs here),
  2. use reshape together with indexing like x[i], or
  3. don't jit compile functions that include indexing expressions like that (since slice-based indexing works fine when not under a jit decorator).

If I understand correctly, adding np.set_printoptions is a separate issue. I can take care of that now.

Does that make sense? If you say a bit more about your intended use case for indexed assignment, or give a simple example, we might be able to provide more recommendations. It would also help inform our design for when arrays created by jax.numpy could support in-place mutation.

@mattjj mattjj added the question Questions for the JAX team label Dec 13, 2018
@mattjj mattjj self-assigned this Dec 13, 2018
@mattjj mattjj added the enhancement New feature or request label Dec 13, 2018
@mattjj
Copy link
Collaborator

mattjj commented Dec 13, 2018

There may be a simple way we can support this specific behavior (namely, in-place indexed assignment outside the context of jit or grad, which I'm guessing is the case here based on the example expression provided). I'll update this issue with progress.

@hawkinsp hawkinsp changed the title ValueError: assignment destination is read-only Numpy-style indexed update support. Feb 22, 2019
hawkinsp added a commit to hawkinsp/jax that referenced this issue Mar 4, 2019
…NumPy-style indexed updates.

Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy.

Progress on issue jax-ml#101. Fixes jax-ml#122.
hawkinsp added a commit to hawkinsp/jax that referenced this issue Mar 4, 2019
…NumPy-style indexed updates.

Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy.

Progress on issue jax-ml#101. Fixes jax-ml#122.

Reenable some disabled TPU indexing tests that now pass.
@hawkinsp
Copy link
Collaborator

We now recommend jax.ops.index_add and jax.ops.index_update wherever in-place indexed update would have been used. Hope that helps!

@yardenas
Copy link

yardenas commented Dec 2, 2021

Hi all,
Another workaround for partially dynamic slices, for which the length is static but the start is dynamic is as follows:

import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(42)
size = 4
start = jax.random.randint(key, (1, ), 0, 10 - size)

x = jnp.ones((10,))

# Dynamic slicing (not jittable)
y = x[start:start + size]

# Use fancy indexing instead of slicing
funky_arange = lambda start, size: start + jnp.cumsum(jnp.ones((size,), jnp.int32))
y = x[funky_arang(start, size)]

I had to define funky_arange to bypass

FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

I am quite new to JAX, so not entirely sure exactly when/where this could break, what do you think? @mattjj

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 3, 2021

This situation (dynamic start with static size) is exactly what lax.dynamic_slice is designed for. I believe that function would be more direct & efficient:

from jax import lax
lax.dynamic_slice(x, [start], [size])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

5 participants