diff --git a/docs/api.rst b/docs/api.rst index 6a075f813..16d055d7f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -27,6 +27,7 @@ Methods gwas_linear_regression hardy_weinberg_test regenie + variant_stats Utilities ========= diff --git a/requirements.txt b/requirements.txt index 0793ee0e8..0b0721530 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ numpy xarray dask[array] scipy +typing-extensions numba zarr diff --git a/setup.cfg b/setup.cfg index af85ef741..abe82e850 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,7 +61,7 @@ ignore = profile = black default_section = THIRDPARTY known_first_party = sgkit -known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,xarray,yaml,zarr +known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,typing_extensions,xarray,yaml,zarr multi_line_output = 3 include_trailing_comma = True force_grid_wrap = 0 diff --git a/sgkit/__init__.py b/sgkit/__init__.py index 807bf7293..9a90dc1f0 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -8,7 +8,7 @@ ) from .display import display_genotypes from .io.vcfzarr_reader import read_vcfzarr -from .stats.aggregation import count_call_alleles, count_variant_alleles +from .stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats from .stats.association import gwas_linear_regression from .stats.hwe import hardy_weinberg_test from .stats.regenie import regenie @@ -27,4 +27,5 @@ "read_vcfzarr", "regenie", "hardy_weinberg_test", + "variant_stats", ] diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index bd08fa656..7859fde75 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,11 +1,17 @@ +from typing import Any, Dict, Hashable + import dask.array as da import numpy as np +import xarray as xr from numba import guvectorize +from typing_extensions import Literal from xarray import Dataset from sgkit.typing import ArrayLike from sgkit.utils import merge_datasets +Dimension = Literal["samples", "variants"] + @guvectorize( # type: ignore [ @@ -162,3 +168,91 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: } ) return merge_datasets(ds, new_ds) if merge else new_ds + + +def _swap(dim: Dimension) -> Dimension: + return "samples" if dim == "variants" else "variants" + + +def call_rate(ds: Dataset, dim: Dimension) -> Dataset: + odim = _swap(dim)[:-1] + n_called = (~ds["call_genotype_mask"].any(dim="ploidy")).sum(dim=dim) + return xr.Dataset( + {f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]} + ) + + +def genotype_count(ds: Dataset, dim: Dimension) -> Dataset: + odim = _swap(dim)[:-1] + M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"] + n_hom_ref = (G == 0).all(dim="ploidy") + n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy") + n_non_ref = (G > 0).any(dim="ploidy") + n_het = ~(n_hom_alt | n_hom_ref) + # This would 0 out the `het` case with any missing calls + agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call] + return Dataset( + { + f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call] + f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call] + f"{odim}_n_hom_alt": agg(n_hom_alt), # type: ignore[no-untyped-call] + f"{odim}_n_non_ref": agg(n_non_ref), # type: ignore[no-untyped-call] + } + ) + + +def allele_frequency(ds: Dataset) -> Dataset: + data_vars: Dict[Hashable, Any] = {} + # only compute variant allele count if not already in dataset + if "variant_allele_count" in ds: + AC = ds["variant_allele_count"] + else: + AC = count_variant_alleles(ds, merge=False)["variant_allele_count"] + data_vars["variant_allele_count"] = AC + + M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy")) + AN = (~M).sum(dim="calls") # type: ignore + assert AN.shape == (ds.dims["variants"],) + + data_vars["variant_allele_total"] = AN + data_vars["variant_allele_frequency"] = AC / AN + return Dataset(data_vars) + + +def variant_stats(ds: Dataset, merge: bool = True) -> Dataset: + """Compute quality control variant statistics from genotype calls. + + Parameters + ---------- + ds : Dataset + Genotype call dataset such as from + `sgkit.create_genotype_call_dataset`. + merge : bool, optional + If True (the default), merge the input dataset and the computed + output variables into a single dataset. Output variables will + overwrite any input variables with the same name, and a warning + will be issued in this case. + If False, return only the computed output variables. + + Returns + ------- + Dataset + A dataset containing the following variables: + - `variant_n_called` (variants): The number of samples with called genotypes. + - `variant_call_rate` (variants): The fraction of samples with called genotypes. + - `variant_n_het` (variants): The number of samples with heterozygous calls. + - `variant_n_hom_ref` (variants): The number of samples with homozygous reference calls. + - `variant_n_hom_alt` (variants): The number of samples with homozygous alternate calls. + - `variant_n_non_ref` (variants): The number of samples that are not homozygous reference calls. + - `variant_allele_count` (variants, alleles): The number of occurrences of each allele. + - `variant_allele_total` (variants): The number of occurrences of all alleles. + - `variant_allele_frequency` (variants, alleles): The frequency of occurence of each allele. + """ + new_ds = xr.merge( + [ + call_rate(ds, dim="samples"), + genotype_count(ds, dim="samples"), + allele_frequency(ds), + ] + ) + return merge_datasets(ds, new_ds) if merge else new_ds diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 291b81aca..8951f842b 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -1,10 +1,15 @@ from typing import Any import numpy as np +import pytest import xarray as xr from xarray import Dataset -from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles +from sgkit.stats.aggregation import ( + count_call_alleles, + count_variant_alleles, + variant_stats, +) from sgkit.testing import simulate_genotype_call_dataset from sgkit.typing import ArrayLike @@ -202,3 +207,29 @@ def test_count_call_alleles__chunked(): ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type] ac2 = count_call_alleles(ds) xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call] + + +@pytest.mark.parametrize("precompute_variant_allele_count", [False, True]) +def test_variant_stats(precompute_variant_allele_count): + ds = get_dataset( + [[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]] + ) + if precompute_variant_allele_count: + ds = count_variant_alleles(ds) + vs = variant_stats(ds) + + np.testing.assert_equal(vs["variant_n_called"], np.array([1, 2, 2, 1])) + np.testing.assert_equal(vs["variant_call_rate"], np.array([0.5, 1.0, 1.0, 0.5])) + np.testing.assert_equal(vs["variant_n_hom_ref"], np.array([0, 0, 0, 1])) + np.testing.assert_equal(vs["variant_n_hom_alt"], np.array([0, 1, 0, 0])) + np.testing.assert_equal(vs["variant_n_het"], np.array([1, 1, 2, 0])) + np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0])) + np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0])) + np.testing.assert_equal( + vs["variant_allele_count"], np.array([[1, 1], [1, 3], [2, 2], [2, 0]]) + ) + np.testing.assert_equal(vs["variant_allele_total"], np.array([2, 4, 4, 2])) + np.testing.assert_equal( + vs["variant_allele_frequency"], + np.array([[0.5, 0.5], [0.25, 0.75], [0.5, 0.5], [1, 0]]), + )