Skip to content

Commit

Permalink
Add merge=True to count allele functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Aug 31, 2020
1 parent 84a2e5e commit 6b8d74a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
16 changes: 12 additions & 4 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
out[a] += 1


def count_call_alleles(ds: Dataset) -> Dataset:
def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
"""Compute per sample allele counts from genotype calls.
Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge : bool
If True, merge the input dataset and the computed variables into
a single dataset, otherwise return only the computed variables.
Returns
-------
Expand Down Expand Up @@ -91,7 +94,7 @@ def count_call_alleles(ds: Dataset) -> Dataset:
G = da.asarray(ds["call_genotype"])
shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
return Dataset(
new_ds = Dataset(
{
"call_allele_count": (
("variants", "samples", "alleles"),
Expand All @@ -101,16 +104,20 @@ def count_call_alleles(ds: Dataset) -> Dataset:
)
}
)
return ds.merge(new_ds) if merge else new_ds


def count_variant_alleles(ds: Dataset) -> Dataset:
def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
"""Compute allele count from genotype calls.
Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge : bool
If True, merge the input dataset and the computed variables into
a single dataset, otherwise return only the computed variables.
Returns
-------
Expand Down Expand Up @@ -139,11 +146,12 @@ def count_variant_alleles(ds: Dataset) -> Dataset:
[2, 2],
[4, 0]], dtype=uint64)
"""
return Dataset(
new_ds = Dataset(
{
"variant_allele_count": (
("variants", "alleles"),
count_call_alleles(ds)["call_allele_count"].sum(dim="samples"),
)
}
)
return ds.merge(new_ds) if merge else new_ds
8 changes: 8 additions & 0 deletions sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:

def test_count_variant_alleles__single_variant_single_sample():
ds = count_variant_alleles(get_dataset([[[1, 0]]]))
assert "call_genotype" in ds
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[1, 1]]))

Expand Down Expand Up @@ -94,6 +95,13 @@ def test_count_variant_alleles__chunked():
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]


def test_count_variant_alleles__no_merge():
ds = count_variant_alleles(get_dataset([[[1, 0]]]), merge=False)
assert "call_genotype" not in ds
ac = ds["variant_allele_count"]
np.testing.assert_equal(ac, np.array([[1, 1]]))


def test_count_call_alleles__single_variant_single_sample():
ds = count_call_alleles(get_dataset([[[1, 0]]]))
ac = ds["call_allele_count"]
Expand Down

0 comments on commit 6b8d74a

Please sign in to comment.