Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cohort statistics without numba #885

Closed
timothymillar opened this issue Aug 9, 2022 · 7 comments
Closed

Cohort statistics without numba #885

timothymillar opened this issue Aug 9, 2022 · 7 comments
Labels
question Further information is requested

Comments

@timothymillar
Copy link
Collaborator

I've been giving a bit more thought to how we can calculate cohort statistics (see #730) without using numba and realized that we could potentially use the same solution mentioned in https://github.com/pystatgen/sgkit/pull/114#issuecomment-678921281. That approach only uses generic array operations making it very portable (e.g. for #803). However, it does involve passing the data though a higher dimension array ( (..., samples) -> (..., samples, cohorts) -> (..., cohorts)).

A simple example:

import numpy as np
import xarray as xr

n_samples = 50
n_variants = 1000
n_cohorts = 3

sample_cohort = np.random.randint(-1, n_cohorts, size=n_samples)
sample_values = np.random.rand(n_variants * n_samples).reshape(n_variants, n_samples)
cohort_membership = sample_cohort[:,None] == np.arange(n_cohorts)  # one-hot encoding of cohort membership

ds = xr.Dataset({
    "sample_cohort": (["samples"], sample_cohort),
    "sample_values": (["variant", "samples"], sample_values),
    "cohort_membership": (["samples", "cohorts"], cohort_membership),
})

# create an array of shape ("variants", "samples", "cohorts") and then reduce to ("variants", "cohorts")
ds["cohort_sum"] = (ds["sample_values"] * ds["cohort_membership"]).sum(dim="samples")
ds["cohort_mean"] = ds["cohort_sum"] / ds["cohort_membership"].sum(dim="samples")

The additional memory usage of this approach may be acceptable when the number of cohorts is small. It could also be mitigated by using a sparse encoding for "cohort_membership" and its derivatives.

I guess this comes down to, where do we want to strike the balance between performance and simplicity?

@timothymillar timothymillar added the question Further information is requested label Aug 9, 2022
@jeromekelleher
Copy link
Collaborator

I'd be inclined to take a "if it ain't broke don't fix it" view on this @timothymillar - after all, we depend pretty much entirely on numba, and I'm not sure the library will work in any meaningful way without it?

The approach does look quite straightforward though, I agree.

@tomwhite
Copy link
Collaborator

tomwhite commented Aug 9, 2022

It would be interesting to see differences in timing, although as @jeromekelleher says numba is used throughout the codebase.

Also, #803 is a long-term thing that will get fixed, and is not really a priority use case for us, just something to keep an eye on. And there is progress happening on that front too, see emscripten-forge/recipes#168.

@hammer
Copy link
Contributor

hammer commented Aug 9, 2022

@tomwhite how does numba interact with cubed? Are there any difficulties or slowdowns it causes on any of the backends?

@tomwhite
Copy link
Collaborator

tomwhite commented Aug 9, 2022

Numba should work with Cubed by using map_blocks (just like Dask) - although I haven't tried that yet. Numba works well with Dask in my experience.

That said, I'm not against this proposal if the performance is acceptable.

@tomwhite
Copy link
Collaborator

tomwhite commented Aug 9, 2022

Numba does cause issues with package compatibility - for example we can't use the latest NumPy version (1.23), which has implications for the version of cyvcf2 we can use...

@timothymillar
Copy link
Collaborator Author

Some crude timings using n_samples=50, n_variants=1000_000 and variable cohort number. Cohorts were randomly assigned (identical for each implementation) and some samples were randomly assigned to no cohort (-1).

# current implimentation in numba
%%timeit
cohort_sum(sample_values, sample_cohort, n_cohorts).compute()

# via one-hot using numpy
%%timeit
(sample_values[...,None] * cohort_membership[None,...]).sum(axis=1)

# via one-hot using xarray
%%timeit
(ds["sample_values"] * ds["cohort_membership"]).sum(dim="samples")

Times in seconds:

cohorts numba numpy xarray
1 0.50 0.12 0.29
3 0.51 1.8 2.3
5 0.50 2.2 3.1
7 0.51 2.5 3.9
9 0.51 3.0 4.7

So probably not an acceptable performance hit for anymore than a few cohorts. I thought it was at least worth documenting this approach as it may be useful elsewhere. Thanks for the discussion all!

@jeromekelleher
Copy link
Collaborator

Very nice, thanks for looking into this @timothymillar! Numba is awesome 😍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants