Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] count_allele_calls #114

Merged
merged 16 commits into from
Aug 25, 2020

Conversation

timothymillar
Copy link
Collaborator

@timothymillar timothymillar commented Aug 16, 2020

See issue #85

This implements an additional jitted function count_call_alleles_ndarray_jit for ndarrays only rather than doing it all in dask.
This approach seems to be inline with the goals outlined here but I'm happy to replicate the approach of the original count_alleles function if that is preferred.

Likewise I can re-write count_alleles using this approach which should improve performance (mainly on chucked arrays due to njit(..., nogil=True))

I haven't added numba to requirements.txt or setup.py Because that is done in #76

I guess add numba now for CI and fix conflict later

@eric-czech
Copy link
Collaborator

eric-czech commented Aug 16, 2020

Thanks for picking this up @timothymillar!

I've been hesitant to wade into numba for things like counting (put a few thoughts at https://github.com/pystatgen/sgkit/issues/49#issuecomment-674561391) but I agree it's probably the right choice here.

It's a shame that there doesn't seem to be a good way to do this with Dask/Xarray, but if we do lose the freedom in array backends it would be nice to still support CPU/GPU easily. Have you thought at all about using a gufunc for this, which would make it easier to compile to either? I think the whole module could collapse down to something like this in that case, where using the sentinel missing value also simplifies things vs using the separate mask variable:

@numba.guvectorize(['void(numba.int8[:], numba.int32[:], numba.int32[:])'], '(n),(k)->(k)')
def count_alleles(x, _, out):
    out[:] = 0
    for v in x:
        if v >= 0:
            out[v] += 1
            
def count_call_alleles(ds) -> DataArray:
    G = da.asarray(ds.call_genotype)
    # This array is only necessary to tell dask/numba what the 
    # dimensions and dtype are for the output array
    O = da.empty(G.shape[:2] + (ds.dims['alleles'],), dtype='int32')
    O = O.rechunk(G.chunks[:2] + (-1,))
    return xr.DataArray(
        count_alleles(G, O),
        dims=('variants', 'samples', 'alleles'),
        name='call_allele_count'
    )

def count_variant_alleles(ds) -> DataArray:
    return (
        count_call_alleles(ds)
        .sum(dim='samples')
        .rename('variant_allele_count')
    )

I'd be a proponent of simply removing the count_variant_alleles function too and leaving it up to users in the future if there isn't a big performance difference between a sum following per-call allele counts and a custom kernel/function.

@timothymillar
Copy link
Collaborator Author

Have you thought at all about using a gufunc for this

Good idea, I'll give it a go

where using the sentinel missing value also simplifies things vs using the separate mask variable

I was a little uncertain of the application of the mask as their seems to be two options when it comes to a user masking out additional values (which I assume is an intended feature at some point):

  1. The user updates the mask (via some API) but the genotype arrays are unchanged hence any function has to use the mask itself either by updating a copy of the genotype calls (as in the current implementation of count_alleles) or using the mask itself as in this PR.
  2. The user updates the mask (via some API) and then applies it to their DataSet which replaces the (alleles within) genotype calls with the sentinel value in the original DataSet (or in a copy).

Essentially should functions working on a DataSet trust the mask or the sentinel values?

@eric-czech
Copy link
Collaborator

Essentially should functions working on a DataSet trust the mask or the sentinel values?

The mask is only a convenience on ds.call_genotype < 0 so I think it's best to use the sentinel values until both are needed.

when it comes to a user masking out additional values (which I assume is an intended feature at some point)

We haven't talked about that yet (feel free to file an issue), but afaik Dask still doesn't support assignment so I see that process as:

  1. An operation defines a transformation of the data array that makes additional values equal to the missing sentinel
  2. The mask is redefined as mask = arr < 0

@timothymillar
Copy link
Collaborator Author

@eric-czech The gufunc version is working nicely but I'm having some issues satisfying the CI with the use of guvectorize.
Do you know whats causing the Sphinx doc issue?

A couple of notes about the implementation:

  • The gufunc is returning an array with dtype uint8, I don't think this will be an issue but if it is then the dtype of the dummy array can be used to indicate the output dtype.
  • I'm not sure on the docstring style preference as there is a bit of variation in the repo.
  • I reversed your signature from '(n),(k)->(k)' to '(k),(n)->(n)' because the genotype vector has a length of ploidy and k is commonly used for ploidy in (polyploid) literature.

@eric-czech
Copy link
Collaborator

eric-czech commented Aug 19, 2020

Nice @timothymillar! This looks great.

The gufunc is returning an array with dtype uint8, I don't think this will be an issue but if it is then the dtype of the dummy array can be used to indicate the output dtype.

Perfect, makes sense.

I'm not sure on the docstring style preference as there is a bit of variation in the repo.

Me neither. I started using the Dask style like that but then switched to referencing our ArrayLike type once that was added and @tomwhite did too. I was waiting until Sphinx was up and running before trying to overhaul the signatures, but https://github.com/pystatgen/sgkit/pull/124 would change a lot of it anyhow. I'm not sure what the standard should be until some more dust settles. Let us know if you have any thoughts on it.

I reversed your signature from '(n),(k)->(k)' to '(k),(n)->(n)' because the genotype vector has a length of ploidy and k is commonly used for ploidy in (polyploid) literature.

Good call!

Do you know whats causing the Sphinx doc issue?

Hmm I do not. Does it go away if you do from numba import guvectorize and use the decorator that way? I'm not sure if the issue is with that or the numba.int8[:] signature annotations. Perhaps it will work with strings like int8[:] instead?

n_alleles = ds.dims["alleles"]
G = da.asarray(ds["call_genotype"])
shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea using this instead! For the subsequent map_blocks call it might be a good idea to ensure this has only a single chunk so that the blocks broadcast. I can't imagine why anyone would ever configure the default chunk size to be so small that n_alleles items would result in multiple chunks, but guarding against it is probably a good idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't imagine why anyone would ever configure the default chunk size to be so small that n_alleles items would result in multiple chunks

Neither but this edge case should be handled correctly by this PR, see below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Thoughts on a da.asarray(ds["call_genotype"].chunk(chunks=dict(ploidy=-1))) then? I think ploidy chunks of size one could actually be pretty common when concatenating haplotype arrays.

I could also see the argument in leaving it up to the user to interpret those error messages and rechunk themselves so that they have to think through the performance implications, but I had already started being defensive about that in the GWAS regression methods. My rationale was that it's better to optimize for a less frustrating user experience than for making performance concerns prominent (and we could always add warnings for that later).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh nm, I see why it doesn't matter now (and the test for it). Ignore that.

shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
return xr.DataArray(
da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2),
Copy link
Collaborator

@eric-czech eric-czech Aug 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what keeps this from working as count_alleles(G, N) instead? Do the block shapes need to be identical?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count_alleles(G, N) results in an error if G is chunked in the ploidy dimension. See comment below for an example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@hammer hammer mentioned this pull request Aug 19, 2020
3 tasks
@timothymillar
Copy link
Collaborator Author

timothymillar commented Aug 19, 2020

Here's a minimal example explaining the use of map_blocks.

setup:

import numpy as np
import dask.array as da

from sgkit.stats.aggregation import count_alleles

n_alleles = 2
genotypes = np.array(
    [[[ 0,  0],
      [ 0,  1],
      [ 1,  0]],
     [[-1,  0],
      [ 0, -1],
      [-1, -1]]], dtype=np.int8)

N = da.empty(n_alleles, dtype=np.uint8)
G = da.asarray(genotypes).rechunk((1,1,1))  # unlikely chunking

(All of the below options work correctly if G is not chunked in dimension 2 (ploidy))

Option 1: calling count_alleles directly:

count_alleles(G, N).compute()

Results in error:

ValueError: Core dimension `'k'` consists of multiple chunks. To fix, rechunk into a single chunk along this dimension or set `allow_rechunk=True`, but beware that this may increase memory usage significantly.

I didn't actually explore the use of allow_rechunk=True but even if it achieves the same as below I prefer the more explicit use of map_blocks.

Option 2: naive use of map_blocks (my first attempt):

shape = (G.chunks[0], G.chunks[1], n_alleles)
da.map_blocks(count_alleles, G, N, chunks=shape).compute()

Results in incorrect array dimensions:

array([[[1, 0, 1, 0],
        [1, 0, 0, 1],
        [0, 1, 1, 0]],

       [[0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 0]]], dtype=uint8)

Option 3: Using map_blocks kwargs to "forget" the ploidy dimensions shape size:

shape = (G.chunks[0], G.chunks[1], n_alleles)
da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2).compute()

Result:

array([[[2, 0],
        [1, 1],
        [1, 1]],

       [[1, 0],
        [1, 0],
        [0, 0]]], dtype=uint8)

@timothymillar
Copy link
Collaborator Author

Does it go away if you do from numba import guvectorize and use the decorator that way? I'm not sure if the issue is with that or the numba.int8[:] signature annotations. Perhaps it will work with strings like int8[:] instead?

One of those seems to have fixed it, thanks for the help!

nopython=True,
)
def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
"""Generaliszed U-function for computing per sample allele counts.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spelling

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, not my strong point

>>> import sgkit as sg
>>> from sgkit.testing import simulate_genotype_call_dataset
>>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1)
>>> ds['call_genotype'].to_series().unstack().astype(str).apply('/'.join, axis=1).unstack() # doctest: +NORMALIZE_WHITESPACE
Copy link
Collaborator

@eric-czech eric-czech Aug 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good place for @tomwhite's code in https://github.com/pystatgen/sgkit/pull/58 now fyi.

@eric-czech
Copy link
Collaborator

This is good to go as far as I'm concerned. @tomwhite / @ravwojdyla / @jeromekelleher could one of you take a look as well? Two+ approvals seems appropriate for this one.

@timothymillar
Copy link
Collaborator Author

timothymillar commented Aug 24, 2020

Just for reference the following is a strait forward implementation only using xarray and dask but the performance is much worse that the gufunc version

def count_call_alleles(ds):
    G = ds["call_genotype"]
    n_variant, n_allele = G.shape[0], ds.dims["alleles"]
    G = G.expand_dims(dim="alleles", axis=-1)
    I = da.arange(n_allele, dtype='int8')
    A = G == I  # one-hot encoding of alleles
    AC = A.sum(axis=-2)
    return AC

Copy link
Collaborator

@jeromekelleher jeromekelleher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM too, although I don't understand the details of how this is interacting with numba.

@tomwhite, @ravwojdyla I think we should have one more vote here.

Copy link
Collaborator

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 this looks great @timothymillar.

@eric-czech
Copy link
Collaborator

Just for reference the following is a strait forward implementation only using xarray and dask

Clever! I bet that's still much faster than da.apply_along_axis(np.bincount, axis=2).

Thanks again for picking this up @timothymillar, nicely done. I'll set it to merge.

@eric-czech eric-czech added the auto-merge Auto merge label for mergify test flight label Aug 24, 2020
@ravwojdyla
Copy link
Collaborator

@Mergifyio refresh

@mergify
Copy link
Contributor

mergify bot commented Aug 24, 2020

Command refresh: success

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
auto-merge Auto merge label for mergify test flight
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants