Skip to content

Commit

Permalink
Implement count_allele_calls #85
Browse files Browse the repository at this point in the history
* Implement count_call_alleles #85

* Fixes for mypy and black

* Add dependency on numba

* gufunc implementation of count_alleles

* Fix count alleles bug for chunking

* Fix doctest for count_call_alleles

* Docstring for gufunc

* Remove duplication in setup.cfg

* Fixes for pre-commit

* Add ignore type checking for guvectorize decorator

* Explicit import of guvectorize

* Use display_genotypes in aggregation docstrings

* Numpy style docstring for count_alleles

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
timothymillar and mergify[bot] authored Aug 25, 2020
1 parent 84f79c2 commit 320ebc9
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 56 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ xarray
dask[array]
scipy
numba
zarr
zarr
5 changes: 3 additions & 2 deletions sgkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from .display import display_genotypes
from .io.vcfzarr_reader import read_vcfzarr
from .stats.aggregation import count_alleles
from .stats.aggregation import count_call_alleles, count_variant_alleles
from .stats.association import gwas_linear_regression
from .stats.hwe import hardy_weinberg_test
from .stats.regenie import regenie
Expand All @@ -19,7 +19,8 @@
"DIM_SAMPLE",
"DIM_VARIANT",
"create_genotype_call_dataset",
"count_alleles",
"count_variant_alleles",
"count_call_alleles",
"create_genotype_dosage_dataset",
"display_genotypes",
"gwas_linear_regression",
Expand Down
144 changes: 107 additions & 37 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,105 @@
import dask.array as da
import numpy as np
import xarray as xr
from numba import guvectorize
from xarray import DataArray, Dataset

from ..typing import ArrayLike

def count_alleles(ds: Dataset) -> DataArray:

@guvectorize( # type: ignore
[
"void(int8[:], uint8[:], uint8[:])",
"void(int16[:], uint8[:], uint8[:])",
"void(int32[:], uint8[:], uint8[:])",
"void(int64[:], uint8[:], uint8[:])",
],
"(k),(n)->(n)",
nopython=True,
)
def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
"""Generalized U-function for computing per sample allele counts.
Parameters
----------
g : array_like
Genotype call of shape (ploidy,) containing alleles encoded as
type `int` with values < 0 indicating a missing allele.
_: array_like
Dummy variable of type `uint8` and shape (alleles,) used to
define the number of unique alleles to be counted in the
return value.
Returns
-------
ac : ndarray
Allele counts with shape (alleles,) and values corresponding to
the number of non-missing occurrences of each allele.
"""
out[:] = 0
n_allele = len(g)
for i in range(n_allele):
a = g[i]
if a >= 0:
out[a] += 1


def count_call_alleles(ds: Dataset) -> DataArray:
"""Compute per sample allele counts from genotype calls.
Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
Returns
-------
call_allele_count : DataArray
Allele counts with shape (variants, samples, alleles) and values
corresponding to the number of non-missing occurrences
of each allele.
Examples
--------
>>> import sgkit as sg
>>> from sgkit.testing import simulate_genotype_call_dataset
>>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1)
>>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE
samples S0 S1
variants
0 1/0 1/0
1 1/0 1/1
2 0/1 1/0
3 0/0 0/0
>>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
array([[[1, 1],
[1, 1]],
<BLANKLINE>
[[1, 1],
[0, 2]],
<BLANKLINE>
[[1, 1],
[1, 1]],
<BLANKLINE>
[[2, 0],
[2, 0]]], dtype=uint8)
"""
n_alleles = ds.dims["alleles"]
G = da.asarray(ds["call_genotype"])
shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
return xr.DataArray(
da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2),
dims=("variants", "samples", "alleles"),
name="call_allele_count",
)


def count_variant_alleles(ds: Dataset) -> DataArray:
"""Compute allele count from genotype calls.
Parameters
Expand All @@ -26,46 +121,21 @@ def count_alleles(ds: Dataset) -> DataArray:
>>> import sgkit as sg
>>> from sgkit.testing import simulate_genotype_call_dataset
>>> ds = simulate_genotype_call_dataset(n_variant=4, n_sample=2, seed=1)
>>> ds['call_genotype'].to_series().unstack().astype(str).apply('/'.join, axis=1).unstack() # doctest: +NORMALIZE_WHITESPACE
samples 0 1
>>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE
samples S0 S1
variants
0 1/0 1/0
1 1/0 1/1
2 0/1 1/0
3 0/0 0/0
0 1/0 1/0
1 1/0 1/1
2 0/1 1/0
3 0/0 0/0
>>> sg.count_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
>>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
array([[2, 2],
[1, 3],
[2, 2],
[4, 0]])
[4, 0]], dtype=uint64)
"""
# Count each allele index individually as a 1D vector and
# restack into new alleles dimension with same order
G = ds["call_genotype"].stack(calls=("samples", "ploidy"))
M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
n_variant, n_allele = G.shape[0], ds.dims["alleles"]
max_allele = n_allele + 1

# Recode missing values as max allele index
G = xr.where(M, n_allele, G) # type: ignore[no-untyped-call]
G = da.asarray(G)

# Count allele indexes within each block
CT = da.map_blocks(
lambda x: np.apply_along_axis(np.bincount, 1, x, minlength=max_allele),
G,
chunks=(G.chunks[0], max_allele),
return xr.DataArray(
count_call_alleles(ds).sum(dim="samples").rename("variant_allele_count"),
dims=("variants", "alleles"),
)
assert CT.shape == (n_variant, G.numblocks[1] * max_allele)

# Stack the column blocks on top of each other
CTS = da.stack([CT.blocks[:, i] for i in range(CT.numblocks[1])])
assert CTS.shape == (CT.numblocks[1], n_variant, max_allele)

# Sum over column blocks and slice off allele
# index corresponding to missing values
AC = CTS.sum(axis=0)[:, :n_allele]
assert AC.shape == (n_variant, n_allele)

return DataArray(data=AC, dims=("variants", "alleles"), name="variant_allele_count")
128 changes: 112 additions & 16 deletions sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr
from xarray import Dataset

from sgkit.stats.aggregation import count_alleles
from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles
from sgkit.testing import simulate_genotype_call_dataset
from sgkit.typing import ArrayLike

Expand All @@ -20,23 +20,23 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:
return ds


def test_count_alleles__single_variant_single_sample():
ac = count_alleles(get_dataset([[[1, 0]]]))
def test_count_variant_alleles__single_variant_single_sample():
ac = count_variant_alleles(get_dataset([[[1, 0]]]))
np.testing.assert_equal(ac, np.array([[1, 1]]))


def test_count_alleles__multi_variant_single_sample():
ac = count_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
def test_count_variant_alleles__multi_variant_single_sample():
ac = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]]))


def test_count_alleles__single_variant_multi_sample():
ac = count_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
def test_count_variant_alleles__single_variant_multi_sample():
ac = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
np.testing.assert_equal(ac, np.array([[4, 4]]))


def test_count_alleles__multi_variant_multi_sample():
ac = count_alleles(
def test_count_variant_alleles__multi_variant_multi_sample():
ac = count_variant_alleles(
get_dataset(
[
[[0, 0], [0, 0], [0, 0]],
Expand All @@ -49,8 +49,8 @@ def test_count_alleles__multi_variant_multi_sample():
np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]]))


def test_count_alleles__missing_data():
ac = count_alleles(
def test_count_variant_alleles__missing_data():
ac = count_variant_alleles(
get_dataset(
[
[[-1, -1], [-1, -1], [-1, -1]],
Expand All @@ -63,8 +63,8 @@ def test_count_alleles__missing_data():
np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]]))


def test_count_alleles__higher_ploidy():
ac = count_alleles(
def test_count_variant_alleles__higher_ploidy():
ac = count_variant_alleles(
get_dataset(
[
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
Expand All @@ -77,12 +77,108 @@ def test_count_alleles__higher_ploidy():
np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]]))


def test_count_alleles__chunked():
def test_count_variant_alleles__chunked():
rs = np.random.RandomState(0)
calls = rs.randint(0, 1, size=(50, 10, 2))
ds = get_dataset(calls)
ac1 = count_alleles(ds)
ac1 = count_variant_alleles(ds)
# Coerce from numpy to multiple chunks in all dimensions
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
ac2 = count_alleles(ds)
ac2 = count_variant_alleles(ds)
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]


def test_count_call_alleles__single_variant_single_sample():
ac = count_call_alleles(get_dataset([[[1, 0]]]))
np.testing.assert_equal(ac, np.array([[[1, 1]]]))


def test_count_call_alleles__multi_variant_single_sample():
ac = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
np.testing.assert_equal(ac, np.array([[[2, 0]], [[1, 1]], [[1, 1]], [[0, 2]]]))


def test_count_call_alleles__single_variant_multi_sample():
ac = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
np.testing.assert_equal(ac, np.array([[[2, 0], [1, 1], [1, 1], [0, 2]]]))


def test_count_call_alleles__multi_variant_multi_sample():
ac = count_call_alleles(
get_dataset(
[
[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [0, 1]],
[[1, 1], [0, 1], [1, 0]],
[[1, 1], [1, 1], [1, 1]],
]
)
)
np.testing.assert_equal(
ac,
np.array(
[
[[2, 0], [2, 0], [2, 0]],
[[2, 0], [2, 0], [1, 1]],
[[0, 2], [1, 1], [1, 1]],
[[0, 2], [0, 2], [0, 2]],
]
),
)


def test_count_call_alleles__missing_data():
ac = count_call_alleles(
get_dataset(
[
[[-1, -1], [-1, -1], [-1, -1]],
[[-1, -1], [0, 0], [-1, 1]],
[[1, 1], [-1, -1], [-1, 0]],
[[1, 1], [1, 1], [1, 1]],
]
)
)
np.testing.assert_equal(
ac,
np.array(
[
[[0, 0], [0, 0], [0, 0]],
[[0, 0], [2, 0], [0, 1]],
[[0, 2], [0, 0], [1, 0]],
[[0, 2], [0, 2], [0, 2]],
]
),
)


def test_count_call_alleles__higher_ploidy():
ac = count_call_alleles(
get_dataset(
[
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
[[0, 1, 2], [1, 2, 3], [-1, -1, -1]],
],
n_allele=4,
n_ploidy=3,
)
)
np.testing.assert_equal(
ac,
np.array(
[
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]],
[[1, 1, 1, 0], [0, 1, 1, 1], [0, 0, 0, 0]],
]
),
)


def test_count_call_alleles__chunked():
rs = np.random.RandomState(0)
calls = rs.randint(0, 1, size=(50, 10, 2))
ds = get_dataset(calls)
ac1 = count_call_alleles(ds)
# Coerce from numpy to multiple chunks in all dimensions
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
ac2 = count_call_alleles(ds)
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]

0 comments on commit 320ebc9

Please sign in to comment.