Skip to content

Commit

Permalink
Use SgkitVariables in the computation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ravwojdyla committed Sep 24, 2020
1 parent 4d9aabb commit 54b56f1
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 95 deletions.
20 changes: 7 additions & 13 deletions sgkit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import xarray as xr

from . import variables
from .typing import ArrayLike
from .utils import check_array_like

Expand Down Expand Up @@ -57,11 +58,6 @@ def create_genotype_call_dataset(
-------
The dataset of genotype calls.
"""
check_array_like(variant_contig, kind="i", ndim=1)
check_array_like(variant_position, kind="i", ndim=1)
check_array_like(variant_alleles, kind={"S", "O"}, ndim=2)
check_array_like(sample_id, kind={"U", "O"}, ndim=1)
check_array_like(call_genotype, kind="i", ndim=3)
data_vars: Dict[Hashable, Any] = {
"variant_contig": ([DIM_VARIANT], variant_contig),
"variant_position": ([DIM_VARIANT], variant_position),
Expand All @@ -83,7 +79,9 @@ def create_genotype_call_dataset(
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return xr.Dataset(data_vars=data_vars, attrs=attrs)
return variables.validate(
xr.Dataset(data_vars=data_vars, attrs=attrs), *data_vars.keys()
)


def create_genotype_dosage_dataset(
Expand Down Expand Up @@ -132,12 +130,6 @@ def create_genotype_dosage_dataset(
The dataset of genotype calls.
"""
check_array_like(variant_contig, kind="i", ndim=1)
check_array_like(variant_position, kind="i", ndim=1)
check_array_like(variant_alleles, kind={"S", "O"}, ndim=2)
check_array_like(sample_id, kind={"U", "O"}, ndim=1)
check_array_like(call_dosage, kind="f", ndim=2)
check_array_like(call_genotype_probability, kind="f", ndim=3)
data_vars: Dict[Hashable, Any] = {
"variant_contig": ([DIM_VARIANT], variant_contig),
"variant_position": ([DIM_VARIANT], variant_position),
Expand All @@ -158,4 +150,6 @@ def create_genotype_dosage_dataset(
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
return xr.Dataset(data_vars=data_vars, attrs=attrs)
return variables.validate(
xr.Dataset(data_vars=data_vars, attrs=attrs), *data_vars.keys()
)
104 changes: 82 additions & 22 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Hashable
from typing import Any, Dict, Hashable, Optional

import dask.array as da
import numpy as np
Expand All @@ -7,6 +7,7 @@
from typing_extensions import Literal
from xarray import Dataset

from sgkit import variables
from sgkit.typing import ArrayLike
from sgkit.utils import conditional_merge_datasets

Expand Down Expand Up @@ -51,7 +52,9 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
out[a] += 1


def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
def count_call_alleles(
ds: Dataset, merge: bool = True, *, call_genotype: str = "call_genotype"
) -> Dataset:
"""Compute per sample allele counts from genotype calls.
Parameters
Expand All @@ -60,11 +63,12 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge
(optional)
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.
call_genotype
Input variable name holding call_genotype as defined by `sgkit.variables.call_genotype`
Returns
-------
Expand Down Expand Up @@ -99,8 +103,9 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
[[2, 0],
[2, 0]]], dtype=uint8)
"""
variables.validate(ds, {call_genotype: variables.call_genotype})
n_alleles = ds.dims["alleles"]
G = da.asarray(ds["call_genotype"])
G = da.asarray(ds[call_genotype])
shape = (G.chunks[0], G.chunks[1], n_alleles)
N = da.empty(n_alleles, dtype=np.uint8)
new_ds = Dataset(
Expand All @@ -113,10 +118,14 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
)
}
)
return conditional_merge_datasets(ds, new_ds, merge)
return variables.validate(
conditional_merge_datasets(ds, new_ds, merge), "call_allele_count"
)


def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
def count_variant_alleles(
ds: Dataset, merge: bool = True, *, call_genotype: str = "call_genotype"
) -> Dataset:
"""Compute allele count from genotype calls.
Parameters
Expand All @@ -129,6 +138,8 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.
call_genotype
Input variable name holding call_genotype as defined by `sgkit.variables.call_genotype`
Returns
-------
Expand Down Expand Up @@ -160,28 +171,34 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
{
"variant_allele_count": (
("variants", "alleles"),
count_call_alleles(ds)["call_allele_count"].sum(dim="samples"),
count_call_alleles(ds, call_genotype=call_genotype)[
"call_allele_count"
].sum(dim="samples"),
)
}
)
return conditional_merge_datasets(ds, new_ds, merge)
return variables.validate(
conditional_merge_datasets(ds, new_ds, merge), "variant_allele_count"
)


def _swap(dim: Dimension) -> Dimension:
return "samples" if dim == "variants" else "variants"


def call_rate(ds: Dataset, dim: Dimension) -> Dataset:
def call_rate(ds: Dataset, dim: Dimension, call_genotype_mask: str) -> Dataset:
odim = _swap(dim)[:-1]
n_called = (~ds["call_genotype_mask"].any(dim="ploidy")).sum(dim=dim)
n_called = (~ds[call_genotype_mask].any(dim="ploidy")).sum(dim=dim)
return xr.Dataset(
{f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]}
)


def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
def genotype_count(
ds: Dataset, dim: Dimension, call_genotype: str, call_genotype_mask: str
) -> Dataset:
odim = _swap(dim)[:-1]
M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"]
M, G = ds[call_genotype_mask].any(dim="ploidy"), ds[call_genotype]
n_hom_ref = (G == 0).all(dim="ploidy")
n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy")
n_non_ref = (G > 0).any(dim="ploidy")
Expand All @@ -198,16 +215,24 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset:
)


def allele_frequency(ds: Dataset) -> Dataset:
def allele_frequency(
ds: Dataset,
call_genotype: str,
call_genotype_mask: str,
variant_allele_count: Optional[str],
) -> Dataset:
data_vars: Dict[Hashable, Any] = {}
# only compute variant allele count if not already in dataset
if "variant_allele_count" in ds:
AC = ds["variant_allele_count"]
if variant_allele_count is not None:
variables.validate(ds, {variant_allele_count: variables.variant_allele_count})
AC = ds[variant_allele_count]
else:
AC = count_variant_alleles(ds, merge=False)["variant_allele_count"]
AC = count_variant_alleles(ds, merge=False, call_genotype=call_genotype)[
"variant_allele_count"
]
data_vars["variant_allele_count"] = AC

M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy"))
M = ds[call_genotype_mask].stack(calls=("samples", "ploidy"))
AN = (~M).sum(dim="calls") # type: ignore
assert AN.shape == (ds.dims["variants"],)

Expand All @@ -216,14 +241,30 @@ def allele_frequency(ds: Dataset) -> Dataset:
return Dataset(data_vars)


def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
def variant_stats(
ds: Dataset,
*,
call_genotype_mask: str = "call_genotype_mask",
call_genotype: str = "call_genotype",
variant_allele_count: Optional[str] = None,
merge: bool = True,
) -> Dataset:
"""Compute quality control variant statistics from genotype calls.
Parameters
----------
ds
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
call_genotype
Input variable name holding call_genotype.
As defined by `sgkit.variables.call_genotype`.
call_genotype_mask
Input variable name holding call_genotype_mask.
As defined by `sgkit.variables.call_genotype_mask`
variant_allele_count
Optional name of the input variable holding variant_allele_count,
as defined by `sgkit.variables.variant_allele_count`.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand All @@ -244,11 +285,30 @@ def variant_stats(ds: Dataset, merge: bool = True) -> Dataset:
- `variant_allele_total` (variants): The number of occurrences of all alleles.
- `variant_allele_frequency` (variants, alleles): The frequency of occurence of each allele.
"""
variables.validate(
ds,
{
call_genotype: variables.call_genotype,
call_genotype_mask: variables.call_genotype_mask,
},
)
new_ds = xr.merge(
[
call_rate(ds, dim="samples"),
genotype_count(ds, dim="samples"),
allele_frequency(ds),
call_rate(ds, dim="samples", call_genotype_mask=call_genotype_mask),
genotype_count(
ds,
dim="samples",
call_genotype=call_genotype,
call_genotype_mask=call_genotype_mask,
),
allele_frequency(
ds,
call_genotype=call_genotype,
call_genotype_mask=call_genotype_mask,
variant_allele_count=variant_allele_count,
),
]
)
return conditional_merge_datasets(ds, new_ds, merge)
return variables.validate(
conditional_merge_datasets(ds, new_ds, merge), *new_ds.variables.keys()
)
31 changes: 17 additions & 14 deletions sgkit/stats/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dask.array import Array, stats
from xarray import Dataset

from .. import variables
from ..typing import ArrayLike
from ..utils import conditional_merge_datasets
from .utils import concat_2d
Expand Down Expand Up @@ -136,21 +137,14 @@ def gwas_linear_regression(
ds
Dataset containing necessary dependent and independent variables.
dosage
Dosage variable name where "dosage" array can contain represent
one of several possible quantities, e.g.:
- Alternate allele counts
- Recessive or dominant allele encodings
- True dosages as computed from imputed or probabilistic variant calls
- Any other custom encoding in a user-defined variable
Name of genetic dosage variable.
As defined by `sgkit.variables.dosage`.
covariates
Covariate variable names, must correspond to 1 or 2D dataset
variables of shape (samples[, covariates]). All covariate arrays
will be concatenated along the second axis (columns).
Names of covariate variables (1D or 2D).
As defined by `sgkit.variables.covariates`.
traits
Trait (e.g. phenotype) variable names, must all be continuous and
correspond to 1 or 2D dataset variables of shape (samples[, traits]).
2D trait arrays will be assumed to contain separate traits within columns
and concatenated to any 1D traits along the second axis (columns).
Names of trait variables (1D or 2D).
As defined by `sgkit.variables.traits`.
add_intercept
Add intercept term to covariate set, by default True.
merge
Expand Down Expand Up @@ -197,6 +191,13 @@ def gwas_linear_regression(
if isinstance(traits, str):
traits = [traits]

variables.validate(
ds,
{dosage: variables.dosage},
{c: variables.covariates for c in covariates},
{t: variables.traits for t in traits},
)

G = _get_loop_covariates(ds, dosage=dosage)

X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
Expand All @@ -220,4 +221,6 @@ def gwas_linear_regression(
"variant_p_value": (("variants", "traits"), res.p_value),
}
)
return conditional_merge_datasets(ds, new_ds, merge)
return variables.validate(
conditional_merge_datasets(ds, new_ds, merge), *new_ds.variables.keys()
)
26 changes: 23 additions & 3 deletions sgkit/stats/hwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from numpy import ndarray
from xarray import Dataset

from sgkit import variables
from sgkit.utils import conditional_merge_datasets


Expand Down Expand Up @@ -124,7 +125,10 @@ def hardy_weinberg_p_value_vec(

def hardy_weinberg_test(
ds: Dataset,
*,
genotype_counts: Optional[Hashable] = None,
call_genotype: str = "call_genotype",
call_genotype_mask: str = "call_genotype_mask",
merge: bool = True,
) -> Dataset:
"""Exact test for HWE as described in Wigginton et al. 2005 [1].
Expand All @@ -140,6 +144,12 @@ def hardy_weinberg_test(
where `N` is equal to the number of variants and the 3 columns contain
heterozygous, homozygous reference, and homozygous alternate counts
(in that order) across all samples for a variant.
call_genotype
Input variable name holding call_genotype.
As defined by `sgkit.variables.call_genotype`.
call_genotype_mask
Input variable name holding call_genotype_mask.
As defined by `sgkit.variables.call_genotype_mask`
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
Expand Down Expand Up @@ -175,15 +185,25 @@ def hardy_weinberg_test(
raise NotImplementedError("HWE test only implemented for biallelic genotypes")
# Use precomputed genotype counts if provided
if genotype_counts is not None:
variables.validate(ds, {genotype_counts: variables.genotype_counts})
obs = list(da.asarray(ds[genotype_counts]).T)
# Otherwise compute genotype counts from calls
else:
variables.validate(
ds,
{
call_genotype_mask: variables.call_genotype_mask,
call_genotype: variables.call_genotype,
},
)
# TODO: Use API genotype counting function instead, e.g.
# https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069
M = ds["call_genotype_mask"].any(dim="ploidy")
AC = xr.where(M, -1, ds["call_genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call]
M = ds[call_genotype_mask].any(dim="ploidy")
AC = xr.where(M, -1, ds[call_genotype].sum(dim="ploidy")) # type: ignore[no-untyped-call]
cts = [1, 0, 2] # arg order: hets, hom1, hom2
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
new_ds = xr.Dataset({"variant_hwe_p_value": ("variants", p)})
return conditional_merge_datasets(ds, new_ds, merge)
return variables.validate(
conditional_merge_datasets(ds, new_ds, merge), "variant_hwe_p_value"
)
Loading

0 comments on commit 54b56f1

Please sign in to comment.