Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: add materialize to materialize lazy arrays #839

Open
lucascolley opened this issue Aug 31, 2024 · 19 comments
Open

RFC: add materialize to materialize lazy arrays #839

lucascolley opened this issue Aug 31, 2024 · 19 comments
Labels
API extension Adds new functions or objects to the API. Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Lazy/Graph Lazy and graph-based array implementations.

Comments

@lucascolley
Copy link
Contributor

lucascolley commented Aug 31, 2024

Preface

I do not think that I am the best person to champion this effort, as I am far from the most informed person here on Lazy arrays. I'm probably missing important things, but I would like to start this discussion as I think that it is an important topic.

The problem

The problem of mixing computation requiring data-dependent properties with lazy execution is discussed in detail elsewhere:

A possible solution

Add the function materialize(x: Array) to the top level of the API. Behaviour:

  • for eagerly-executed arrays, this would be a no-op
  • for lazy arrays, this would force computation such that the data is available in the returned array (which is of the same array type?)
  • for "100% lazy" arrays (Handling materialization of lazy arrays #748 (comment)), this would raise an exception

Prior art

Concerns

  • I think the main concern is whether eager-only libraries will agree to adding a no-op into the API. There is precedent for that type of change (e.g. device kwargs in NumPy), but perhaps this is too obtrusive?
  • As far as I can tell there isn't a standard way to do this across lazy libraries. Does JAX just do this automatically when it would be needed? Do other libraries have this capability?

Alternatives

  1. Do nothing. The easy option, but it leaves us unable to support lazy arrays when data-dependent properties are used in computation (maybe that is okay?)
  2. An alternative API. Maybe spelled like compute* or a method on the array object. Maybe with options for partial materialization (if that's a thing)?

cc @TomNicholas @hameerabbasi @rgommers

@kgryte kgryte added RFC Request for comments. Feature requests and proposed changes. API extension Adds new functions or objects to the API. Needs Discussion Needs further discussion. topic: Lazy/Graph Lazy and graph-based array implementations. labels Sep 3, 2024
@asmeurer
Copy link
Member

asmeurer commented Sep 3, 2024

A question is whether it's appropriate for an array API consuming library to materialize a lazy graph "behind the user's back", as it were. Such an operation could be quite expensive in general, and a user might be surprised to find that a seemingly innocuous function from something like scipy is doing this.

On the other hand, if an algorithm fundamentally depends on array values in a loop, there's no way it can be implemented without something like this. So maybe the answer is we should just provide guidance that functions that use materialize should clearly document that they do, or maybe even disable it unless given some sort of explicit flag by the user.

@asmeurer
Copy link
Member

asmeurer commented Sep 3, 2024

If I understand https://data-apis.org/array-api/draft/design_topics/lazy_eager.html correctly, the primary APIs that require materialization are __bool__, __int__, __float__, etc. https://data-apis.org/array-api/draft/design_topics/lazy_eager.html

So another option here would be to add a compute flag to a potential item API #815.

@hameerabbasi
Copy link
Contributor

A couple of thoughts:

  1. Regarding partial materialization, I don't think a separate API is necessary for that. Given there are already methods to select elements one wants from an array (indexing and so on), one can call materialize on the result of those. Whether or not that backfills the data of the indexed array is entirely the library's choice.
  2. Regarding the "do nothing" option, it's kind of tricky -- sometimes an explicit barrier to compute the data is needed to "help" optimization of the compute graph.
  3. Adding compute to the item API is something I'm -0.7 on if it's the only API. Ideally, you would compute only the part of the array that's needed, and doing it element by element in item would be excruciatingly slow in many cases.

@TomNicholas
Copy link

Thank you for starting this discussion @lucascolley, and thanks for tagging me!

function materialize(x: Array)

Note that the signature here should probably be more like materialize(*arrs: Array, **kwargs), as both dask.compute and cubed.compute can compute multiple lazy arrays at once. You often want to do that in order to handle common intermediate arrays efficiently. The **kwargs is also important because different parallel execution frameworks will require specifying different configuration options at runtime.

Prior art

Adding some more here:

  • Cubed:
    • Cubed (cc @tomwhite) deliberately copies the lazy API of dask.array. It too has a compute function for multiple arrays (as well as a .compute method on single arrays). An intermediate Plan object is created which is analogous to dask's task graph, but consists entirely of lazy arrays (as opposed to arbitrary functions that can go in a dask task graph). Normally computed results are either plotted or serialized to Zarr on-disk.
  • Xarray's "Lazy Indexing Classes"
    • Internally Xarray has an implementation of lazy indexing, which it uses to make sure that only data that is necessary for the result is loaded from on-disk files. Note this is completely separate from (and predates) dask's integration with xarray. It's not a fully-fledged lazy array API implementation, but there have been discussions about how it could become one (Lazy indexing arrays as a stand-alone package pydata/xarray#5081).
  • JAX:
    • I'm not super familiar with JAX but if I understand correctly JAX is effectively lazy but distributed over "local XLA devices", which act a little bit like dask chunks. jax.pmap and jax.vmap can be used to distributed computation over the devices, and there are methods which trigger array computation e.g. .block_until_ready().

If I understand https://data-apis.org/array-api/draft/design_topics/lazy_eager.html correctly, the primary APIs that require materialization are bool, int, float, etc.

I don't think this is the primary API at all, it's just an interesting special case where the return type is out of our hands. The primary API is as @lucascolley says, a materialize/compute function. For example in Xarray we basically compute if:

  • the user asks to see numeric .values,
  • the user tries to .plot their array,
  • the user wants to save their array to disk (e.g. as netCDF/Zarr).
  1. Adding compute to the item API is something I'm -0.7 on if it's the only API. Ideally, you would compute only the part of the array that's needed, and doing it element by element in item would be excruciatingly slow in many cases.

I agree - in Xarray we very rarely use .item, instead we are normally computing whole arrays (which have often been lazily indexed beforehand, so we are just computing a subarray of the original array).

I also agree with the other 2 points @hameerabbasi just made.

Xarray has a new abstraction over dask, cubed (and maybe soon JAX) called a "ChunkManager", (though I think we will rename it to a ComputeManager to better reflect its responsibilities). This can be understood as a way to create lazy arrays (.from_array), a way to distribute computation over them (.apply_gufunc), and a way to trigger computation of them (.compute).

@lucascolley
Copy link
Contributor Author

Thanks all for the comments! Just tagging @jakevdp also who can maybe shed some light on JAX.

@jakevdp
Copy link

jakevdp commented Sep 5, 2024

I don't think materialize would be useful for JAX. In that case, it would either be a no-op (during eager execution) or an unconditional error (during traced execution). When JAX is tracing a function, there is no concrete array to materialize!

@hameerabbasi
Copy link
Contributor

hameerabbasi commented Sep 5, 2024

Note that the signature here should probably be more like materialize(*arrs: Array, **kwargs), as both dask.compute and cubed.compute can compute multiple lazy arrays at once. You often want to do that in order to handle common intermediate arrays efficiently.

I'm +1 on an API that allows simultaneous materialisation of multiple arrays, although I'd spell it slightly differently.

  1. To preserve type safety and avoid special cases that require wrapping/unwrapping of tuples, I'd make the type of the first argument Array | Iterable[Array].
  2. The supplied kwargs are likely to be implementation-specific. Since the standard is a least common denominator, I feel this won't make it in. This, of course, doesn't prevent implementations from adding their own versions.

With this in mind, the signature I'd propose is materialize(x: Array | Iterable[Array]) -> Array | tuple[Array, ...]

@jakevdp
Copy link

jakevdp commented Sep 5, 2024

One thing I'm unclear on: what is the difference between materialized and non-materialized arrays in terms of the array API? What array API operations can you do on one, but not on the other?

@hameerabbasi
Copy link
Contributor

One thing I'm unclear on: what is the difference between materialized and non-materialized arrays in terms of the array API? What array API operations can you do on one, but not on the other?

I feel this is more driven by use-cases and performance. Some of these are outlined in #748 (comment) and #728.

@lucascolley
Copy link
Contributor Author

What array API operations can you do on one, but not on the other?

One we bumped into in SciPy is xp.unique_values (#834)

@asmeurer
Copy link
Member

asmeurer commented Sep 7, 2024

One thing I'm unclear on: what is the difference between materialized and non-materialized arrays in terms of the array API? What array API operations can you do on one, but not on the other?

As I mentioned above, __bool__, __int__, __float__, __complex__, and __index__ will not work on unmaterialized arrays, unless the library is willing to automatically materialize them.

Additionally, the APIs that have data-dependent shapes are unique_all(), unique_counts(), unique_inverse(), and unique_values(), nonzero(), and repeat() when the repeats argument is an array, as well as boolean array indexing. But a lazy library would not necessarily disallow these. The shape of an array cannot be computed statically when one of these operations is used, but the array API allows None in shapes (for instance, Dask). The issue in #834 is not so much from using unique_values but from trying to call int on the resulting None dimension.

@jakevdp
Copy link

jakevdp commented Sep 9, 2024

OK, thanks for the clarification. In that case, materialize is unimplementable in JAX, because materializing an array involves exiting a transformation context, and that can't be done by anything except exiting the context of the transformed function.

@rgommers
Copy link
Member

rgommers commented Sep 18, 2024

I'd like to add a different perspective, based on execution models. I think we have fundamentally three kinds:

  1. Eager execution model
  2. A fully lazy execution (or graph export) model
  3. A hybrid lazy/eager execution model

(1) Eager execution model

Examples of implementations:

  • NumPy
  • CuPy
  • PyTorch (eager mode)
  • JAX (eager mode)
  • dpctl (it's async eager with SYCL queueing primitives being available to the end user IIUC)

Any "execute or materialize now" API would be a no-op.

(2) Fully lazy execution model

Examples of implementations:

  • JAX (JIT mode, e.g. jax.jit and other transforms)
  • PyTorch (model export or graph modes, e.g. with torch.export and AOTInductor)
  • ndonnx

Any "execute or materialize now" API would need to raise an exception.

(3) Hybrid lazy/eager execution model

Examples of implementations:

  • PyTorch (in modes where it can "graph break", e.g. torch.compile)
  • Dask
  • Xarray
  • Cubed
  • MLX

This is the only mode where an "execute or materialize now" API may be needed. This is not a given though, which is clear from PyTorch not having any such .compute() or .materialize() API. Xarray and Cubed copy the Dask model, so I'll continue discussing only Dask and PyTorch.

As pointed out by @asmeurer above, there are only very few APIs that cannot be kept lazy (__int__, __bool__ & co., because the Python language semantics force evaluation and returning actual int/bool etc., rather than duck types). For everything else like .values, things can be kept lazy in principle.

For PyTorch, the way things work in hybrid mode is that if actual values are needed, the computation is done automatically. No syntax is needed for this. And there doesn't seem to be much of a downside to this. EDIT: see https://pytorch.org/docs/stable/export.html#existing-frameworks for a short summary of various PyTorch execution models.

MLX is in the middle: it does have syntax to trigger evaluation (.eval() but it auto-computes when needed (see https://ml-explore.github.io/mlx/build/html/usage/lazy_evaluation.html#when-to-evaluate). Its documentation says that .eval() may be useful to explicitly insert graph breaks to avoid the graph becoming too large, which may be costly.

For Dask, it chooses to require .compute() to trigger execution, and if that's not added by the user it raises an exception instead of auto-executing. I think the main rationale for that is that execution is expensive, so the exception is a reminder to the user to think about whether they can rewrite their code to stay fully lazy. Forcing the user to add .compute() is the user saying "I thought about it, and yes this is what I really want".

There is another important difference between PyTorch (and fully lazy libraries like JAX/ndonnx as well) vs. Dask I think:

  • PyTorch & co all have full context on what they're keeping lazy by turning it into a graph and what outputs are needed later on, and once something executes it stays in memory. If a model is too large, one will run out of memory.
  • Dask on the other hand has less information and may discard intermediate results. Discarding intermediate results is based on heuristics, and has the benefit of allowing scaling to larger data sets, at the cost of something having to re-compute variables that it discarded. It offers more execution-related syntax like .persist() and chunking to allow the user to give more hints about what to keep in memory or discard.

My current assessment is:

  • A .materialize()/.compute() API is not a good idea for the Array API Standard, since:
    • It doesn't do anything useful for fully eager or fully lazy implementations,
    • It doesn't do anything useful for PyTorch and potentially other libraries with a hybrid mode with enough context either.
    • @jakevdp already assessed it as not implementable for JAX.
    • I would also much prefer not to see this in NumPy.
  • The Dask choice is historical and pretty unlikely to change, but I'd argue that the "trigger the user" isn't a very good argument. This could fairly easily be implemented as a diagnostic mode instead (a global setting to change the default behavior, or something like cython -a to highlight where execution is triggered).
  • The Dask model is also fairly ad-hoc and non-deterministic, and:
    • For situations where execution must happen (like for bool()), there is little value from forcing users to add .materialize() to the code.
    • For situations where things can be kept lazy in principle but the library doesn't currently support that, the use of .materialize() would be library-specific and making one library happy would unnecessarily break laziness for another library. This is obviously a bad thing.
  • There is also a pragmatic naming issue if we'd consider adding something to the standard: Dask has .compute(), MLX has .eval(), other have nothing (and outside of array/tensor land, Polars chose .collect()).
    • These methods are fine to have for any library, but in my opinion belong outside of the standard.
  • For the SciPy example in gh-834 there isn't much to gain here and it's not about unique_values - I commented on that in gh-834#comment.
  • I haven't put PyData Sparse in the above list, since it's very much in flux right now. @hameerabbasi I'd strongly consider using a JAX/PyTorch like model though, where you have an explicit function like jax.jit/torch.compile which scopes what is being compiled because it applies to a function. I think that that completely removes the need for any syntax in the standard. gh-748 doesn't actually explain what the problem is with that - let's discuss there if needed.

Now we obviously do have an issue with Dask/Xarray/Cubed that we need to understand better and find a solution for. It's a hard puzzle. That seems to require more thought, and perhaps a higher-bandwidth conversation soon. The ad-hoc-ness is (as far as I understand it - I could well be missing something of course) going to remain a fundamental problem for any attempt at standardization. I'd be curious to hear from @TomNicholas or anyone else with more knowledge about Dask why something like a user opt-in to auto-trigger compute whenever possible isn't a good solution.

@lucascolley
Copy link
Contributor Author

I'd be curious to hear from @TomNicholas or anyone else with more knowledge about Dask why something like a user opt-in to auto-trigger compute whenever possible isn't a good solution.

@lithomas1 asked this in dask/dask#11356 and the response from @phofl was

I would rather not add any option like this and then have users stumble into it.

@fjetter said in dask/dask#11298 (comment)

in our experience these implicit compute patterns are strong anti patterns [1] that are surprising people and are actively harmful in many cases. We are preferring a non surprising and easy to use API and are willing to sacrifice compatibility for this if necessary.

[1] Think about an array of a couple TB in a remote storage that is being loaded entirely just to allow simple indexing as shown in this example. If we called compute_chunk_sizes ourselves, all of this would happen (and cost money) just because a user did arr[[0, 1]]

@rgommers
Copy link
Member

Thanks for the pointers @lucascolley. So that seems to be a fairly conclusive "we have some experience and won't do that" - which is fair enough. A few thoughts on those discussions:

  • The answer seems to be motivated primarily by the behavior for distributed arrays, and throws lazy and distributed into the same bucket. Lazy and distributed execution are very different though, and the latter is significantly more complex and requires more primitives.
  • From the standard's perspective, we've tried to avoid including functions that are known to be problematic for distributed libraries (e.g., we have mean but not median for that reason - although we added a few super-common ones like sort), but there are certainly other things one needs that we haven't given any thought at all so far (e.g., chunking/sharding). So when we talk about lazy execution here, it's about single-node lazy execution.
  • Dask shines mostly for distributed use cases, so if the rationale is something like: "this is too much of a footgun for distributed usage, and we want to keep the single node and distributed APIs as similar as possible to make it easy to scale up, hence we don't want auto-compute" that is certainly a reasonable position I think.
  • Not having any auto-compute is not a problem from the standard's perspective. Right now Dask still does have some auto-compute, but if that were all removed it would put Dask more or less in the same position as jax.jit and ndonnx for how well it could run standard-compliant code without any modifications.
    • And then to move it back to a hybrid mode, a user or library author could choose to special-case Dask if they really wanted that, with something like:
def compute(x):
    if is_dask_array(x):
        x.compute()
    return x

def some_func(x):
    if compute(x).shape[0] > 5):
        # we couldn't avoid the `if` conditional in this logic
        ...

@lucascolley
Copy link
Contributor Author

Thanks Ralf, that makes sense. I'm pretty convinced that we don't want to add materialize now, so I'll close this issue.

As @asmeurer mentioned previously, we still need to decide in SciPy whether we are comfortable with doing int(np.asarray(xp.unique_values(...))) (perhaps with DLPack in the future) to force this materialisation, or whether we should let it error. But we can cross that bridge when the Dask PR is ready, which won't be before a new Dask release and an array-api-compat release with 2023.12 support.

@lucascolley lucascolley closed this as not planned Won't fix, can't repro, duplicate, stale Sep 18, 2024
@rgommers
Copy link
Member

Thanks Lucas. I'll reopen this for now to signal we're not done with this discussion. I've given my input, but at least @hameerabbasi and @TomNicholas seem to have needs that perhaps aren't met yet. We may also want to improve the documentation around this topic.

@rgommers rgommers reopened this Sep 18, 2024
@asmeurer
Copy link
Member

in our experience these implicit compute patterns are strong anti patterns [1] that are surprising people and are actively harmful in many cases. We are preferring a non surprising and easy to use API and are willing to sacrifice compatibility for this if necessary.

This aligns with the point I was trying to make above (#839 (comment)), which is that a library like scipy calling compute() would effectively be automatic compute from the point of view of the end-user. I think of the array API and especially libraries like scipy as "extending" an array library's array toolbox. So to an end-user, calling dask.array.mean(x) is no different from calling scipy.stats.hmean(x). One just happens to come from a different library. But semantically one implicitly calling compute() would have the same pitfalls as the other.

So I think that if scipy encounters this situation in one of its functions, it should either do nothing, i.e., require the user to materialize the array themselves before calling the function, or register the function itself as a lazy function (but how that would work would be array library dependent).

@lucascolley
Copy link
Contributor Author

I think it should be possible to use the introspection API to add different modes, where we raise errors by default but a user can opt-in to allowing us to force computation. The same can be said for device transfers via DLPack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API. Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Lazy/Graph Lazy and graph-based array implementations.
Projects
None yet
Development

No branches or pull requests

7 participants