Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summary stats #102

Merged
merged 8 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Methods
gwas_linear_regression
hardy_weinberg_test
regenie
variant_stats

Utilities
=========
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ numpy
xarray
dask[array]
scipy
typing-extensions
numba
zarr
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ignore =
profile = black
default_section = THIRDPARTY
known_first_party = sgkit
known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,xarray,yaml,zarr
known_third_party = dask,fire,glow,hail,hypothesis,invoke,numba,numpy,pandas,pkg_resources,pyspark,pytest,setuptools,sgkit_plink,typing_extensions,xarray,yaml,zarr
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 0
Expand Down
3 changes: 2 additions & 1 deletion sgkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from .display import display_genotypes
from .io.vcfzarr_reader import read_vcfzarr
from .stats.aggregation import count_call_alleles, count_variant_alleles
from .stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
from .stats.association import gwas_linear_regression
from .stats.hwe import hardy_weinberg_test
from .stats.regenie import regenie
Expand All @@ -27,4 +27,5 @@
"read_vcfzarr",
"regenie",
"hardy_weinberg_test",
"variant_stats",
]
94 changes: 94 additions & 0 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import Any, Dict, Hashable

import dask.array as da
import numpy as np
import xarray as xr
from numba import guvectorize
from typing_extensions import Literal
from xarray import Dataset

from sgkit.typing import ArrayLike
from sgkit.utils import merge_datasets

Dimension = Literal["samples", "variants"]


@guvectorize( # type: ignore
[
Expand Down Expand Up @@ -162,3 +168,91 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
}
)
return merge_datasets(ds, new_ds) if merge else new_ds


def _swap(dim: Dimension) -> Dimension:
return "samples" if dim == "variants" else "variants"


def call_rate(ds: Dataset, dim: Dimension) -> Dataset:
odim = _swap(dim)[:-1]
n_called = (~ds["call_genotype_mask"].any(dim="ploidy")).sum(dim=dim)
return xr.Dataset(
{f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]}
)


def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now fixed.

odim = _swap(dim)[:-1]
M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"]
n_hom_ref = (G == 0).all(dim="ploidy")
n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy")
n_non_ref = (G > 0).any(dim="ploidy")
n_het = ~(n_hom_alt | n_hom_ref)
# This would 0 out the `het` case with any missing calls
agg = lambda x: xr.where(M, False, x).sum(dim=dim) # type: ignore[no-untyped-call]
return Dataset(
{
f"{odim}_n_het": agg(n_het), # type: ignore[no-untyped-call]
f"{odim}_n_hom_ref": agg(n_hom_ref), # type: ignore[no-untyped-call]
f"{odim}_n_hom_alt": agg(n_hom_alt), # type: ignore[no-untyped-call]
f"{odim}_n_non_ref": agg(n_non_ref), # type: ignore[no-untyped-call]
}
)


def allele_frequency(ds: Dataset) -> Dataset:
data_vars: Dict[Hashable, Any] = {}
# only compute variant allele count if not already in dataset
if "variant_allele_count" in ds:
AC = ds["variant_allele_count"]
else:
AC = count_variant_alleles(ds, merge=False)["variant_allele_count"]
data_vars["variant_allele_count"] = AC

M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
AN = (~M).sum(dim="calls") # type: ignore
assert AN.shape == (ds.dims["variants"],)

data_vars["variant_allele_total"] = AN
data_vars["variant_allele_frequency"] = AC / AN
return Dataset(data_vars)


def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
"""Compute quality control variant statistics from genotype calls.

Parameters
----------
ds : Dataset
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge : bool, optional
If True (the default), merge the input dataset and the computed
output variables into a single dataset. Output variables will
overwrite any input variables with the same name, and a warning
will be issued in this case.
If False, return only the computed output variables.

Returns
-------
Dataset
A dataset containing the following variables:
- `variant_n_called` (variants): The number of samples with called genotypes.
- `variant_call_rate` (variants): The fraction of samples with called genotypes.
- `variant_n_het` (variants): The number of samples with heterozygous calls.
- `variant_n_hom_ref` (variants): The number of samples with homozygous reference calls.
- `variant_n_hom_alt` (variants): The number of samples with homozygous alternate calls.
- `variant_n_non_ref` (variants): The number of samples that are not homozygous reference calls.
- `variant_allele_count` (variants, alleles): The number of occurrences of each allele.
- `variant_allele_total` (variants): The number of occurrences of all alleles.
- `variant_allele_frequency` (variants, alleles): The frequency of occurence of each allele.
"""
new_ds = xr.merge(
[
call_rate(ds, dim="samples"),
genotype_count(ds, dim="samples"),
allele_frequency(ds),
]
)
return merge_datasets(ds, new_ds) if merge else new_ds
33 changes: 32 additions & 1 deletion sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Any

import numpy as np
import pytest
import xarray as xr
from xarray import Dataset

from sgkit.stats.aggregation import count_call_alleles, count_variant_alleles
from sgkit.stats.aggregation import (
count_call_alleles,
count_variant_alleles,
variant_stats,
)
from sgkit.testing import simulate_genotype_call_dataset
from sgkit.typing import ArrayLike

Expand Down Expand Up @@ -202,3 +207,29 @@ def test_count_call_alleles__chunked():
ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) # type: ignore[arg-type]
ac2 = count_call_alleles(ds)
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]


@pytest.mark.parametrize("precompute_variant_allele_count", [False, True])
def test_variant_stats(precompute_variant_allele_count):
ds = get_dataset(
[[[1, 0], [-1, -1]], [[1, 0], [1, 1]], [[0, 1], [1, 0]], [[-1, -1], [0, 0]]]
)
if precompute_variant_allele_count:
ds = count_variant_alleles(ds)
vs = variant_stats(ds)

np.testing.assert_equal(vs["variant_n_called"], np.array([1, 2, 2, 1]))
np.testing.assert_equal(vs["variant_call_rate"], np.array([0.5, 1.0, 1.0, 0.5]))
np.testing.assert_equal(vs["variant_n_hom_ref"], np.array([0, 0, 0, 1]))
np.testing.assert_equal(vs["variant_n_hom_alt"], np.array([0, 1, 0, 0]))
np.testing.assert_equal(vs["variant_n_het"], np.array([1, 1, 2, 0]))
np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0]))
np.testing.assert_equal(vs["variant_n_non_ref"], np.array([1, 2, 2, 0]))
np.testing.assert_equal(
vs["variant_allele_count"], np.array([[1, 1], [1, 3], [2, 2], [2, 0]])
)
np.testing.assert_equal(vs["variant_allele_total"], np.array([2, 4, 4, 2]))
np.testing.assert_equal(
vs["variant_allele_frequency"],
np.array([[0.5, 0.5], [0.25, 0.75], [0.5, 0.5], [1, 0]]),
)