Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use case study: Intheon #403

Closed
chkothe opened this issue Mar 12, 2022 · 9 comments
Closed

Use case study: Intheon #403

chkothe opened this issue Mar 12, 2022 · 9 comments

Comments

@chkothe
Copy link

chkothe commented Mar 12, 2022

We're very happy to see progress in getting the array API standardized -- obviously a monumental undertaking when all is considered.

We've been using backend-agnostic numpy-style code for over a year now in research and now production, and are gradually rolling it out across a multi-megabyte codebase (supporting mainly numpy, cupy, pytorch, tensorflow, jax (also dask, but not using that at the moment), so I thought I'd share our user story in case it's helpful. It's a similar use case than what the array API addresses, but it was built with the current numpy workalike APIs in mind.

Our codebase has an entrypoint equivalent to get get_namespace, although we call them backends, and a typical use case looks like:

def sqrt_sum(arr1, arr2):
    be = backend_for(arr1, arr2)
    return be.sqrt(be.asarray(arr1) + be.asarray(arr2))

For each of the backends we have a (usually thin) compatibility layer that adds any missing functions or fixes up issues with the function signature. In our case, backend_for looks at __array_priority__ to return the namespace for the highest-priority array, although we rarely use it with more than one array (but it results in the above function accepting multiple array types and promoting according to np < dask < {jax, tf, cupy} < torch). Getting a namespace this way is about as fast as doing .T on a 1x1 numpy array (thanks to some caching) so we use it on even the smallest subroutine.

We use a lot of the API surface of these backends, and typically the most compute-intensive subroutines have an option to choose a preferred backend (via a function backend_get(shorthand)) out of the subset that supports the necessary ops, or to keep the same backend. We've been extremely impressed with the compatibility of cupy, the performance of torch (notably also its cpu arrays), and we have a few places where we prefer tf or jax because they might have a faster implementation of a critical op or parallelize better (e.g. jax). We find that even over a half-year time span, things evolved rapidly enough that the ideal backend pick for a function may change from one to another (one of the reasons we aim for that much flexibility).

We traced our API usage and found that perhaps 90% of our backend-agnostic call sites would be covered by the array API as it exists now. There are a few places that use different aliases for the same functionality (due to some forms being more popular with current backends than others) the most frequent issues being absolute(x) and concatenate(x), other examples being arccos(x) (all our backends) vs acos(x) (no backend). We also frequently use convenience shorthands like hstack(x), vstack(x), ravel(x) or x.flatten(), but those could be substituted easily (or provided in a wrapper).

We found a few omissions that would require a bit more code rewriting, among others the functions minimum(a,b), maximum(a,b), clip(x, lower, upper) (presumably that would turn into where(x>upper,upper,where(x<lower,lower,x)). Also we frequently use moveaxis(x,a,b) and swapaxes(x,a,b), e.g., in linear algebra on stacks of matrices (>30 call sites for us). All of these are supported by the above 6 backends already and they're pretty trivial, fortunately. Our code uses real(x) in some places since some of the implementations might return spurious complex numbers; that may be reason enough to at least partially specify that function already now. Also reciprocal(x) occurs frequently in our code, I guess that's from a suspicion that writing 1/x may fail to use the reciprocal instruction if it's available.

A few things that we use have no substitute at this point, unfortunately, namely einsum(), lstsq(x,y), eig(x), and sqrtm() (though the latter could be implemented via eigh). We hope that these eventually find their way (back) into the API. I realize that lstsq was removed as per a previous discussion (and it's understandable given that the API is a bit crufty), but then our code base has 26 unique call sites of that alone, since we're dealing mostly with engineering and stats. One might reasonably assume that backends that already have optimized implementations of that (all 6 do, and torch/tf support batched matrices) will provide it anyway in their array API namespace. However, we do worry that, given that it has been deliberately removed, we can't be sure that some of the existing backends won't be encouraged to minimize their maintenance surface and drop functions like that from their formerly numpy-compatible (soon array API compatible) namespace, forcing users like us to deal with their raw EagerTensor, xla_extension.DeviceArray, or whatever it may be called, and go find it in whichever ancestral tensorflow namespace the functionality may have been buried in before. We're wondering if a tradeoff could be made where e.g., some of the rarely-used outputs could be marked as as "reserved" placeholder and are allowed to hold unspecified values (e.g., None) until perhaps at some future date the API specifies them. There's also the option to go the same route as with svd, where some arguments were removed in the interest of simplicity. On the plus side, it's good to see diag retired in favor of diagonal (especially so in the age of batched matrices).

Other than that, for multi-GPU we use a backend-provided context manager where available (torch, tf, cupy) a custom context manager where it's not (jax), and a no-op context manager for numpy & dask (usage looks like with be.select_device(id):). That's because passing device= through all compute functions down into the individual array creation calls (from arange to eye) just isn't all that practical with a large and deeply nested codebase, and it's easy to overlook call sites, causing hidden performance bugs that only turn up on multi-accelerator runs -- however, since the user can write their own context manager (with a stack in thread-local storage and wrappers around the array creation functions), that can be worked around with some effort.

Lastly, our indexed array accesses in our main codebase (the parts that we hope to eventually port) look like the following:

x[:, :, slice,:]        # similar expressions occur hundreds of times with variable counts of :'s
x[:,:,slice,:,slice,:]  # hundreds of times, but can fall back to chained x[:,:,:,:,slice,:][:,:,slice,:,:,:]
x[:,:,indices,:]        # similar uses occur dozens of times
X[:, an_int, :, :]      # probably a few dozen times
X[:, bools, :, :]       # a few to a dozen times
x[:, reverse_slice,:]   # occasional (not supported by current pytorch)
x[indices, :, indices]  # never (we don't use numpy advanced indexing with multiple arrays)
x[:, None, slice]       # used as equivalent to np.newaxis

We use a high-level array wrapper (similar in spirit to xarray) that supports arbitrarily strided views and allows writes into those views, which results in low-level calls (in the guts of the array class) equivalent to the form:

be.reshape(be.transpose(x, order), shape)[:,slice,:,:,slice,:] = y  # invoked similarly from hundreds of places (indirectly)

... that's because we spend much of our time dealing with multi-way tensors (e.g., neural data) that have axes such as space, time, frequency, instance, statistic, or feature (often 3-5 at a time), and most subroutines are agnostic to the presence or order of most axes except for one or two, so they create views that move those few to specific places and then read/write through the transposed view. Our way of dealing with backends that don't support that is not enabling them for those functions (and having feature flags on the backends for reverse indexing, slice assignment, and mesh indexing support to catch cases where we do).

I wasn't sure if this is the right place to report relevant API usage "in the field", hopefully it is.

@rgommers
Copy link
Member

Hi @chkothe, thanks so much for the very detailed write-up of your use case! This is a good place to share.

My overall takeaway from the statistics and details you provide is that we've gotten things mostly right and our minimalist API surface covers a pleasantly large amount of what you need. However you also point out some gaps that would be helpful to address, so let me try and respond to each and see which ones we need to follow up on:

We use a lot of the API surface of these backends, and typically the most compute-intensive subroutines have an option to choose a preferred backend [...]

This part surprised me. I'd expect that converting between array types per function call will be quite expensive, introducing both overhead and possible bugs. Do you have some examples for functions where you do or don't allow this?

the most frequent issues being absolute(x) and concatenate(x),

concatenate vs. concat I understand (we went with conciseness + Python's naming), but in NumPy abs and absolute are aliases, and np.abs is the much more commonly used one. So for absolute just making changes in your code base seems okay?

other examples being arccos(x) (all our backends) vs acos(x) (no backend).

TensorFlow and PyTorch (and C99) use acos. NumPy also tends to follow C in such naming, but the arc* functions ended up differently for historical reasons. I think the choice is reasonable.

We also frequently use convenience shorthands like hstack(x), vstack(x), ravel(x) or x.flatten(), but those could be substituted easily (or provided in a wrapper).

There is something here that we still need to follow up on (not enough hours in a day ...): can we popularize/adopt einops as a better API for these sort of manipulations? data-apis/consortium-feedback#3

We found a few omissions that would require a bit more code rewriting, among others the functions minimum(a,b), maximum(a,b), clip(x, lower, upper) (presumably that would turn into where(x>upper,upper,where(x<lower,lower,x)).

  • I'm fairly sure minimum and maximum came up somewhere else recently and they have good support across array libraries. I am just failing to find the relevant discussion right now.
  • clip is fairly heavily used and the use of where to emulate it is a bit cumbersome - this may be worth splitting off as a separate issue to discuss standardizing it (or have a good reason for why not).

Also we frequently use moveaxis(x,a,b) and swapaxes(x,a,b),

Both of these are good candidates to add I'd say. Let's open a separate issue for them.

Also reciprocal(x) occurs frequently in our code, I guess that's from a suspicion that writing 1/x may fail to use the reciprocal instruction if it's available.

I'm not sure if this is actually guaranteed - for example the numpy docs just say "calculates 1/x". reciprocal is implemented by all libraries though - it was just not considered for inclusion yet because it's fairly niche (I just checked SciPy for example: just 2 instances of np.reciprocal).

A few things that we use have no substitute at this point, unfortunately, namely einsum(), lstsq(x,y), eig(x), and sqrtm() (though the latter could be implemented via eigh). We hope that these eventually find their way (back) into the API.

However, we do worry that, given that it has been deliberately removed, we can't be sure that some of the existing backends won't be encouraged to minimize their maintenance surface and drop functions like that from their formerly numpy-compatible

I very much doubt this will happen. I have so far not seen any suggestions like this, nor do I think it's a valid reason to remove functionality. Each library is going to provide a superset of what's in the standard. Deprecation and removal of existing functionality requires a good reason, like "confuses users", "there are now superior alternatives", "it's broken", etc.

We're wondering if a tradeoff could be made where e.g., some of the rarely-used outputs could be marked as as "reserved" placeholder and are allowed to hold unspecified values (e.g., None) until perhaps at some future date the API specifies them. There's also the option to go the same route as with svd, where some arguments were removed in the interest of simplicity.

The svd route seems preferred here. Or just transitioning to a new name with saner behavior, just like we left out unique and replaced it with (in that case multiple) new functions.

Other than that, for multi-GPU we use a backend-provided context manager where available (torch, tf, cupy) a custom context manager where it's not (jax), and a no-op context manager for numpy & dask (usage looks like with be.select_device(id):). That's because passing device= through all compute functions down into the individual array creation calls (from arange to eye) just isn't all that practical

We went back and forth on that quite a bit. I think we need to see this work in practice for a bit during adoption of the standard in array-consuming libraries. There were certainly arguments for adopting a context manager. One pain point I remember is that semantics were hard to agree on - for example PyTorch has a rule of never allowing implicit device transfers, while TensorFlow does allow that.

Lastly, our indexed array accesses in our main codebase (the parts that we hope to eventually port) look like the following:

That's super useful, thanks. I think there's work in progress to propose adding single-integer-array indexing to the specification, which I think address the most common gap you're seeing.

@chkothe
Copy link
Author

chkothe commented Mar 17, 2022

Thanks for the detailed feedback! I agree with the overall assessment.

This part surprised me. I'd expect that converting between array types per function call will be quite expensive, introducing both overhead and possible bugs. Do you have some examples for functions where you do or don't allow this?

Yes, and to be more clear, for the typical few-liner subroutine it would be too much overhead -- the places where we do that mainly involve iterative solvers that can take 10s of seconds (or more) to run on large data, where the conversion overhead is small. Some examples of that are robust estimators, large-scale least-squares interpolation, machine learning code, or costly sliding-window ops on time series (so the analogy would be the typical high-level scikit-learn or statsmodels method). Another case is with workflows where only the few most expensive steps run on the GPU, meaning that a conversion happens anyway, and then one may as well pick the fastest backend.

So for absolute just making changes in your code base seems okay?

Yes for sure -- those are quick search-and-replace changes. We're also fine with putting light wrappers around namespaces to fix up small gaps like those where needed, and have been doing that all along (that may be necessary if one wishes to have side-by-side support for backends that have and have not already transitioned to the array API). Overall I agree that concise and uniform naming seems preferable.

  • Re moveaxis/swapaxes: I checked, and those are by now available in TF (move, swap) and PyTorch (move, swap). I fear I won't have much time to spin out separate issues and follow up on them at this point (my apologies!), but just having those on the collective radar is hopefully a good start.

  • Re minimum / maximum / clip: right, and another reason would be that common GPUs have dedicated/efficient support for some typical element-wise clipping (saturating) ops (e.g., nVidia, AMD; it's my understanding that they can even be zero-cost), and it'd be nice if the API didn't leave that silicon unused. Of course it wouldn't be guaranteed that those strength were leveraged by any actual implementations (though one may hope that XLA et al. would have an easier time doing a good job when presented with an easy-to-transform expression compared to nested where clauses).

can we popularize/adopt einops as a better API for these sort of manipulations? data-apis/consortium-feedback#3

My two cents are that having those added to the API seems appealing for newly written code, at least in places where it makes numpy easier to use or follow (requires some case-by-case judgment). When porting existing numpy code, I'd ideally be able to search and replace np. by xp. for the most part without much rethinking/reengineering. Likewise when reviewing a PR that does such a port, it's a much lower burden if it's a light syntactic change (especially when unit tests are incomplete).

Thanks for addressing the lstsq concern. As long as there's awareness that there's users like us out there who more or less depend on it, we should be covered for now (and presumably more people would come out of the woodwork if that ever broke in existing implementations).

We went back and forth on that quite a bit. I think we need to see this work in practice for a bit during adoption of the standard in array-consuming libraries. There were certainly arguments for adopting a context manager. One pain point I remember is that semantics were hard to agree on - for example PyTorch has a rule of never allowing implicit device transfers, while TensorFlow does allow that.

Re context managers -- yes, those look like a candidate for a later extension (seems like a potential tar pit that could slow down standardization). After all, there's sort of an escape hatch that allows users to retrofit one that controls the device= argument of the creation functions if they need it (using function wrappers), and that may even be the most sensible scope for that. I could imagine that someone might develop a utility library around the array API that fills in some bits like that, and otherwise broadens the numpy compatibility by implementing more common aliases for ops etc. (easier to do things like that in a take-it-or-leave-it library than in a standard).

That's super useful, thanks. I think there's work in progress to propose adding single-integer-array indexing to the specification, which I think address the most common gap you're seeing.

Yeah, and I'm not even sure all our single integer array indexing use cases are all that relevant (oftentimes we could as well use a 1-element slice). What's more important to us is actually that implementations continue to preserve their numpy-style view semantics and write-through capabilities that they already have (and one may hope that more implementations achieve it in case it's low-hanging fruit for them).

@leofang
Copy link
Contributor

leofang commented Mar 21, 2022

Hi @chkothe, thanks for your very detailed report! It's valuable and as a CuPy contributor I'm happy to hear both CuPy and the array API help your work.

Sorry to digress here, though, as I was hoping to discuss this offline but couldn't find your contact. Could you kindly share why you need eig() instead of eigh()? What's your use cases? On the CUDA math library side (which I am part of), this need (ex: cupy/cupy#3255) was raised multiple times internally, but the cuSOLVER team considered the routine that can back eig() both mathematically ill-defined and much less used in the wild. If you can help us justify/motivate, it would be very nice (so we can unblock you and many other libraries like CuPy) 🙂

@chkothe
Copy link
Author

chkothe commented Mar 21, 2022

Thanks for following up on that! Emailed you.

@shoyer
Copy link
Contributor

shoyer commented Jul 21, 2022

  • sqrtm (and related functions like expm, logm, etc.) are very hard to implement, so I'm fairly certain we don't want that one in the standard.

These are indeed tricky to implement, but probably not significantly harder than matrix factorization? Yes, they are in SciPy rather than NumPy, but that's a somewhat artifcial distinction. Many (not all) of these can be found in PyTorch, JAX and TensorFlow.

My two cents is that it would be valuable if the standard specified the interface for these functions, even if not every library is going to implement them. It's really not a big deal to need to look up a compatibility table for advanced linear algebra functionality.

The other option is to encourage people to write their own shims for specific backends, which is not terrible but goes a little against the spirit of the array standard.

@asmeurer
Copy link
Member

linalg is already an optional extension in the standard. Maybe it would also be helpful to state that individual functions in an extension might not be implemented?

@shoyer
Copy link
Contributor

shoyer commented Jul 22, 2022

linalg is already an optional extension in the standard. Maybe it would also be helpful to state that individual functions in an extension might not be implemented?

Yes, definitely!

@rgommers
Copy link
Member

Maybe it would also be helpful to state that individual functions in an extension might not be implemented?

Hmm, I'm not so sure I agree with that. Unfortunately https://data-apis.org/array-api/latest/extensions/index.html does not explain this yet - we should fix that. I had a look through older discussions, and it wasn't conclusive resolved if this is okay or not. What was resolved is that an extension is optional. My working assumption was that if the extension is implemented, it is complete. If every single function in an extension is optional, then this will be very tricky to use from a user perspective. Checking if __array_namespace__.linalg exists is a lot easier (also to document by a downstream library as required yes/no).

My two cents is that it would be valuable if the standard specified the interface for these functions, even if not every library is going to implement them. It's really not a big deal to need to look up a compatibility table for advanced linear algebra functionality.

I had a look at expm, which is probably the most widely used one of this set of functions:

For the four libraries that do have an implementation, the signatures don't match. So it's also a matter of effort whether we'd like to do this. Maybe rather than adding it to the linalg extension, there should be a separate list of functions for which folks want to align on an interface, but it's not part of the standard? Not sure exactly how that should look like (separate html docs section, some qualifier/status for an API, or ...), but I'd be keen on not adding so much that it becomes ever harder to be compliant.

rgommers added a commit to rgommers/array-api that referenced this issue Aug 4, 2022
Note that an alternative idea discussed earlier was a separate function
`get_extension()`, but IIRC that was not considered a good idea.
Now that we have module-level `__getattr__`'s, it should not be
a problem for any library to use a specified name like `linalg` or
`fft`.

Whether the functions in the `linalg` extensions should all be required
to exist or not was discussed in data-apisgh-403 (no clear conclusion there).
One option discussed there to deal with hard to implement or more niche
APIs is to create a separate status/label or some other way to track
desired signatures (details to be worked out if we want to go that way).
@jakirkham
Copy link
Member

jakirkham commented Sep 23, 2022

clip is fairly heavily used and the use of where to emulate it is a bit cumbersome - this may be worth splitting off as a separate issue to discuss standardizing it (or have a good reason for why not).

Raised in issue ( #482 )

,
moveaxis: discussed in Add specifications for array manipulation functions #42 (comment), was left out initially because PyTorch and TensorFlow don't have it.
swapaxes: pointed out in that same PR as missing, but not discussed there further. TensorFlow also doesn't have it, but PyTorch does have it.
Both of these are good candidates to add I'd say. Let's open a separate issue for them.

Raised issue ( #483 )

rgommers added a commit that referenced this issue Nov 28, 2022
Note that an alternative idea discussed earlier was a separate function
`get_extension()`, but IIRC that was not considered a good idea.
Now that we have module-level `__getattr__`'s, it should not be
a problem for any library to use a specified name like `linalg` or
`fft`.

Whether the functions in the `linalg` extensions should all be required
to exist or not was discussed in gh-403 (no clear conclusion there).
One option discussed there to deal with hard to implement or more niche
APIs is to create a separate status/label or some other way to track
desired signatures (details to be worked out if we want to go that way).
@data-apis data-apis locked and limited conversation to collaborators Apr 1, 2024
@kgryte kgryte converted this issue into discussion #769 Apr 1, 2024
@kgryte kgryte removed the use case label Apr 4, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants