diff --git a/sgkit/model.py b/sgkit/model.py index 409a57a97..4a1dd5d62 100644 --- a/sgkit/model.py +++ b/sgkit/model.py @@ -3,6 +3,7 @@ import numpy as np import xarray as xr +from . import variables from .typing import ArrayLike from .utils import check_array_like @@ -57,11 +58,6 @@ def create_genotype_call_dataset( ------- The dataset of genotype calls. """ - check_array_like(variant_contig, kind="i", ndim=1) - check_array_like(variant_position, kind="i", ndim=1) - check_array_like(variant_alleles, kind={"S", "O"}, ndim=2) - check_array_like(sample_id, kind={"U", "O"}, ndim=1) - check_array_like(call_genotype, kind="i", ndim=3) data_vars: Dict[Hashable, Any] = { "variant_contig": ([DIM_VARIANT], variant_contig), "variant_position": ([DIM_VARIANT], variant_position), @@ -83,7 +79,9 @@ def create_genotype_call_dataset( check_array_like(variant_id, kind={"U", "O"}, ndim=1) data_vars["variant_id"] = ([DIM_VARIANT], variant_id) attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names} - return xr.Dataset(data_vars=data_vars, attrs=attrs) + return variables.validate( + xr.Dataset(data_vars=data_vars, attrs=attrs), *data_vars.keys() + ) def create_genotype_dosage_dataset( @@ -132,12 +130,6 @@ def create_genotype_dosage_dataset( The dataset of genotype calls. """ - check_array_like(variant_contig, kind="i", ndim=1) - check_array_like(variant_position, kind="i", ndim=1) - check_array_like(variant_alleles, kind={"S", "O"}, ndim=2) - check_array_like(sample_id, kind={"U", "O"}, ndim=1) - check_array_like(call_dosage, kind="f", ndim=2) - check_array_like(call_genotype_probability, kind="f", ndim=3) data_vars: Dict[Hashable, Any] = { "variant_contig": ([DIM_VARIANT], variant_contig), "variant_position": ([DIM_VARIANT], variant_position), @@ -158,4 +150,6 @@ def create_genotype_dosage_dataset( check_array_like(variant_id, kind={"U", "O"}, ndim=1) data_vars["variant_id"] = ([DIM_VARIANT], variant_id) attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names} - return xr.Dataset(data_vars=data_vars, attrs=attrs) + return variables.validate( + xr.Dataset(data_vars=data_vars, attrs=attrs), *data_vars.keys() + ) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 0cc428a4f..058f28175 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Hashable +from typing import Any, Dict, Hashable, Optional import dask.array as da import numpy as np @@ -7,6 +7,7 @@ from typing_extensions import Literal from xarray import Dataset +from sgkit import variables from sgkit.typing import ArrayLike from sgkit.utils import conditional_merge_datasets @@ -51,7 +52,9 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: out[a] += 1 -def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: +def count_call_alleles( + ds: Dataset, merge: bool = True, *, call_genotype: str = "call_genotype" +) -> Dataset: """Compute per sample allele counts from genotype calls. Parameters @@ -60,11 +63,12 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. merge - (optional) If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only the computed output variables. See :ref:`dataset_merge` for more details. + call_genotype + Input variable name holding call_genotype as defined by `sgkit.variables.call_genotype` Returns ------- @@ -99,8 +103,9 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: [[2, 0], [2, 0]]], dtype=uint8) """ + variables.validate(ds, {call_genotype: variables.call_genotype}) n_alleles = ds.dims["alleles"] - G = da.asarray(ds["call_genotype"]) + G = da.asarray(ds[call_genotype]) shape = (G.chunks[0], G.chunks[1], n_alleles) N = da.empty(n_alleles, dtype=np.uint8) new_ds = Dataset( @@ -113,10 +118,14 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: ) } ) - return conditional_merge_datasets(ds, new_ds, merge) + return variables.validate( + conditional_merge_datasets(ds, new_ds, merge), "call_allele_count" + ) -def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: +def count_variant_alleles( + ds: Dataset, merge: bool = True, *, call_genotype: str = "call_genotype" +) -> Dataset: """Compute allele count from genotype calls. Parameters @@ -129,6 +138,8 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: output variables into a single dataset, otherwise return only the computed output variables. See :ref:`dataset_merge` for more details. + call_genotype + Input variable name holding call_genotype as defined by `sgkit.variables.call_genotype` Returns ------- @@ -160,28 +171,34 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: { "variant_allele_count": ( ("variants", "alleles"), - count_call_alleles(ds)["call_allele_count"].sum(dim="samples"), + count_call_alleles(ds, call_genotype=call_genotype)[ + "call_allele_count" + ].sum(dim="samples"), ) } ) - return conditional_merge_datasets(ds, new_ds, merge) + return variables.validate( + conditional_merge_datasets(ds, new_ds, merge), "variant_allele_count" + ) def _swap(dim: Dimension) -> Dimension: return "samples" if dim == "variants" else "variants" -def call_rate(ds: Dataset, dim: Dimension) -> Dataset: +def call_rate(ds: Dataset, dim: Dimension, call_genotype_mask: str) -> Dataset: odim = _swap(dim)[:-1] - n_called = (~ds["call_genotype_mask"].any(dim="ploidy")).sum(dim=dim) + n_called = (~ds[call_genotype_mask].any(dim="ploidy")).sum(dim=dim) return xr.Dataset( {f"{odim}_n_called": n_called, f"{odim}_call_rate": n_called / ds.dims[dim]} ) -def genotype_count(ds: Dataset, dim: Dimension) -> Dataset: +def genotype_count( + ds: Dataset, dim: Dimension, call_genotype: str, call_genotype_mask: str +) -> Dataset: odim = _swap(dim)[:-1] - M, G = ds["call_genotype_mask"].any(dim="ploidy"), ds["call_genotype"] + M, G = ds[call_genotype_mask].any(dim="ploidy"), ds[call_genotype] n_hom_ref = (G == 0).all(dim="ploidy") n_hom_alt = ((G > 0) & (G[..., 0] == G)).all(dim="ploidy") n_non_ref = (G > 0).any(dim="ploidy") @@ -198,16 +215,24 @@ def genotype_count(ds: Dataset, dim: Dimension) -> Dataset: ) -def allele_frequency(ds: Dataset) -> Dataset: +def allele_frequency( + ds: Dataset, + call_genotype: str, + call_genotype_mask: str, + variant_allele_count: Optional[str], +) -> Dataset: data_vars: Dict[Hashable, Any] = {} # only compute variant allele count if not already in dataset - if "variant_allele_count" in ds: - AC = ds["variant_allele_count"] + if variant_allele_count is not None: + variables.validate(ds, {variant_allele_count: variables.variant_allele_count}) + AC = ds[variant_allele_count] else: - AC = count_variant_alleles(ds, merge=False)["variant_allele_count"] + AC = count_variant_alleles(ds, merge=False, call_genotype=call_genotype)[ + "variant_allele_count" + ] data_vars["variant_allele_count"] = AC - M = ds["call_genotype_mask"].stack(calls=("samples", "ploidy")) + M = ds[call_genotype_mask].stack(calls=("samples", "ploidy")) AN = (~M).sum(dim="calls") # type: ignore assert AN.shape == (ds.dims["variants"],) @@ -216,7 +241,14 @@ def allele_frequency(ds: Dataset) -> Dataset: return Dataset(data_vars) -def variant_stats(ds: Dataset, merge: bool = True) -> Dataset: +def variant_stats( + ds: Dataset, + *, + call_genotype_mask: str = "call_genotype_mask", + call_genotype: str = "call_genotype", + variant_allele_count: Optional[str] = None, + merge: bool = True, +) -> Dataset: """Compute quality control variant statistics from genotype calls. Parameters @@ -224,6 +256,15 @@ def variant_stats(ds: Dataset, merge: bool = True) -> Dataset: ds Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. + call_genotype + Input variable name holding call_genotype. + As defined by `sgkit.variables.call_genotype`. + call_genotype_mask + Input variable name holding call_genotype_mask. + As defined by `sgkit.variables.call_genotype_mask` + variant_allele_count + Optional name of the input variable holding variant_allele_count, + as defined by `sgkit.variables.variant_allele_count`. merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -244,11 +285,30 @@ def variant_stats(ds: Dataset, merge: bool = True) -> Dataset: - `variant_allele_total` (variants): The number of occurrences of all alleles. - `variant_allele_frequency` (variants, alleles): The frequency of occurence of each allele. """ + variables.validate( + ds, + { + call_genotype: variables.call_genotype, + call_genotype_mask: variables.call_genotype_mask, + }, + ) new_ds = xr.merge( [ - call_rate(ds, dim="samples"), - genotype_count(ds, dim="samples"), - allele_frequency(ds), + call_rate(ds, dim="samples", call_genotype_mask=call_genotype_mask), + genotype_count( + ds, + dim="samples", + call_genotype=call_genotype, + call_genotype_mask=call_genotype_mask, + ), + allele_frequency( + ds, + call_genotype=call_genotype, + call_genotype_mask=call_genotype_mask, + variant_allele_count=variant_allele_count, + ), ] ) - return conditional_merge_datasets(ds, new_ds, merge) + return variables.validate( + conditional_merge_datasets(ds, new_ds, merge), *new_ds.variables.keys() + ) diff --git a/sgkit/stats/association.py b/sgkit/stats/association.py index ba9bf3eb4..b78351ad8 100644 --- a/sgkit/stats/association.py +++ b/sgkit/stats/association.py @@ -7,6 +7,7 @@ from dask.array import Array, stats from xarray import Dataset +from .. import variables from ..typing import ArrayLike from ..utils import conditional_merge_datasets from .utils import concat_2d @@ -136,21 +137,14 @@ def gwas_linear_regression( ds Dataset containing necessary dependent and independent variables. dosage - Dosage variable name where "dosage" array can contain represent - one of several possible quantities, e.g.: - - Alternate allele counts - - Recessive or dominant allele encodings - - True dosages as computed from imputed or probabilistic variant calls - - Any other custom encoding in a user-defined variable + Name of genetic dosage variable. + As defined by `sgkit.variables.dosage`. covariates - Covariate variable names, must correspond to 1 or 2D dataset - variables of shape (samples[, covariates]). All covariate arrays - will be concatenated along the second axis (columns). + Names of covariate variables (1D or 2D). + As defined by `sgkit.variables.covariates`. traits - Trait (e.g. phenotype) variable names, must all be continuous and - correspond to 1 or 2D dataset variables of shape (samples[, traits]). - 2D trait arrays will be assumed to contain separate traits within columns - and concatenated to any 1D traits along the second axis (columns). + Names of trait variables (1D or 2D). + As defined by `sgkit.variables.traits`. add_intercept Add intercept term to covariate set, by default True. merge @@ -197,6 +191,13 @@ def gwas_linear_regression( if isinstance(traits, str): traits = [traits] + variables.validate( + ds, + {dosage: variables.dosage}, + {c: variables.covariates for c in covariates}, + {t: variables.traits for t in traits}, + ) + G = _get_loop_covariates(ds, dosage=dosage) X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates"))) @@ -220,4 +221,6 @@ def gwas_linear_regression( "variant_p_value": (("variants", "traits"), res.p_value), } ) - return conditional_merge_datasets(ds, new_ds, merge) + return variables.validate( + conditional_merge_datasets(ds, new_ds, merge), *new_ds.variables.keys() + ) diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index 61cff16d9..93230df02 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -7,6 +7,7 @@ from numpy import ndarray from xarray import Dataset +from sgkit import variables from sgkit.utils import conditional_merge_datasets @@ -124,7 +125,10 @@ def hardy_weinberg_p_value_vec( def hardy_weinberg_test( ds: Dataset, + *, genotype_counts: Optional[Hashable] = None, + call_genotype: str = "call_genotype", + call_genotype_mask: str = "call_genotype_mask", merge: bool = True, ) -> Dataset: """Exact test for HWE as described in Wigginton et al. 2005 [1]. @@ -140,6 +144,12 @@ def hardy_weinberg_test( where `N` is equal to the number of variants and the 3 columns contain heterozygous, homozygous reference, and homozygous alternate counts (in that order) across all samples for a variant. + call_genotype + Input variable name holding call_genotype. + As defined by `sgkit.variables.call_genotype`. + call_genotype_mask + Input variable name holding call_genotype_mask. + As defined by `sgkit.variables.call_genotype_mask` merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -175,15 +185,25 @@ def hardy_weinberg_test( raise NotImplementedError("HWE test only implemented for biallelic genotypes") # Use precomputed genotype counts if provided if genotype_counts is not None: + variables.validate(ds, {genotype_counts: variables.genotype_counts}) obs = list(da.asarray(ds[genotype_counts]).T) # Otherwise compute genotype counts from calls else: + variables.validate( + ds, + { + call_genotype_mask: variables.call_genotype_mask, + call_genotype: variables.call_genotype, + }, + ) # TODO: Use API genotype counting function instead, e.g. # https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069 - M = ds["call_genotype_mask"].any(dim="ploidy") - AC = xr.where(M, -1, ds["call_genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call] + M = ds[call_genotype_mask].any(dim="ploidy") + AC = xr.where(M, -1, ds[call_genotype].sum(dim="ploidy")) # type: ignore[no-untyped-call] cts = [1, 0, 2] # arg order: hets, hom1, hom2 obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts] p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs) new_ds = xr.Dataset({"variant_hwe_p_value": ("variants", p)}) - return conditional_merge_datasets(ds, new_ds, merge) + return variables.validate( + conditional_merge_datasets(ds, new_ds, merge), "variant_hwe_p_value" + ) diff --git a/sgkit/stats/pc_relate.py b/sgkit/stats/pc_relate.py index b7d2b9133..162ff11b2 100644 --- a/sgkit/stats/pc_relate.py +++ b/sgkit/stats/pc_relate.py @@ -3,6 +3,7 @@ import dask.array as da import xarray as xr +from sgkit import variables from sgkit.typing import ArrayLike from sgkit.utils import conditional_merge_datasets @@ -21,13 +22,25 @@ def _impute_genotype_call_with_variant_mean( return imputed_call_g -def _collapse_ploidy(ds: xr.Dataset) -> Tuple[xr.DataArray, xr.DataArray]: - call_g_mask = ds["call_genotype_mask"].any(dim="ploidy") - call_g = xr.where(call_g_mask, -1, ds["call_genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call] +def _collapse_ploidy( + ds: xr.Dataset, + call_genotype: str = "call_genotype", + call_genotype_mask: str = "call_genotype_mask", +) -> Tuple[xr.DataArray, xr.DataArray]: + call_g_mask = ds[call_genotype_mask].any(dim="ploidy") + call_g = xr.where(call_g_mask, -1, ds[call_genotype].sum(dim="ploidy")) # type: ignore[no-untyped-call] return call_g, call_g_mask -def pc_relate(ds: xr.Dataset, maf: float = 0.01, merge: bool = True) -> xr.Dataset: +def pc_relate( + ds: xr.Dataset, + maf: float = 0.01, + *, + call_genotype: str = "call_genotype", + call_genotype_mask: str = "call_genotype_mask", + sample_pcs: str = "sample_pcs", + merge: bool = True +) -> xr.Dataset: """Compute PC-Relate as described in Conomos, et al. 2016 [1]. This method computes the kinship coefficient matrix. The kinship coefficient for @@ -59,9 +72,9 @@ def pc_relate(ds: xr.Dataset, maf: float = 0.01, merge: bool = True) -> xr.Datas ds Dataset containing (S = num samples, V = num variants, D = ploidy, PC = num PC) - - genotype calls: "call_genotype" (SxVxD) - - genotype calls mask: "call_genotype_mask" (SxVxD) - - sample PCs: "sample_pcs" (PCxS) + - genotype calls: (SxVxD) + - genotype calls mask: (SxVxD) + - sample PCs: (PCxS) maf individual minor allele frequency filter. If an individual's estimated individual-specific minor allele frequency at a SNP is less than this value, @@ -72,6 +85,15 @@ def pc_relate(ds: xr.Dataset, maf: float = 0.01, merge: bool = True) -> xr.Datas output variables into a single dataset, otherwise return only the computed output variables. See :ref:`dataset_merge` for more details. + call_genotype + Input variable name holding call_genotype. + As defined by `sgkit.variables.call_genotype`. + call_genotype_mask + Input variable name holding call_genotype_mask. + As defined by `sgkit.variables.call_genotype_mask` + sample_pcs + Input variable name holding sample_pcs. + As defined by `sgkit.variables.sample_pcs` Warnings -------- @@ -106,19 +128,21 @@ def pc_relate(ds: xr.Dataset, maf: float = 0.01, merge: bool = True) -> xr.Datas raise ValueError("PC Relate only works for diploid genotypes") if "alleles" in ds.dims and ds.dims["alleles"] != 2: raise ValueError("PC Relate only works for biallelic genotypes") - if "call_genotype" not in ds: - raise ValueError("Input dataset must contain call_genotype") - if "call_genotype_mask" not in ds: - raise ValueError("Input dataset must contain call_genotype_mask") - if "sample_pcs" not in ds: - raise ValueError("Input dataset must contain sample_pcs variable") - - call_g, call_g_mask = _collapse_ploidy(ds) + variables.validate( + ds, + { + call_genotype: variables.call_genotype, + call_genotype_mask: variables.call_genotype_mask, + sample_pcs: variables.sample_pcs, + }, + ) + + call_g, call_g_mask = _collapse_ploidy(ds, call_genotype, call_genotype_mask) imputed_call_g = _impute_genotype_call_with_variant_mean(call_g, call_g_mask) # 𝔼[gs|V] = 1β0 + Vβ, where 1 is a length _s_ vector of 1s, and β = (β1,...,βD)^T # is a length D vector of regression coefficients for each of the PCs - pcs = ds["sample_pcs"] + pcs = ds[sample_pcs] pcsi = da.concatenate([da.ones((1, pcs.shape[1]), dtype=pcs.dtype), pcs], axis=0) # Note: dask qr decomp requires no chunking in one dimension, and because number of # components should be smaller than number of samples in most cases, we disable @@ -147,4 +171,6 @@ def pc_relate(ds: xr.Dataset, maf: float = 0.01, merge: bool = True) -> xr.Datas # NOTE: phi is of shape (S x S), S = num samples assert phi.shape == (call_g.shape[1],) * 2 new_ds = xr.Dataset({"pc_relate_phi": (("sample_x", "sample_y"), phi)}) - return conditional_merge_datasets(ds, new_ds, merge) + return variables.validate( + conditional_merge_datasets(ds, new_ds, merge), "pc_relate_phi" + ) diff --git a/sgkit/stats/popgen.py b/sgkit/stats/popgen.py index 416732bb2..ca41496e5 100644 --- a/sgkit/stats/popgen.py +++ b/sgkit/stats/popgen.py @@ -1,16 +1,17 @@ -from typing import Hashable - import dask.array as da import numpy as np import xarray as xr from xarray import DataArray, Dataset +from .. import variables from .aggregation import count_variant_alleles def diversity( ds: Dataset, - allele_counts: Hashable = "variant_allele_count", + *, + call_genotype: str = "call_genotype", + allele_counts: str = "variant_allele_count", ) -> DataArray: """Compute diversity from allele counts. @@ -27,8 +28,12 @@ def diversity( ---------- ds Genotype call dataset. + call_genotype + Input variable name holding call_genotype. + As defined by `sgkit.variables.call_genotype`. allele_counts - allele counts to use or calculate. + allele counts to use or calculate, as defined by + `variables.variant_allele_count` Returns ------- @@ -37,8 +42,10 @@ def diversity( if len(ds.samples) < 2: return xr.DataArray(np.nan) if allele_counts not in ds: - ds_new = count_variant_alleles(ds) + variables.validate(ds, {call_genotype: variables.call_genotype}) + ds_new = count_variant_alleles(ds, call_genotype=call_genotype) else: + variables.validate(ds, {allele_counts: variables.variant_allele_count}) ds_new = ds ac = ds_new[allele_counts] an = ac.sum(axis=1) @@ -52,7 +59,9 @@ def diversity( def divergence( ds1: Dataset, ds2: Dataset, - allele_counts: Hashable = "variant_allele_count", + *, + call_genotype: str = "call_genotype", + allele_counts: str = "variant_allele_count", ) -> DataArray: """Compute divergence between two genotype call datasets. @@ -62,21 +71,29 @@ def divergence( Genotype call dataset. ds2 Genotype call dataset. + call_genotype + Input variable name holding call_genotype. + As defined by `sgkit.variables.call_genotype`. allele_counts - allele counts to use or calculate. + allele counts to use or calculate, as defined by + `variables.variant_allele_count` Returns ------- divergence value between the two datasets. """ if allele_counts not in ds1: + variables.validate(ds1, {call_genotype: variables.call_genotype}) ds1_new = count_variant_alleles(ds1) else: + variables.validate(ds1, {allele_counts: variables.variant_allele_count}) ds1_new = ds1 ac1 = ds1_new[allele_counts] if allele_counts not in ds2: + variables.validate(ds2, {call_genotype: variables.call_genotype}) ds2_new = count_variant_alleles(ds2) else: + variables.validate(ds2, {allele_counts: variables.variant_allele_count}) ds2_new = ds2 ac2 = ds2_new[allele_counts] an1 = ds1_new[allele_counts].sum(axis=1) @@ -92,7 +109,9 @@ def divergence( def Fst( ds1: Dataset, ds2: Dataset, - allele_counts: Hashable = "variant_allele_count", + *, + call_genotype: str = "call_genotype", + allele_counts: str = "variant_allele_count", ) -> DataArray: """Compute Fst between two genotype call datasets. @@ -102,15 +121,21 @@ def Fst( Genotype call dataset. ds2 Genotype call dataset. + call_genotype + Input variable name holding call_genotype. + As defined by `sgkit.variables.call_genotype`. allele_counts - allele counts to use or calculate. + allele counts to use or calculate, as defined by + `variables.variant_allele_count` Returns ------- fst value between the two datasets. """ - total_div = diversity(ds1) + diversity(ds2) - gs = divergence(ds1, ds2) + total_div = diversity( + ds1, call_genotype=call_genotype, allele_counts=allele_counts + ) + diversity(ds2, call_genotype=call_genotype, allele_counts=allele_counts) + gs = divergence(ds1, ds2, call_genotype=call_genotype, allele_counts=allele_counts) den = total_div + 2 * gs # type: ignore[operator] fst = 1 - (2 * total_div / den) return fst # type: ignore[no-any-return] @@ -118,7 +143,9 @@ def Fst( def Tajimas_D( ds: Dataset, - allele_counts: Hashable = "variant_allele_count", + *, + call_genotype: str = "call_genotype", + allele_counts: str = "variant_allele_count", ) -> DataArray: """Compute Tajimas' D for a genotype call dataset. @@ -126,8 +153,12 @@ def Tajimas_D( ---------- ds Genotype call dataset. + call_genotype + Input variable name holding call_genotype. + As defined by `sgkit.variables.call_genotype`. allele_counts - allele counts to use or calculate. + allele counts to use or calculate, as defined by + `variables.variant_allele_count` Returns ------- @@ -135,8 +166,10 @@ def Tajimas_D( """ if allele_counts not in ds: + variables.validate(ds, {call_genotype: variables.call_genotype}) ds_new = count_variant_alleles(ds) else: + variables.validate(ds, {allele_counts: variables.variant_allele_count}) ds_new = ds ac = ds_new[allele_counts] diff --git a/sgkit/stats/regenie.py b/sgkit/stats/regenie.py index 576cf5f28..946e8b484 100644 --- a/sgkit/stats/regenie.py +++ b/sgkit/stats/regenie.py @@ -7,6 +7,7 @@ from numpy import ndarray from xarray import Dataset +from .. import variables from ..typing import ArrayLike from ..utils import conditional_merge_datasets, split_array_chunks from .utils import ( @@ -730,6 +731,7 @@ def regenie( dosage: str, covariates: Union[str, Sequence[str]], traits: Union[str, Sequence[str]], + variant_contig: str = "variant_contig", variant_block_size: Optional[Union[int, Tuple[int, ...]]] = None, sample_block_size: Optional[Union[int, Tuple[int, ...]]] = None, alphas: Optional[Sequence[float]] = None, @@ -753,10 +755,16 @@ def regenie( ---------- dosage Name of genetic dosage variable. + As defined by `sgkit.variables.dosage`. covariates Names of covariate variables (1D or 2D). + As defined by `sgkit.variables.covariates`. traits Names of trait variables (1D or 2D). + As defined by `sgkit.variables.traits`. + variant_contig + Name of the variant contig input variable. + As definied by `sgkit.variables.variant_contig`. variant_block_size Number of variants in each block. If int, this describes the number of variants in each block @@ -800,16 +808,18 @@ def regenie( A dataset containing the following variables: - `base_prediction` (blocks, alphas, samples, outcomes): Stage 1 - predictions from ridge regression reduction . + predictions from ridge regression reduction. As defined by + `sgkit.variables.base_prediction`. - `meta_prediction` (samples, outcomes): Stage 2 predictions from the best meta estimator trained on the out-of-sample Stage 1 - predictions. + predictions. As defined by `sgkit.variables.meta_prediction`. - `loco_prediction` (contigs, samples, outcomes): LOCO predictions resulting from Stage 2 predictions ignoring effects for variant blocks on held out contigs. This will be absent if the - data provided does not contain at least 2 contigs. + data provided does not contain at least 2 contigs. As defined by + `sgkit.variables.loco_prediction`. Raises ------ @@ -851,10 +861,18 @@ def regenie( covariates = [covariates] if isinstance(traits, str): traits = [traits] + + variables.validate( + ds, + {dosage: variables.dosage, variant_contig: variables.variant_contig}, + {c: variables.covariates for c in covariates}, + {t: variables.traits for t in traits}, + ) + G = ds[dosage] X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates"))) Y = da.asarray(concat_2d(ds[list(traits)], dims=("samples", "traits"))) - contigs = ds["variant_contig"] + contigs = ds[variant_contig] new_ds = regenie_transform( G.T, X, @@ -868,4 +886,6 @@ def regenie( orthogonalize=orthogonalize, **kwargs, ) - return conditional_merge_datasets(ds, new_ds, merge) + return variables.validate( + conditional_merge_datasets(ds, new_ds, merge), *new_ds.variables.keys() + ) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 8951f842b..b4a069155 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -216,7 +216,9 @@ def test_variant_stats(precompute_variant_allele_count): ) if precompute_variant_allele_count: ds = count_variant_alleles(ds) - vs = variant_stats(ds) + vs = variant_stats(ds, variant_allele_count="variant_allele_count") + else: + vs = variant_stats(ds) np.testing.assert_equal(vs["variant_n_called"], np.array([1, 2, 2, 1])) np.testing.assert_equal(vs["variant_call_rate"], np.array([0.5, 1.0, 1.0, 0.5])) diff --git a/sgkit/tests/test_pc_relate.py b/sgkit/tests/test_pc_relate.py index 8655f18e0..3817325ae 100644 --- a/sgkit/tests/test_pc_relate.py +++ b/sgkit/tests/test_pc_relate.py @@ -27,17 +27,13 @@ def test_pc_relate__genotype_inputs_checks() -> None: pc_relate(g_non_biallelic) g_no_pcs = simulate_genotype_call_dataset(100, 10) - with pytest.raises( - ValueError, match="Input dataset must contain sample_pcs variable" - ): + with pytest.raises(ValueError, match="sample_pcs not present"): pc_relate(g_no_pcs) - with pytest.raises(ValueError, match="Input dataset must contain call_genotype"): + with pytest.raises(ValueError, match="call_genotype not present"): pc_relate(g_no_pcs.drop_vars("call_genotype")) - with pytest.raises( - ValueError, match="Input dataset must contain call_genotype_mask" - ): + with pytest.raises(ValueError, match="call_genotype_mask not present"): pc_relate(g_no_pcs.drop_vars("call_genotype_mask"))