diff --git a/docs/api.rst b/docs/api.rst index cba4d3737..3ef72a174 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -54,41 +54,41 @@ Variables .. autosummary:: :toctree: generated/ - variables.base_prediction - variables.call_allele_count - variables.call_dosage - variables.call_dosage_mask - variables.call_genotype - variables.call_genotype_mask - variables.call_genotype_phased - variables.call_genotype_probability - variables.call_genotype_probability_mask - variables.covariates - variables.dosage - variables.genotype_counts - variables.loco_prediction - variables.meta_prediction - variables.pc_relate_phi - variables.sample_id - variables.sample_pcs - variables.traits - variables.variant_allele - variables.variant_allele_count - variables.variant_allele_frequency - variables.variant_allele_total - variables.variant_beta - variables.variant_call_rate - variables.variant_contig - variables.variant_hwe_p_value - variables.variant_id - variables.variant_n_called - variables.variant_n_het - variables.variant_n_hom_alt - variables.variant_n_hom_ref - variables.variant_n_non_ref - variables.variant_p_value - variables.variant_position - variables.variant_t_value + variables.base_prediction_spec + variables.call_allele_count_spec + variables.call_dosage_spec + variables.call_dosage_mask_spec + variables.call_genotype_spec + variables.call_genotype_mask_spec + variables.call_genotype_phased_spec + variables.call_genotype_probability_spec + variables.call_genotype_probability_mask_spec + variables.covariates_spec + variables.dosage_spec + variables.genotype_counts_spec + variables.loco_prediction_spec + variables.meta_prediction_spec + variables.pc_relate_phi_spec + variables.sample_id_spec + variables.sample_pcs_spec + variables.traits_spec + variables.variant_allele_spec + variables.variant_allele_count_spec + variables.variant_allele_frequency_spec + variables.variant_allele_total_spec + variables.variant_beta_spec + variables.variant_call_rate_spec + variables.variant_contig_spec + variables.variant_hwe_p_value_spec + variables.variant_id_spec + variables.variant_n_called_spec + variables.variant_n_het_spec + variables.variant_n_hom_alt_spec + variables.variant_n_hom_ref_spec + variables.variant_n_non_ref_spec + variables.variant_p_value_spec + variables.variant_position_spec + variables.variant_t_value_spec Utilities ========= diff --git a/sgkit/model.py b/sgkit/model.py index c9be8dd3b..c88d36e0c 100644 --- a/sgkit/model.py +++ b/sgkit/model.py @@ -5,7 +5,6 @@ from . import variables from .typing import ArrayLike -from .utils import check_array_like DIM_VARIANT = "variants" DIM_SAMPLE = "samples" @@ -70,13 +69,11 @@ def create_genotype_call_dataset( ), } if call_genotype_phased is not None: - check_array_like(call_genotype_phased, kind="b", ndim=2) data_vars["call_genotype_phased"] = ( [DIM_VARIANT, DIM_SAMPLE], call_genotype_phased, ) if variant_id is not None: - check_array_like(variant_id, kind={"U", "O"}, ndim=1) data_vars["variant_id"] = ([DIM_VARIANT], variant_id) attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names} return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs)) @@ -145,7 +142,6 @@ def create_genotype_dosage_dataset( ), } if variant_id is not None: - check_array_like(variant_id, kind={"U", "O"}, ndim=1) data_vars["variant_id"] = ([DIM_VARIANT], variant_id) attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names} return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs)) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 669c3cddf..e2fd278f4 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -53,7 +53,7 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: def count_call_alleles( - ds: Dataset, *, call_genotype: str = "call_genotype", merge: bool = True + ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True ) -> Dataset: """Compute per sample allele counts from genotype calls. @@ -64,7 +64,7 @@ def count_call_alleles( :func:`sgkit.create_genotype_call_dataset`. call_genotype Input variable name holding call_genotype as defined by - :data:`sgkit.variables.call_genotype` + :data:`sgkit.variables.call_genotype_spec` merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -104,14 +104,14 @@ def count_call_alleles( [[2, 0], [2, 0]]], dtype=uint8) """ - variables.validate(ds, {call_genotype: variables.call_genotype}) + variables.validate(ds, {call_genotype: variables.call_genotype_spec}) n_alleles = ds.dims["alleles"] G = da.asarray(ds[call_genotype]) shape = (G.chunks[0], G.chunks[1], n_alleles) N = da.empty(n_alleles, dtype=np.uint8) new_ds = Dataset( { - "call_allele_count": ( + variables.call_allele_count: ( ("variants", "samples", "alleles"), da.map_blocks( count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2 @@ -123,7 +123,7 @@ def count_call_alleles( def count_variant_alleles( - ds: Dataset, *, call_genotype: str = "call_genotype", merge: bool = True + ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True ) -> Dataset: """Compute allele count from genotype calls. @@ -134,7 +134,7 @@ def count_variant_alleles( :func:`sgkit.create_genotype_call_dataset`. call_genotype Input variable name holding call_genotype as defined by - :data:`sgkit.variables.call_genotype` + :data:`sgkit.variables.call_genotype_spec` merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -169,10 +169,10 @@ def count_variant_alleles( """ new_ds = Dataset( { - "variant_allele_count": ( + variables.variant_allele_count: ( ("variants", "alleles"), count_call_alleles(ds, call_genotype=call_genotype)[ - "call_allele_count" + variables.call_allele_count ].sum(dim="samples"), ) } @@ -222,28 +222,30 @@ def allele_frequency( data_vars: Dict[Hashable, Any] = {} # only compute variant allele count if not already in dataset if variant_allele_count is not None: - variables.validate(ds, {variant_allele_count: variables.variant_allele_count}) + variables.validate( + ds, {variant_allele_count: variables.variant_allele_count_spec} + ) AC = ds[variant_allele_count] else: AC = count_variant_alleles(ds, merge=False, call_genotype=call_genotype)[ - "variant_allele_count" + variables.variant_allele_count ] - data_vars["variant_allele_count"] = AC + data_vars[variables.variant_allele_count] = AC M = ds[call_genotype_mask].stack(calls=("samples", "ploidy")) AN = (~M).sum(dim="calls") # type: ignore assert AN.shape == (ds.dims["variants"],) - data_vars["variant_allele_total"] = AN - data_vars["variant_allele_frequency"] = AC / AN + data_vars[variables.variant_allele_total] = AN + data_vars[variables.variant_allele_frequency] = AC / AN return Dataset(data_vars) def variant_stats( ds: Dataset, *, - call_genotype_mask: str = "call_genotype_mask", - call_genotype: str = "call_genotype", + call_genotype_mask: str = variables.call_genotype_mask, + call_genotype: str = variables.call_genotype, variant_allele_count: Optional[str] = None, merge: bool = True, ) -> Dataset: @@ -256,13 +258,13 @@ def variant_stats( :func:`sgkit.create_genotype_call_dataset`. call_genotype Input variable name holding call_genotype. - Defined by :data:`sgkit.variables.call_genotype`. + Defined by :data:`sgkit.variables.call_genotype_spec`. call_genotype_mask Input variable name holding call_genotype_mask. - Defined by :data:`sgkit.variables.call_genotype_mask` + Defined by :data:`sgkit.variables.call_genotype_mask_spec` variant_allele_count Optional name of the input variable holding variant_allele_count, - as defined by :data:`sgkit.variables.variant_allele_count`. + as defined by :data:`sgkit.variables.variant_allele_count_spec`. merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -273,30 +275,30 @@ def variant_stats( ------- A dataset containing the following variables: - - :data:`sgkit.variables.variant_n_called` (variants): + - :data:`sgkit.variables.variant_n_called_spec` (variants): The number of samples with called genotypes. - - :data:`sgkit.variables.variant_call_rate` (variants): + - :data:`sgkit.variables.variant_call_rate_spec` (variants): The fraction of samples with called genotypes. - - :data:`sgkit.variables.variant_n_het` (variants): + - :data:`sgkit.variables.variant_n_het_spec` (variants): The number of samples with heterozygous calls. - - :data:`sgkit.variables.variant_n_hom_ref` (variants): + - :data:`sgkit.variables.variant_n_hom_ref_spec` (variants): The number of samples with homozygous reference calls. - - :data:`sgkit.variables.variant_n_hom_alt` (variants): + - :data:`sgkit.variables.variant_n_hom_alt_spec` (variants): The number of samples with homozygous alternate calls. - - :data:`sgkit.variables.variant_n_non_ref` (variants): + - :data:`sgkit.variables.variant_n_non_ref_spec` (variants): The number of samples that are not homozygous reference calls. - - :data:`sgkit.variables.variant_allele_count` (variants, alleles): + - :data:`sgkit.variables.variant_allele_count_spec` (variants, alleles): The number of occurrences of each allele. - - :data:`sgkit.variables.variant_allele_total` (variants): + - :data:`sgkit.variables.variant_allele_total_spec` (variants): The number of occurrences of all alleles. - - :data:`sgkit.variables.variant_allele_frequency` (variants, alleles): + - :data:`sgkit.variables.variant_allele_frequency_spec` (variants, alleles): The frequency of occurrence of each allele. """ variables.validate( ds, { - call_genotype: variables.call_genotype, - call_genotype_mask: variables.call_genotype_mask, + call_genotype: variables.call_genotype_spec, + call_genotype_mask: variables.call_genotype_mask_spec, }, ) new_ds = xr.merge( diff --git a/sgkit/stats/association.py b/sgkit/stats/association.py index 5c0e0a917..0fa6a9ae4 100644 --- a/sgkit/stats/association.py +++ b/sgkit/stats/association.py @@ -104,11 +104,13 @@ def linear_regression( return LinearRegressionResult(beta=B, t_value=T, p_value=P) -def _get_loop_covariates(ds: Dataset, dosage: Optional[str] = None) -> Array: +def _get_loop_covariates( + ds: Dataset, call_genotype: str, dosage: Optional[str] = None +) -> Array: if dosage is None: # TODO: This should be (probably gwas-specific) allele # count with sex chromosome considerations - G = ds["call_genotype"].sum(dim="ploidy") # pragma: no cover + G = ds[call_genotype].sum(dim="ploidy") # pragma: no cover else: G = ds[dosage] return da.asarray(G.data) @@ -121,6 +123,7 @@ def gwas_linear_regression( covariates: Union[str, Sequence[str]], traits: Union[str, Sequence[str]], add_intercept: bool = True, + call_genotype: str = variables.call_genotype, merge: bool = True, ) -> Dataset: """Run linear regression to identify continuous trait associations with genetic variants. @@ -138,15 +141,18 @@ def gwas_linear_regression( Dataset containing necessary dependent and independent variables. dosage Name of genetic dosage variable. - Defined by :data:`sgkit.variables.dosage`. + Defined by :data:`sgkit.variables.dosage_spec`. covariates Names of covariate variables (1D or 2D). - Defined by :data:`sgkit.variables.covariates`. + Defined by :data:`sgkit.variables.covariates_spec`. traits Names of trait variables (1D or 2D). - Defined by :data:`sgkit.variables.traits`. + Defined by :data:`sgkit.variables.traits_spec`. add_intercept Add intercept term to covariate set, by default True. + call_genotype + Input variable name holding call_genotype. + Defined by :data:`sgkit.variables.call_genotype_spec`. merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -193,12 +199,12 @@ def gwas_linear_regression( variables.validate( ds, - {dosage: variables.dosage}, - {c: variables.covariates for c in covariates}, - {t: variables.traits for t in traits}, + {dosage: variables.dosage_spec}, + {c: variables.covariates_spec for c in covariates}, + {t: variables.traits_spec for t in traits}, ) - G = _get_loop_covariates(ds, dosage=dosage) + G = _get_loop_covariates(ds, dosage=dosage, call_genotype=call_genotype) X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates"))) if add_intercept: @@ -216,9 +222,9 @@ def gwas_linear_regression( res = linear_regression(G.T, X, Y) new_ds = xr.Dataset( { - "variant_beta": (("variants", "traits"), res.beta), - "variant_t_value": (("variants", "traits"), res.t_value), - "variant_p_value": (("variants", "traits"), res.p_value), + variables.variant_beta: (("variants", "traits"), res.beta), + variables.variant_t_value: (("variants", "traits"), res.t_value), + variables.variant_p_value: (("variants", "traits"), res.p_value), } ) return conditional_merge_datasets(ds, variables.validate(new_ds), merge) diff --git a/sgkit/stats/hwe.py b/sgkit/stats/hwe.py index dfde97d3c..52e9a8f6f 100644 --- a/sgkit/stats/hwe.py +++ b/sgkit/stats/hwe.py @@ -127,8 +127,8 @@ def hardy_weinberg_test( ds: Dataset, *, genotype_counts: Optional[Hashable] = None, - call_genotype: str = "call_genotype", - call_genotype_mask: str = "call_genotype_mask", + call_genotype: str = variables.call_genotype, + call_genotype_mask: str = variables.call_genotype_mask, merge: bool = True, ) -> Dataset: """Exact test for HWE as described in Wigginton et al. 2005 [1]. @@ -146,10 +146,10 @@ def hardy_weinberg_test( (in that order) across all samples for a variant. call_genotype Input variable name holding call_genotype. - Defined by :data:`sgkit.variables.call_genotype`. + Defined by :data:`sgkit.variables.call_genotype_spec`. call_genotype_mask Input variable name holding call_genotype_mask. - Defined by :data:`sgkit.variables.call_genotype_mask` + Defined by :data:`sgkit.variables.call_genotype_mask_spec` merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -185,15 +185,15 @@ def hardy_weinberg_test( raise NotImplementedError("HWE test only implemented for biallelic genotypes") # Use precomputed genotype counts if provided if genotype_counts is not None: - variables.validate(ds, {genotype_counts: variables.genotype_counts}) + variables.validate(ds, {genotype_counts: variables.genotype_counts_spec}) obs = list(da.asarray(ds[genotype_counts]).T) # Otherwise compute genotype counts from calls else: variables.validate( ds, { - call_genotype_mask: variables.call_genotype_mask, - call_genotype: variables.call_genotype, + call_genotype_mask: variables.call_genotype_mask_spec, + call_genotype: variables.call_genotype_spec, }, ) # TODO: Use API genotype counting function instead, e.g. @@ -203,5 +203,5 @@ def hardy_weinberg_test( cts = [1, 0, 2] # arg order: hets, hom1, hom2 obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts] p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs) - new_ds = xr.Dataset({"variant_hwe_p_value": ("variants", p)}) + new_ds = xr.Dataset({variables.variant_hwe_p_value: ("variants", p)}) return conditional_merge_datasets(ds, variables.validate(new_ds), merge) diff --git a/sgkit/stats/pc_relate.py b/sgkit/stats/pc_relate.py index b7aa66816..865cff8ea 100644 --- a/sgkit/stats/pc_relate.py +++ b/sgkit/stats/pc_relate.py @@ -24,8 +24,8 @@ def _impute_genotype_call_with_variant_mean( def _collapse_ploidy( ds: xr.Dataset, - call_genotype: str = "call_genotype", - call_genotype_mask: str = "call_genotype_mask", + call_genotype: str = variables.call_genotype, + call_genotype_mask: str = variables.call_genotype_mask, ) -> Tuple[xr.DataArray, xr.DataArray]: call_g_mask = ds[call_genotype_mask].any(dim="ploidy") call_g = xr.where(call_g_mask, -1, ds[call_genotype].sum(dim="ploidy")) # type: ignore[no-untyped-call] @@ -36,9 +36,9 @@ def pc_relate( ds: xr.Dataset, *, maf: float = 0.01, - call_genotype: str = "call_genotype", - call_genotype_mask: str = "call_genotype_mask", - sample_pcs: str = "sample_pcs", + call_genotype: str = variables.call_genotype, + call_genotype_mask: str = variables.call_genotype_mask, + sample_pcs: str = variables.sample_pcs, merge: bool = True ) -> xr.Dataset: """Compute PC-Relate as described in Conomos, et al. 2016 [1]. @@ -82,13 +82,13 @@ def pc_relate( The default value is 0.01. Must be between (0.0, 0.1). call_genotype Input variable name holding call_genotype. - Defined by :data:`sgkit.variables.call_genotype`. + Defined by :data:`sgkit.variables.call_genotype_spec`. call_genotype_mask Input variable name holding call_genotype_mask. - Defined by :data:`sgkit.variables.call_genotype_mask` + Defined by :data:`sgkit.variables.call_genotype_mask_spec` sample_pcs Input variable name holding sample_pcs. - Defined by :data:`sgkit.variables.sample_pcs` + Defined by :data:`sgkit.variables.sample_pcs_spec` merge If True (the default), merge the input dataset and the computed output variables into a single dataset, otherwise return only @@ -102,7 +102,7 @@ def pc_relate( Returns ------- Dataset containing (S = num samples): - :data:`sgkit.variables.pc_relate_phi`: (S,S) ArrayLike + :data:`sgkit.variables.pc_relate_phi_spec`: (S,S) ArrayLike pairwise recent kinship coefficient matrix as float in [-0.5, 0.5]. References @@ -131,9 +131,9 @@ def pc_relate( variables.validate( ds, { - call_genotype: variables.call_genotype, - call_genotype_mask: variables.call_genotype_mask, - sample_pcs: variables.sample_pcs, + call_genotype: variables.call_genotype_spec, + call_genotype_mask: variables.call_genotype_mask_spec, + sample_pcs: variables.sample_pcs_spec, }, ) @@ -170,5 +170,5 @@ def pc_relate( phi = gramian(centered_af) / gramian(stddev) # NOTE: phi is of shape (S x S), S = num samples assert phi.shape == (call_g.shape[1],) * 2 - new_ds = xr.Dataset({"pc_relate_phi": (("sample_x", "sample_y"), phi)}) + new_ds = xr.Dataset({variables.pc_relate_phi: (("sample_x", "sample_y"), phi)}) return conditional_merge_datasets(ds, variables.validate(new_ds), merge) diff --git a/sgkit/stats/popgen.py b/sgkit/stats/popgen.py index 77d174f7d..e01e413f3 100644 --- a/sgkit/stats/popgen.py +++ b/sgkit/stats/popgen.py @@ -10,8 +10,8 @@ def diversity( ds: Dataset, *, - call_genotype: str = "call_genotype", - allele_counts: str = "variant_allele_count", + call_genotype: str = variables.call_genotype, + allele_counts: str = variables.variant_allele_count, ) -> DataArray: """Compute diversity from allele counts. @@ -30,10 +30,10 @@ def diversity( Genotype call dataset. call_genotype Input variable name holding call_genotype. - Defined by :data:`sgkit.variables.call_genotype`. + Defined by :data:`sgkit.variables.call_genotype_spec`. allele_counts allele counts to use or calculate, as defined by - :data:`sgkit.variables.variant_allele_count` + :data:`sgkit.variables.variant_allele_count_spec` Returns ------- @@ -42,10 +42,10 @@ def diversity( if len(ds.samples) < 2: return xr.DataArray(np.nan) if allele_counts not in ds: - variables.validate(ds, {call_genotype: variables.call_genotype}) + variables.validate(ds, {call_genotype: variables.call_genotype_spec}) ds_new = count_variant_alleles(ds, call_genotype=call_genotype) else: - variables.validate(ds, {allele_counts: variables.variant_allele_count}) + variables.validate(ds, {allele_counts: variables.variant_allele_count_spec}) ds_new = ds ac = ds_new[allele_counts] an = ac.sum(axis=1) @@ -60,8 +60,8 @@ def divergence( ds1: Dataset, ds2: Dataset, *, - call_genotype: str = "call_genotype", - allele_counts: str = "variant_allele_count", + call_genotype: str = variables.call_genotype, + allele_counts: str = variables.variant_allele_count, ) -> DataArray: """Compute divergence between two genotype call datasets. @@ -73,27 +73,27 @@ def divergence( Genotype call dataset. call_genotype Input variable name holding call_genotype. - Defined by :data:`sgkit.variables.call_genotype`. + Defined by :data:`sgkit.variables.call_genotype_spec`. allele_counts allele counts to use or calculate, as defined by - :data:`sgkit.variables.variant_allele_count` + :data:`sgkit.variables.variant_allele_count_spec` Returns ------- divergence value between the two datasets. """ if allele_counts not in ds1: - variables.validate(ds1, {call_genotype: variables.call_genotype}) + variables.validate(ds1, {call_genotype: variables.call_genotype_spec}) ds1_new = count_variant_alleles(ds1) else: - variables.validate(ds1, {allele_counts: variables.variant_allele_count}) + variables.validate(ds1, {allele_counts: variables.variant_allele_count_spec}) ds1_new = ds1 ac1 = ds1_new[allele_counts] if allele_counts not in ds2: - variables.validate(ds2, {call_genotype: variables.call_genotype}) + variables.validate(ds2, {call_genotype: variables.call_genotype_spec}) ds2_new = count_variant_alleles(ds2) else: - variables.validate(ds2, {allele_counts: variables.variant_allele_count}) + variables.validate(ds2, {allele_counts: variables.variant_allele_count_spec}) ds2_new = ds2 ac2 = ds2_new[allele_counts] an1 = ds1_new[allele_counts].sum(axis=1) @@ -110,8 +110,8 @@ def Fst( ds1: Dataset, ds2: Dataset, *, - call_genotype: str = "call_genotype", - allele_counts: str = "variant_allele_count", + call_genotype: str = variables.call_genotype, + allele_counts: str = variables.variant_allele_count, ) -> DataArray: """Compute Fst between two genotype call datasets. @@ -123,10 +123,10 @@ def Fst( Genotype call dataset. call_genotype Input variable name holding call_genotype. - Defined by :data:`sgkit.variables.call_genotype`. + Defined by :data:`sgkit.variables.call_genotype_spec`. allele_counts allele counts to use or calculate, as defined by - :data:`sgkit.variables.variant_allele_count` + :data:`sgkit.variables.variant_allele_count_spec` Returns ------- @@ -144,8 +144,8 @@ def Fst( def Tajimas_D( ds: Dataset, *, - call_genotype: str = "call_genotype", - allele_counts: str = "variant_allele_count", + call_genotype: str = variables.call_genotype, + allele_counts: str = variables.variant_allele_count, ) -> DataArray: """Compute Tajimas' D for a genotype call dataset. @@ -155,20 +155,20 @@ def Tajimas_D( Genotype call dataset. call_genotype Input variable name holding call_genotype. - Defined by :data:`sgkit.variables.call_genotype`. + Defined by :data:`sgkit.variables.call_genotype_spec`. allele_counts allele counts to use or calculate, as defined by - :data:`sgkit.variables.variant_allele_count` + :data:`sgkit.variables.variant_allele_count_spec` Returns ------- Tajimas' D value. """ if allele_counts not in ds: - variables.validate(ds, {call_genotype: variables.call_genotype}) + variables.validate(ds, {call_genotype: variables.call_genotype_spec}) ds_new = count_variant_alleles(ds) else: - variables.validate(ds, {allele_counts: variables.variant_allele_count}) + variables.validate(ds, {allele_counts: variables.variant_allele_count_spec}) ds_new = ds ac = ds_new[allele_counts] diff --git a/sgkit/stats/regenie.py b/sgkit/stats/regenie.py index b7d1a365b..954de4026 100644 --- a/sgkit/stats/regenie.py +++ b/sgkit/stats/regenie.py @@ -708,16 +708,16 @@ def regenie_transform( YP3 = _stage_3(B2, YP1, X, Y, contigs, variant_chunk_start) data_vars: Dict[Hashable, Any] = {} - data_vars["base_prediction"] = xr.DataArray( + data_vars[variables.base_prediction] = xr.DataArray( YP1, dims=("blocks", "alphas", "samples", "outcomes"), attrs={"description": DESC_BASE_PRED}, ) - data_vars["meta_prediction"] = xr.DataArray( + data_vars[variables.meta_prediction] = xr.DataArray( YP2, dims=("samples", "outcomes"), attrs={"description": DESC_META_PRED} ) if YP3 is not None: - data_vars["loco_prediction"] = xr.DataArray( + data_vars[variables.loco_prediction] = xr.DataArray( YP3, dims=("contigs", "samples", "outcomes"), attrs={"description": DESC_LOCO_PRED}, @@ -731,7 +731,7 @@ def regenie( dosage: str, covariates: Union[str, Sequence[str]], traits: Union[str, Sequence[str]], - variant_contig: str = "variant_contig", + variant_contig: str = variables.variant_contig, variant_block_size: Optional[Union[int, Tuple[int, ...]]] = None, sample_block_size: Optional[Union[int, Tuple[int, ...]]] = None, alphas: Optional[Sequence[float]] = None, @@ -755,16 +755,16 @@ def regenie( ---------- dosage Name of genetic dosage variable. - Defined by :data:`sgkit.variables.dosage`. + Defined by :data:`sgkit.variables.dosage_spec`. covariates Names of covariate variables (1D or 2D). - Defined by :data:`sgkit.variables.covariates`. + Defined by :data:`sgkit.variables.covariates_spec`. traits Names of trait variables (1D or 2D). - Defined by :data:`sgkit.variables.traits`. + Defined by :data:`sgkit.variables.traits_spec`. variant_contig Name of the variant contig input variable. - Definied by :data:`sgkit.variables.variant_contig`. + Definied by :data:`sgkit.variables.variant_contig_spec`. variant_block_size Number of variants in each block. If int, this describes the number of variants in each block @@ -809,17 +809,17 @@ def regenie( - `base_prediction` (blocks, alphas, samples, outcomes): Stage 1 predictions from ridge regression reduction. Defined by - :data:`sgkit.variables.base_prediction`. + :data:`sgkit.variables.base_prediction_spec`. - `meta_prediction` (samples, outcomes): Stage 2 predictions from the best meta estimator trained on the out-of-sample Stage 1 - predictions. Defined by :data:`sgkit.variables.meta_prediction`. + predictions. Defined by :data:`sgkit.variables.meta_prediction_spec`. - `loco_prediction` (contigs, samples, outcomes): LOCO predictions resulting from Stage 2 predictions ignoring effects for variant blocks on held out contigs. This will be absent if the data provided does not contain at least 2 contigs. Defined by - :data:`sgkit.variables.loco_prediction`. + :data:`sgkit.variables.loco_prediction_spec`. Raises ------ @@ -864,9 +864,9 @@ def regenie( variables.validate( ds, - {dosage: variables.dosage, variant_contig: variables.variant_contig}, - {c: variables.covariates for c in covariates}, - {t: variables.traits for t in traits}, + {dosage: variables.dosage_spec, variant_contig: variables.variant_contig_spec}, + {c: variables.covariates_spec for c in covariates}, + {t: variables.traits_spec for t in traits}, ) G = ds[dosage] diff --git a/sgkit/tests/test_variables.py b/sgkit/tests/test_variables.py index c0e5c3b39..d7f2ecbfc 100644 --- a/sgkit/tests/test_variables.py +++ b/sgkit/tests/test_variables.py @@ -27,10 +27,15 @@ def test_variables__no_spec(dummy_ds: xr.Dataset) -> None: def test_variables__validate_by_name(dummy_ds: xr.Dataset) -> None: spec = ArrayLikeSpec("foo", kind="i", ndim=1) try: - SgkitVariables.register_variable(spec) + assert "foo" not in SgkitVariables.registered_variables + name, spec_b = SgkitVariables.register_variable(spec) + assert "foo" in SgkitVariables.registered_variables + assert name == "foo" + assert spec_b == spec variables.validate(dummy_ds, "foo") finally: SgkitVariables.registered_variables.pop("foo", None) + assert "foo" not in SgkitVariables.registered_variables def test_variables__validate_by_dummy_spec(dummy_ds: xr.Dataset) -> None: diff --git a/sgkit/variables.py b/sgkit/variables.py index da7807e2c..4a5f14ce4 100644 --- a/sgkit/variables.py +++ b/sgkit/variables.py @@ -1,20 +1,24 @@ import logging from dataclasses import dataclass -from typing import Dict, Hashable, Mapping, Set, Union, overload +from typing import Dict, Hashable, Mapping, Set, Tuple, Union, overload import xarray as xr logger = logging.getLogger(__name__) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class Spec: """Root type Spec""" default_name: str + # Note: we want to prevent dev/users from mistakenly + # using Spec as a hashable obj in dict, xr.Dataset + __hash__ = None # type: ignore[assignment] -@dataclass(frozen=True) + +@dataclass(frozen=True, eq=False) class ArrayLikeSpec(Spec): """ArrayLike type spec""" @@ -22,6 +26,99 @@ class ArrayLikeSpec(Spec): ndim: Union[None, int, Set[int]] = None +class SgkitVariables: + """Holds registry of Sgkit variables, and can validate a dataset against a spec""" + + registered_variables: Dict[Hashable, Spec] = {} + + @classmethod + def register_variable(cls, spec: Spec) -> Tuple[str, Spec]: + """Register variable spec""" + if spec.default_name in cls.registered_variables: + raise ValueError(f"`{spec.default_name}` already registered") + cls.registered_variables[spec.default_name] = spec + return spec.default_name, spec + + @classmethod + @overload + def _validate( + cls, + xr_dataset: xr.Dataset, + *specs: Mapping[Hashable, Spec], + ) -> xr.Dataset: + """ + Validate that xr_dataset contains array(s) of interest with alternative + variable name(s). To validate all variables in the dataset, skip `specs`. + """ + ... + + @classmethod + @overload + def _validate(cls, xr_dataset: xr.Dataset, *specs: Spec) -> xr.Dataset: + """ + Validate that xr_dataset contains array(s) of interest with default + variable name(s). To validate all variables in the dataset, skip `specs`. + """ + ... + + @classmethod + @overload + def _validate(cls, xr_dataset: xr.Dataset, *specs: Hashable) -> xr.Dataset: + """ + Validate that xr_dataset contains array(s) of interest with variable + name(s). Variable must be registered in `SgkitVariables.registered_variables`. + To validate all variables in the dataset, skip `specs`. + """ + ... + + @classmethod + def _validate( + cls, + xr_dataset: xr.Dataset, + *specs: Union[Spec, Mapping[Hashable, Spec], Hashable], + ) -> xr.Dataset: + if len(specs) == 0: + specs = tuple(xr_dataset.variables.keys()) + logger.debug(f"No specs provided, will validate all variables: {specs}") + for s in specs: + if isinstance(s, Spec): + cls._check_field(xr_dataset, s, s.default_name) + elif isinstance(s, Mapping): + for fname, field_spec in s.items(): + cls._check_field(xr_dataset, field_spec, fname) + elif s: + try: + field_spec = cls.registered_variables[s] + except KeyError: + raise ValueError(f"No array spec registered for {s}") + cls._check_field(xr_dataset, field_spec, field_spec.default_name) + return xr_dataset + + @classmethod + def _check_field( + cls, xr_dataset: xr.Dataset, field_spec: Spec, field: Hashable + ) -> None: + from sgkit.utils import check_array_like + + assert isinstance( + field_spec, ArrayLikeSpec + ), "ArrayLikeSpec is the only currently supported variable spec" + + if field not in xr_dataset: + raise ValueError(f"{field} not present in {xr_dataset}") + try: + check_array_like( + xr_dataset[field], kind=field_spec.kind, ndim=field_spec.ndim + ) + except (TypeError, ValueError) as e: + raise ValueError( + f"{field} does not match the spec, see the error above for more detail" + ) from e + + +validate = SgkitVariables._validate +"""Shortcut for the SgkitVariables.validate""" + """ We define xr.Dataset variables used in the sgkit methods below, these definitions: @@ -38,46 +135,70 @@ class ArrayLikeSpec(Spec): specific page. """ -base_prediction = ArrayLikeSpec("base_prediction", ndim=4, kind="f") +base_prediction, base_prediction_spec = SgkitVariables.register_variable( + ArrayLikeSpec("base_prediction", ndim=4, kind="f") +) """ REGENIE's base prediction (blocks, alphas, samples, outcomes). Stage 1 predictions from ridge regression reduction. """ -call_allele_count = ArrayLikeSpec("call_allele_count", ndim=3, kind="u") +call_allele_count, call_allele_count_spec = SgkitVariables.register_variable( + ArrayLikeSpec("call_allele_count", ndim=3, kind="u") +) """ Allele counts. With shape (variants, samples, alleles) and values corresponding to the number of non-missing occurrences of each allele. """ -call_dosage = ArrayLikeSpec("call_dosage", kind="f", ndim=2) +call_dosage, call_dosage_spec = SgkitVariables.register_variable( + ArrayLikeSpec("call_dosage", kind="f", ndim=2) +) """Dosages, encoded as floats, with NaN indicating a missing value.""" -call_dosage_mask = ArrayLikeSpec("call_dosage_mask", kind="b", ndim=2) +call_dosage_mask, call_dosage_mask_spec = SgkitVariables.register_variable( + ArrayLikeSpec("call_dosage_mask", kind="b", ndim=2) +) """TODO""" -call_genotype = ArrayLikeSpec("call_genotype", kind="i", ndim=3) +call_genotype, call_genotype_spec = SgkitVariables.register_variable( + ArrayLikeSpec("call_genotype", kind="i", ndim=3) +) """ Call genotype. Encoded as allele values (0 for the reference, 1 for the first allele, 2 for the second allele), or -1 to indicate a missing value. """ -call_genotype_mask = ArrayLikeSpec("call_genotype_mask", kind="b", ndim=3) +call_genotype_mask, call_genotype_mask_spec = SgkitVariables.register_variable( + ArrayLikeSpec("call_genotype_mask", kind="b", ndim=3) +) """TODO""" -call_genotype_phased = ArrayLikeSpec("call_genotype_phased", kind="b", ndim=2) +call_genotype_phased, call_genotype_phased_spec = SgkitVariables.register_variable( + ArrayLikeSpec("call_genotype_phased", kind="b", ndim=2) +) """ A flag for each call indicating if it is phased or not. If omitted all calls are unphased. """ -call_genotype_probability = ArrayLikeSpec("call_genotype_probability", kind="f", ndim=3) +( + call_genotype_probability, + call_genotype_probability_spec, +) = SgkitVariables.register_variable( + ArrayLikeSpec("call_genotype_probability", kind="f", ndim=3) +) """TODO""" -call_genotype_probability_mask = ArrayLikeSpec( - "call_genotype_probability_mask", kind="b", ndim=3 +( + call_genotype_probability_mask, + call_genotype_probability_mask_spec, +) = SgkitVariables.register_variable( + ArrayLikeSpec("call_genotype_probability_mask", kind="b", ndim=3) ) """TODO""" -covariates = ArrayLikeSpec("covariates", ndim={1, 2}) +covariates, covariates_spec = SgkitVariables.register_variable( + ArrayLikeSpec("covariates", ndim={1, 2}) +) """ Covariate variable names. Must correspond to 1 or 2D dataset variables of shape (samples[, covariates]). All covariate arrays will be concatenated along the second axis (columns). """ -dosage = ArrayLikeSpec("dosage") +dosage, dosage_spec = SgkitVariables.register_variable(ArrayLikeSpec("dosage")) """ Dosage variable name. Where "dosage" array can contain represent one of several possible quantities, e.g.: @@ -86,163 +207,123 @@ class ArrayLikeSpec(Spec): - True dosages as computed from imputed or probabilistic variant calls - Any other custom encoding in a user-defined variable """ -genotype_counts = ArrayLikeSpec("genotype_counts", ndim=2, kind="i") +genotype_counts, genotype_counts_spec = SgkitVariables.register_variable( + ArrayLikeSpec("genotype_counts", ndim=2, kind="i") +) """ Genotype counts. Must correspond to an (`N`, 3) array where `N` is equal to the number of variants and the 3 columns contain heterozygous, homozygous reference, and homozygous alternate counts (in that order) across all samples for a variant. """ -loco_prediction = ArrayLikeSpec("loco_prediction", ndim=3, kind="f") +loco_prediction, loco_prediction_spec = SgkitVariables.register_variable( + ArrayLikeSpec("loco_prediction", ndim=3, kind="f") +) """ REGENIE's loco_prediction (contigs, samples, outcomes). LOCO predictions resulting from Stage 2 predictions ignoring effects for variant blocks on held out contigs. This will be absent if the data provided does not contain at least 2 contigs. """ -meta_prediction = ArrayLikeSpec("meta_prediction", ndim=2, kind="f") +meta_prediction, meta_prediction_spec = SgkitVariables.register_variable( + ArrayLikeSpec("meta_prediction", ndim=2, kind="f") +) """ REGENIE's meta_prediction (samples, outcomes). Stage 2 predictions from the best meta estimator trained on the out-of-sample Stage 1 predictions. """ -pc_relate_phi = ArrayLikeSpec("pc_relate_phi", ndim=2, kind="f") +pc_relate_phi, pc_relate_phi_spec = SgkitVariables.register_variable( + ArrayLikeSpec("pc_relate_phi", ndim=2, kind="f") +) """PC Relate kinship coefficient matrix.""" -sample_id = ArrayLikeSpec("sample_id", kind={"U", "O"}, ndim=1) +sample_id, sample_id_spec = SgkitVariables.register_variable( + ArrayLikeSpec("sample_id", kind={"U", "O"}, ndim=1) +) """The unique identifier of the sample.""" -sample_pcs = ArrayLikeSpec("sample_pcs", ndim=2, kind="f") +sample_pcs, sample_pcs_spec = SgkitVariables.register_variable( + ArrayLikeSpec("sample_pcs", ndim=2, kind="f") +) """Sample PCs (PCxS).""" -traits = ArrayLikeSpec("traits", ndim={1, 2}) +traits, traits_spec = SgkitVariables.register_variable( + ArrayLikeSpec("traits", ndim={1, 2}) +) """ Trait (for example phenotype) variable names. Must all be continuous and correspond to 1 or 2D dataset variables of shape (samples[, traits]). 2D trait arrays will be assumed to contain separate traits within columns and concatenated to any 1D traits along the second axis (columns). """ -variant_allele = ArrayLikeSpec("variant_allele", kind={"S", "O"}, ndim=2) +variant_allele, variant_allele_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_allele", kind={"S", "O"}, ndim=2) +) """The possible alleles for the variant.""" -variant_allele_count = ArrayLikeSpec("variant_allele_count", ndim=2, kind="u") +variant_allele_count, variant_allele_count_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_allele_count", ndim=2, kind="u") +) """ Variant allele counts. With shape (variants, alleles) and values corresponding to the number of non-missing occurrences of each allele. """ -variant_allele_frequency = ArrayLikeSpec("variant_allele_frequency", ndim=2, kind="f") +( + variant_allele_frequency, + variant_allele_frequency_spec, +) = SgkitVariables.register_variable( + ArrayLikeSpec("variant_allele_frequency", ndim=2, kind="f") +) """The frequency of the occurrence of each allele.""" -variant_allele_total = ArrayLikeSpec("variant_allele_total", ndim=1, kind="i") +variant_allele_total, variant_allele_total_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_allele_total", ndim=1, kind="i") +) """The number of occurrences of all alleles.""" -variant_beta = ArrayLikeSpec("variant_beta") +variant_beta, variant_beta_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_beta") +) """Beta values associated with each variant and trait.""" -variant_call_rate = ArrayLikeSpec("variant_call_rate", ndim=1, kind="f") +variant_call_rate, variant_call_rate_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_call_rate", ndim=1, kind="f") +) """The number of samples with heterozygous calls.""" -variant_contig = ArrayLikeSpec("variant_contig", kind="i", ndim=1) +variant_contig, variant_contig_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_contig", kind="i", ndim=1) +) """The (index of the) contig for each variant.""" -variant_hwe_p_value = ArrayLikeSpec("variant_hwe_p_value", kind="f") +variant_hwe_p_value, variant_hwe_p_value_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_hwe_p_value", kind="f") +) """P values from HWE test for each variant as float in [0, 1].""" -variant_id = ArrayLikeSpec("variant_id", kind="U", ndim=1) +variant_id, variant_id_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_id", kind="U", ndim=1) +) """The unique identifier of the variant.""" -variant_n_called = ArrayLikeSpec("variant_n_called", ndim=1, kind="i") +variant_n_called, variant_n_called_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_n_called", ndim=1, kind="i") +) """The number of samples with called genotypes.""" -variant_n_het = ArrayLikeSpec("variant_n_het", ndim=1, kind="i") +variant_n_het, variant_n_het_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_n_het", ndim=1, kind="i") +) """The number of samples with heterozygous calls.""" -variant_n_hom_alt = ArrayLikeSpec("variant_n_hom_alt", ndim=1, kind="i") +variant_n_hom_alt, variant_n_hom_alt_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_n_hom_alt", ndim=1, kind="i") +) """The number of samples with homozygous alternate calls.""" -variant_n_hom_ref = ArrayLikeSpec("variant_n_hom_ref", ndim=1, kind="i") +variant_n_hom_ref, variant_n_hom_ref_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_n_hom_ref", ndim=1, kind="i") +) """The number of samples with homozygous reference calls.""" -variant_n_non_ref = ArrayLikeSpec("variant_n_non_ref", ndim=1, kind="i") +variant_n_non_ref, variant_n_non_ref_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_n_non_ref", ndim=1, kind="i") +) """The number of samples that are not homozygous reference calls.""" -variant_p_value = ArrayLikeSpec("variant_p_value", kind="f") +variant_p_value, variant_p_value_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_p_value", kind="f") +) """P values as float in [0, 1].""" -variant_position = ArrayLikeSpec("variant_position", kind="i", ndim=1) +variant_position, variant_position_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_position", kind="i", ndim=1) +) """The reference position of the variant.""" -variant_t_value = ArrayLikeSpec("variant_t_value") +variant_t_value, variant_t_value_spec = SgkitVariables.register_variable( + ArrayLikeSpec("variant_t_value") +) """T statistics for each beta.""" - - -class SgkitVariables: - """Holds registry of Sgkit variables, and can validate a dataset against a spec""" - - registered_variables: Dict[Hashable, ArrayLikeSpec] = { - x.default_name: x for x in globals().values() if isinstance(x, ArrayLikeSpec) - } - - @classmethod - def register_variable(cls, spec: ArrayLikeSpec) -> None: - """Register variable spec""" - if spec.default_name in cls.registered_variables: - raise ValueError(f"`{spec.default_name}` already registered") - cls.registered_variables[spec.default_name] = spec - - @classmethod - @overload - def _validate( - cls, - xr_dataset: xr.Dataset, - *specs: Mapping[Hashable, ArrayLikeSpec], - ) -> xr.Dataset: - """ - Validate that xr_dataset contains array(s) of interest with alternative - variable name(s). To validate all variables in the dataset, skip `specs`. - """ - ... - - @classmethod - @overload - def _validate(cls, xr_dataset: xr.Dataset, *specs: ArrayLikeSpec) -> xr.Dataset: - """ - Validate that xr_dataset contains array(s) of interest with default - variable name(s). To validate all variables in the dataset, skip `specs`. - """ - ... - - @classmethod - @overload - def _validate(cls, xr_dataset: xr.Dataset, *specs: Hashable) -> xr.Dataset: - """ - Validate that xr_dataset contains array(s) of interest with variable - name(s). Variable must be registered in `SgkitVariables.registered_variables`. - To validate all variables in the dataset, skip `specs`. - """ - ... - - @classmethod - def _validate( - cls, - xr_dataset: xr.Dataset, - *specs: Union[ArrayLikeSpec, Mapping[Hashable, ArrayLikeSpec], Hashable], - ) -> xr.Dataset: - if len(specs) == 0: - specs = tuple(xr_dataset.variables.keys()) - logger.debug(f"No specs provided, will validate all variables: {specs}") - for s in specs: - if isinstance(s, ArrayLikeSpec): - cls._check_field(xr_dataset, s, s.default_name) - elif isinstance(s, Mapping): - for fname, field_spec in s.items(): - cls._check_field(xr_dataset, field_spec, fname) - else: - try: - field_spec = cls.registered_variables[s] - except KeyError: - raise ValueError(f"No array spec registered for {s}") - cls._check_field(xr_dataset, field_spec, field_spec.default_name) - return xr_dataset - - @classmethod - def _check_field( - cls, xr_dataset: xr.Dataset, field_spec: ArrayLikeSpec, field: Hashable - ) -> None: - from sgkit.utils import check_array_like - - if field not in xr_dataset: - raise ValueError(f"{field} not present in {xr_dataset}") - try: - check_array_like( - xr_dataset[field], kind=field_spec.kind, ndim=field_spec.ndim - ) - except (TypeError, ValueError) as e: - raise ValueError( - f"{field} does not match the spec, see the error above for more detail" - ) from e - - -validate = SgkitVariables._validate -"""Shortcut for the SgkitVariables.validate"""