Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] count_allele_calls #114

Merged
merged 16 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ xarray
dask[array]
scipy
numba
zarr
zarr
5 changes: 3 additions & 2 deletions sgkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from .display import display_genotypes
from .io.vcfzarr_reader import read_vcfzarr
from .stats.aggregation import count_alleles
from .stats.aggregation import count_call_alleles, count_variant_alleles
from .stats.association import gwas_linear_regression
from .stats.hwe import hardy_weinberg_test
from .stats.regenie import regenie
Expand All @@ -19,7 +19,8 @@
"DIM_SAMPLE",
"DIM_VARIANT",
"create_genotype_call_dataset",
"count_alleles",
"count_variant_alleles",
"count_call_alleles",
"create_genotype_dosage_dataset",
"display_genotypes",
"gwas_linear_regression",
Expand Down
144 changes: 107 additions & 37 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,105 @@
import dask.array as da
import numpy as np
import xarray as xr
from numba import guvectorize
from xarray import DataArray, Dataset

from ..typing import ArrayLike

def count_alleles(ds: Dataset) -> DataArray:

@guvectorize( # type: ignore
[
"void(int8[:], uint8[:], uint8[:])",
"void(int16[:], uint8[:], uint8[:])",
"void(int32[:], uint8[:], uint8[:])",
"void(int64[:], uint8[:], uint8[:])",
],
"(k),(n)->(n)",
nopython=True,
)
def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
"""Generalized U-function for computing per sample allele counts.

Parameters
----------
g : array_like
Genotype call of shape (ploidy,) containing alleles encoded as
type `int` with values < 0 indicating a missing allele.
_: array_like
Dummy variable of type `uint8` and shape (alleles,) used to
define the number of unique alleles to be counted in the
return value.

Returns
-------
ac : ndarray
Allele counts with shape (alleles,) and values corresponding to
the number of non-missing occurrences of each allele.

"""
out[:] = 0
n_allele = len(g)
for i in range(n_allele):
a = g[i]
if a >= 0:
out[a] += 1


def count_call_alleles(ds: Dataset) -> DataArray:
"""Compute per sample allele counts from genotype calls.

Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.

Returns
-------
call_allele_count : DataArray
Allele counts with shape (variants, samples, alleles) and values
corresponding to the number of non-missing occurrences
of each allele.

Examples
--------

>>> import sgkit as sg
>>> from sgkit.testing import simulate_genotype_call_dataset
>>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1)
>>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE
samples S0 S1
variants
0 1/0 1/0
1 1/0 1/1
2 0/1 1/0
3 0/0 0/0

>>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
array([[[1, 1],
[1, 1]],
<BLANKLINE>
[[1, 1],
[0, 2]],
<BLANKLINE>
[[1, 1],
[1, 1]],
<BLANKLINE>
[[2, 0],
[2, 0]]], dtype=uint8)
"""
n_alleles = ds.dims["alleles"]
G = da.asarray(ds["call_genotype"])
shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea using this instead! For the subsequent map_blocks call it might be a good idea to ensure this has only a single chunk so that the blocks broadcast. I can't imagine why anyone would ever configure the default chunk size to be so small that n_alleles items would result in multiple chunks, but guarding against it is probably a good idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't imagine why anyone would ever configure the default chunk size to be so small that n_alleles items would result in multiple chunks

Neither but this edge case should be handled correctly by this PR, see below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Thoughts on a da.asarray(ds["call_genotype"].chunk(chunks=dict(ploidy=-1))) then? I think ploidy chunks of size one could actually be pretty common when concatenating haplotype arrays.

I could also see the argument in leaving it up to the user to interpret those error messages and rechunk themselves so that they have to think through the performance implications, but I had already started being defensive about that in the GWAS regression methods. My rationale was that it's better to optimize for a less frustrating user experience than for making performance concerns prominent (and we could always add warnings for that later).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh nm, I see why it doesn't matter now (and the test for it). Ignore that.

return xr.DataArray(
da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2),
Copy link
Collaborator

@eric-czech eric-czech Aug 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what keeps this from working as count_alleles(G, N) instead? Do the block shapes need to be identical?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count_alleles(G, N) results in an error if G is chunked in the ploidy dimension. See comment below for an example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

dims=("variants", "samples", "alleles"),
name="call_allele_count",
)


def count_variant_alleles(ds: Dataset) -> DataArray:
"""Compute allele count from genotype calls.

Parameters
Expand All @@ -26,46 +121,21 @@ def count_alleles(ds: Dataset) -> DataArray:
>>> import sgkit as sg
>>> from sgkit.testing import simulate_genotype_call_dataset
>>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1)
>>> ds['call_genotype'].to_series().unstack().astype(str).apply('/'.join, axis=1).unstack() # doctest: +NORMALIZE_WHITESPACE
samples 0 1
>>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE
samples S0 S1
variants
0 1/0 1/0
1 1/0 1/1
2 0/1 1/0
3 0/0 0/0
0 1/0 1/0
1 1/0 1/1
2 0/1 1/0
3 0/0 0/0

>>> sg.count_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
>>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
array([[2, 2],
[1, 3],
[2, 2],
[4, 0]])
[4, 0]], dtype=uint64)
"""
# Count each allele index individually as a 1D vector and
# restack into new alleles dimension with same order
G = ds["call_genotype"].stack(calls=("samples", "ploidy"))
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
n_variant, n_allele = G.shape[0], ds.dims["alleles"]
max_allele = n_allele + 1

# Recode missing values as max allele index
G = xr.where(M, n_allele, G) # type: ignore[no-untyped-call]
G = da.asarray(G)

# Count allele indexes within each block
CT = da.map_blocks(
lambda x: np.apply_along_axis(np.bincount, 1, x, minlength=max_allele),
G,
chunks=(G.chunks[0], max_allele),
return xr.DataArray(
count_call_alleles(ds).sum(dim="samples").rename("variant_allele_count"),
dims=("variants", "alleles"),
)
assert CT.shape == (n_variant, G.numblocks[1] * max_allele)

# Stack the column blocks on top of each other
CTS = da.stack([CT.blocks[:, i] for i in range(CT.numblocks[1])])
assert CTS.shape == (CT.numblocks[1], n_variant, max_allele)

# Sum over column blocks and slice off allele
# index corresponding to missing values
AC = CTS.sum(axis=0)[:, :n_allele]
assert AC.shape == (n_variant, n_allele)

return DataArray(data=AC, dims=("variants", "alleles"), name="variant_allele_count")
128 changes: 112 additions & 16 deletions sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr
from xarray import Dataset

from sgkit.stats.aggregation import count_alleles
from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles
from sgkit.testing import simulate_genotype_call_dataset
from sgkit.typing import ArrayLike

Expand All @@ -20,23 +20,23 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:
return ds


def test_count_alleles__single_variant_single_sample():
ac = count_alleles(get_dataset([[[1, 0]]]))
def test_count_variant_alleles__single_variant_single_sample():
ac = count_variant_alleles(get_dataset([[[1, 0]]]))
np.testing.assert_equal(ac, np.array([[1, 1]]))


def test_count_alleles__multi_variant_single_sample():
ac = count_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
def test_count_variant_alleles__multi_variant_single_sample():
ac = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]]))


def test_count_alleles__single_variant_multi_sample():
ac = count_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
def test_count_variant_alleles__single_variant_multi_sample():
ac = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
np.testing.assert_equal(ac, np.array([[4, 4]]))


def test_count_alleles__multi_variant_multi_sample():
ac = count_alleles(
def test_count_variant_alleles__multi_variant_multi_sample():
ac = count_variant_alleles(
get_dataset(
[
[[0, 0], [0, 0], [0, 0]],
Expand All @@ -49,8 +49,8 @@ def test_count_alleles__multi_variant_multi_sample():
np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]]))


def test_count_alleles__missing_data():
ac = count_alleles(
def test_count_variant_alleles__missing_data():
ac = count_variant_alleles(
get_dataset(
[
[[-1, -1], [-1, -1], [-1, -1]],
Expand All @@ -63,8 +63,8 @@ def test_count_alleles__missing_data():
np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]]))


def test_count_alleles__higher_ploidy():
ac = count_alleles(
def test_count_variant_alleles__higher_ploidy():
ac = count_variant_alleles(
get_dataset(
[
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
Expand All @@ -77,12 +77,108 @@ def test_count_alleles__higher_ploidy():
np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]]))


def test_count_alleles__chunked():
def test_count_variant_alleles__chunked():
rs = np.random.RandomState(0)
calls = rs.randint(0, 1, size=(50, 10, 2))
ds = get_dataset(calls)
ac1 = count_alleles(ds)
ac1 = count_variant_alleles(ds)
# Coerce from numpy to multiple chunks in all dimensions
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
ac2 = count_alleles(ds)
ac2 = count_variant_alleles(ds)
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]


def test_count_call_alleles__single_variant_single_sample():
ac = count_call_alleles(get_dataset([[[1, 0]]]))
np.testing.assert_equal(ac, np.array([[[1, 1]]]))


def test_count_call_alleles__multi_variant_single_sample():
ac = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
np.testing.assert_equal(ac, np.array([[[2, 0]], [[1, 1]], [[1, 1]], [[0, 2]]]))


def test_count_call_alleles__single_variant_multi_sample():
ac = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
np.testing.assert_equal(ac, np.array([[[2, 0], [1, 1], [1, 1], [0, 2]]]))


def test_count_call_alleles__multi_variant_multi_sample():
ac = count_call_alleles(
get_dataset(
[
[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [0, 1]],
[[1, 1], [0, 1], [1, 0]],
[[1, 1], [1, 1], [1, 1]],
]
)
)
np.testing.assert_equal(
ac,
np.array(
[
[[2, 0], [2, 0], [2, 0]],
[[2, 0], [2, 0], [1, 1]],
[[0, 2], [1, 1], [1, 1]],
[[0, 2], [0, 2], [0, 2]],
]
),
)


def test_count_call_alleles__missing_data():
ac = count_call_alleles(
get_dataset(
[
[[-1, -1], [-1, -1], [-1, -1]],
[[-1, -1], [0, 0], [-1, 1]],
[[1, 1], [-1, -1], [-1, 0]],
[[1, 1], [1, 1], [1, 1]],
]
)
)
np.testing.assert_equal(
ac,
np.array(
[
[[0, 0], [0, 0], [0, 0]],
[[0, 0], [2, 0], [0, 1]],
[[0, 2], [0, 0], [1, 0]],
[[0, 2], [0, 2], [0, 2]],
]
),
)


def test_count_call_alleles__higher_ploidy():
ac = count_call_alleles(
get_dataset(
[
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
[[0, 1, 2], [1, 2, 3], [-1, -1, -1]],
],
n_allele=4,
n_ploidy=3,
)
)
np.testing.assert_equal(
ac,
np.array(
[
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
[[1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 0, 0]],
]
),
)


def test_count_call_alleles__chunked():
rs = np.random.RandomState(0)
calls = rs.randint(0, 1, size=(50, 10, 2))
ds = get_dataset(calls)
ac1 = count_call_alleles(ds)
# Coerce from numpy to multiple chunks in all dimensions
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
ac2 = count_call_alleles(ds)
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]