Skip to content

Commit

Permalink
Add merge parameter to variant_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Aug 31, 2020
1 parent cc4f4cd commit 88c08b9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
28 changes: 17 additions & 11 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Dict, Hashable

import dask.array as da
import numpy as np
import xarray as xr
from numba import guvectorize
from typing_extensions import Literal
from xarray import Dataset
Expand Down Expand Up @@ -181,7 +184,7 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
n_het = ~(n_hom_alt | n_hom_ref)
# This would 0 out the `het` case with any missing calls
agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call]
return xr.Dataset(
return Dataset(
{
f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call]
f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call]
Expand All @@ -192,26 +195,29 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:


def allele_frequency(ds: Dataset) -> Dataset:
AC = count_variant_alleles(ds)
data_vars: Dict[Hashable, Any] = {}
# only compute variant allele count if not already in dataset
if "variant_allele_count" in ds:
AC = ds["variant_allele_count"]
else:
AC = count_variant_alleles(ds, merge=False)["variant_allele_count"]
data_vars["variant_allele_count"] = AC

M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
AN = (~M).sum(dim="calls") # type: ignore
assert AN.shape == (ds.dims["variants"],)

return xr.Dataset(
{
"variant_allele_count": AC,
"variant_allele_total": AN,
"variant_allele_frequency": AC / AN,
}
)
data_vars["variant_allele_total"] = AN
data_vars["variant_allele_frequency"] = AC / AN
return Dataset(data_vars)


def variant_stats(ds: Dataset) -> Dataset:
return xr.merge(
def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
new_ds = xr.merge(
[
call_rate(ds, dim="samples"),
genotype_count(ds, dim="samples"),
allele_frequency(ds),
]
)
return ds.merge(new_ds) if merge else new_ds
12 changes: 10 additions & 2 deletions sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Any

import numpy as np
import pytest
import xarray as xr
from xarray import Dataset

from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
from sgkit.stats.aggregation import (
count_call_alleles,
count_variant_alleles,
variant_stats,
)
from sgkit.testing import simulate_genotype_call_dataset
from sgkit.typing import ArrayLike

Expand Down Expand Up @@ -204,10 +209,13 @@ def test_count_call_alleles__chunked():
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]


def test_variant_stats():
@pytest.mark.parametrize("precompute_variant_allele_count", [False, True])
def test_variant_stats(precompute_variant_allele_count):
ds = get_dataset(
[[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]]
)
if precompute_variant_allele_count:
ds = count_variant_alleles(ds)
vs = variant_stats(ds)

np.testing.assert_equal(vs["variant_n_called"], np.array([1, 2, 2, 1]))
Expand Down

0 comments on commit 88c08b9

Please sign in to comment.