From 6b8d74a28f0048d64f8bd849b6dc3c8718cc1ccd Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 31 Aug 2020 11:19:36 +0100 Subject: [PATCH] Add merge=True to count allele functions. --- sgkit/stats/aggregation.py | 16 ++++++++++++---- sgkit/tests/test_aggregation.py | 8 ++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 388c3f309..353ec6c49 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -44,7 +44,7 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: out[a] += 1 -def count_call_alleles(ds: Dataset) -> Dataset: +def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: """Compute per sample allele counts from genotype calls. Parameters @@ -52,6 +52,9 @@ def count_call_alleles(ds: Dataset) -> Dataset: ds : Dataset Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. + merge : bool + If True, merge the input dataset and the computed variables into + a single dataset, otherwise return only the computed variables. Returns ------- @@ -91,7 +94,7 @@ def count_call_alleles(ds: Dataset) -> Dataset: G = da.asarray(ds["call_genotype"]) shape = (G.chunks[0], G.chunks[1], n_alleles) N = da.empty(n_alleles, dtype=np.uint8) - return Dataset( + new_ds = Dataset( { "call_allele_count": ( ("variants", "samples", "alleles"), @@ -101,9 +104,10 @@ def count_call_alleles(ds: Dataset) -> Dataset: ) } ) + return ds.merge(new_ds) if merge else new_ds -def count_variant_alleles(ds: Dataset) -> Dataset: +def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: """Compute allele count from genotype calls. Parameters @@ -111,6 +115,9 @@ def count_variant_alleles(ds: Dataset) -> Dataset: ds : Dataset Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. + merge : bool + If True, merge the input dataset and the computed variables into + a single dataset, otherwise return only the computed variables. Returns ------- @@ -139,7 +146,7 @@ def count_variant_alleles(ds: Dataset) -> Dataset: [2, 2], [4, 0]], dtype=uint64) """ - return Dataset( + new_ds = Dataset( { "variant_allele_count": ( ("variants", "alleles"), @@ -147,3 +154,4 @@ def count_variant_alleles(ds: Dataset) -> Dataset: ) } ) + return ds.merge(new_ds) if merge else new_ds diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 1cdebb9bb..291b81aca 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -22,6 +22,7 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset: def test_count_variant_alleles__single_variant_single_sample(): ds = count_variant_alleles(get_dataset([[[1, 0]]])) + assert "call_genotype" in ds ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[1, 1]])) @@ -94,6 +95,13 @@ def test_count_variant_alleles__chunked(): xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call] +def test_count_variant_alleles__no_merge(): + ds = count_variant_alleles(get_dataset([[[1, 0]]]), merge=False) + assert "call_genotype" not in ds + ac = ds["variant_allele_count"] + np.testing.assert_equal(ac, np.array([[1, 1]])) + + def test_count_call_alleles__single_variant_single_sample(): ds = count_call_alleles(get_dataset([[[1, 0]]])) ac = ds["call_allele_count"]