Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_is_haploid_expr, get_dp_gq_adj_expr, get_adj_het_ab_expr, and some helpful parameters to agg_by_strata and compute_stats_per_ref_site #673

Merged
merged 8 commits into from
Feb 12, 2024
38 changes: 24 additions & 14 deletions gnomad/sample_qc/sex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pandas as pd
from sklearn.mixture import GaussianMixture

from gnomad.utils.annotations import prep_ploidy_ht

logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -23,25 +25,33 @@ def adjusted_sex_ploidy_expr(
xx_karyotype_str: str = "XX",
) -> hl.expr.CallExpression:
"""
Create an entry expression to convert males to haploid on non-PAR X/Y and females to missing on Y.
Create an entry expression to convert XY to haploid on non-PAR X/Y and XX to missing on Y.

:param locus_expr: Locus
:param gt_expr: Genotype
:param karyotype_expr: Karyotype
:param xy_karyotype_str: Male sex karyotype representation
:param xx_karyotype_str: Female sex karyotype representation
:param locus_expr: Locus expression.
:param gt_expr: Genotype expression.
:param karyotype_expr: Sex karyotype expression.
:param xy_karyotype_str: String representing XY karyotype. Default is "XY".
:param xx_karyotype_str: String representing XX karyotype. Default is "XX".
:return: Genotype adjusted for sex ploidy
"""
male = karyotype_expr == xy_karyotype_str
female = karyotype_expr == xx_karyotype_str
x_nonpar = locus_expr.in_x_nonpar()
y_par = locus_expr.in_y_par()
y_nonpar = locus_expr.in_y_nonpar()
# An optimization that annotates the locus's matrix table with the
# fields in the case statements below as an optimization step
col_idx, row_idx = prep_ploidy_ht(
locus_expr, karyotype_expr, xy_karyotype_str, xx_karyotype_str
)

return (
hl.case(missing_false=True)
.when(female & (y_par | y_nonpar), hl.null(hl.tcall))
.when(male & (x_nonpar | y_nonpar) & gt_expr.is_het(), hl.null(hl.tcall))
.when(male & (x_nonpar | y_nonpar), hl.call(gt_expr[0], phased=False))
.when(~row_idx.in_non_par, gt_expr)
.when(col_idx.xx & (row_idx.y_par | row_idx.y_nonpar), hl.null(hl.tcall))
.when(
col_idx.xy & (row_idx.x_nonpar | row_idx.y_nonpar) & gt_expr.is_het(),
hl.null(hl.tcall),
)
.when(
col_idx.xy & (row_idx.x_nonpar | row_idx.y_nonpar),
hl.call(gt_expr[0], phased=False),
)
.default(gt_expr)
)

Expand Down
234 changes: 211 additions & 23 deletions gnomad/utils/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,163 @@ def create_frequency_bins_expr(
return bin_expr


def prep_ploidy_ht(
locus_expr: hl.expr.LocusExpression = None,
karyotype_expr: hl.expr.StringExpression = None,
xy_karyotype_str: str = "XY",
xx_karyotype_str: str = "XX",
) -> Tuple[hl.expr.StructExpression, hl.expr.StructExpression]:
"""
Prepare relevant ploidy annotations for downstream calculations on a matrix table.

This method annotates the matrix table with the following fields:

- `xy`: Boolean indicating if the sample is XY
- `xx`: Boolean indicating if the sample is XX
- `in_non_par`: Boolean indicating if the locus is in a non-PAR region
- `x_nonpar`: Boolean indicating if the locus is in a non-PAR region of the X chromosome
- `y_par`: Boolean indicating if the locus is in a PAR region of the Y chromosome
- `y_nonpar`: Boolean indicating if the locus is in a non-PAR region of the Y chromosome

This method is used as an optimization for the `get_is_haploid_expr`
and `adjusted_sex_ploidy_expr` methods.

:param locus_expr: Locus expression.
:param karyotype_expr: Karyotype expression.
:param xy_karyotype_str: String representing XY karyotype. Default is "XY".
:param xx_karyotype_str: String representing XX karyotype. Default is "XX".
:return: Tuple of index expressions for columns and rows.
"""
source_mt = locus_expr._indices.source
col_ht = source_mt.annotate_cols(
xy=karyotype_expr.upper() == xy_karyotype_str,
xx=karyotype_expr.upper() == xx_karyotype_str,
).cols()
row_ht = source_mt.annotate_rows(
in_non_par=~locus_expr.in_autosome_or_par(),
x_nonpar=locus_expr.in_x_nonpar(),
y_par=locus_expr.in_y_par(),
y_nonpar=locus_expr.in_y_nonpar(),
).rows()
col_idx = col_ht[source_mt.col_key]
row_idx = row_ht[source_mt.row_key]
return col_idx, row_idx


def get_is_haploid_expr(
gt_expr: Optional[hl.expr.CallExpression] = None,
locus_expr: Optional[hl.expr.LocusExpression] = None,
karyotype_expr: Optional[hl.expr.StringExpression] = None,
xy_karyotype_str: str = "XY",
xx_karyotype_str: str = "XX",
) -> hl.expr.BooleanExpression:
"""
Determine if a genotype or locus and karyotype combination is haploid.

.. note::

One of `gt_expr` or `locus_expr` and `karyotype_expr` is required.

:param gt_expr: Optional genotype expression.
:param locus_expr: Optional locus expression.
:param karyotype_expr: Optional sex karyotype expression.
:param xy_karyotype_str: String representing XY karyotype. Default is "XY".
:param xx_karyotype_str: String representing XX karyotype. Default is "XX".
:return: Boolean expression indicating if the genotype is haploid.
"""
if gt_expr is None and locus_expr is None and karyotype_expr is None:
raise ValueError(
"One of 'gt_expr' or 'locus_expr' and 'karyotype_expr' is required."
)

if gt_expr is not None:
return gt_expr.is_haploid()

if locus_expr is None or karyotype_expr is None:
raise ValueError(
"Both 'locus_expr' and 'karyotype_expr' are required if no 'gt_expr' is "
"supplied."
)
# An optimization that annotates the locus's matrix table with the
# fields in the case statements below as an optimization step
col_idx, row_idx = prep_ploidy_ht(
locus_expr, karyotype_expr, xy_karyotype_str, xx_karyotype_str
)

return row_idx.in_non_par & hl.or_missing(
~(col_idx.xx & (row_idx.y_par | row_idx.y_nonpar)),
col_idx.xy & (row_idx.x_nonpar | row_idx.y_nonpar),
)


def get_dp_gq_adj_expr(
gq_expr: Union[hl.expr.Int32Expression, hl.expr.Int64Expression],
dp_expr: Union[hl.expr.Int32Expression, hl.expr.Int64Expression],
gt_expr: Optional[hl.expr.CallExpression] = None,
locus_expr: Optional[hl.expr.LocusExpression] = None,
karyotype_expr: Optional[hl.expr.StringExpression] = None,
adj_gq: int = 20,
adj_dp: int = 10,
haploid_adj_dp: int = 5,
) -> hl.expr.BooleanExpression:
"""
Get adj annotation using only GQ and DP.

Default thresholds correspond to gnomAD values.

.. note::

This function can be used to annotate adj taking into account only GQ and DP.
It is useful for cases where the GT field is not available, such as in the
reference data of a VariantDataset.

.. note::

One of `gt_expr` or `locus_expr` and `karyotype_expr` is required.

:param gq_expr: GQ expression.
:param dp_expr: DP expression.
:param gt_expr: Optional genotype expression.
:param locus_expr: Optional locus expression.
:param karyotype_expr: Optional sex karyotype expression.
:param adj_gq: GQ threshold for adj. Default is 20.
:param adj_dp: DP threshold for adj. Default is 10.
:param haploid_adj_dp: Haploid DP threshold for adj. Default is 5.
:return: Boolean expression indicating adj filter.
"""
return (gq_expr >= adj_gq) & hl.if_else(
get_is_haploid_expr(gt_expr, locus_expr, karyotype_expr),
dp_expr >= haploid_adj_dp,
dp_expr >= adj_dp,
)


def get_adj_het_ab_expr(
gt_expr: hl.expr.CallExpression,
dp_expr: Union[hl.expr.Int32Expression, hl.expr.Int64Expression],
ad_expr: hl.expr.ArrayNumericExpression,
adj_ab: float = 0.2,
) -> hl.expr.BooleanExpression:
"""
Get adj het AB annotation.

:param gt_expr: Genotype expression.
:param dp_expr: DP expression.
:param ad_expr: AD expression.
:param adj_ab: AB threshold for adj. Default is 0.2.
:return: Boolean expression indicating adj het AB filter.
"""
return (
hl.case()
.when(~gt_expr.is_het(), True)
.when(gt_expr.is_het_ref(), ad_expr[gt_expr[1]] / dp_expr >= adj_ab)
.default(
(ad_expr[gt_expr[0]] / dp_expr >= adj_ab)
& (ad_expr[gt_expr[1]] / dp_expr >= adj_ab)
)
)


def get_adj_expr(
gt_expr: hl.expr.CallExpression,
gq_expr: Union[hl.expr.Int32Expression, hl.expr.Int64Expression],
Expand All @@ -617,19 +774,14 @@ def get_adj_expr(

Defaults correspond to gnomAD values.
"""
return (
(gq_expr >= adj_gq)
& hl.if_else(gt_expr.is_haploid(), dp_expr >= haploid_adj_dp, dp_expr >= adj_dp)
& (
hl.case()
.when(~gt_expr.is_het(), True)
.when(gt_expr.is_het_ref(), ad_expr[gt_expr[1]] / dp_expr >= adj_ab)
.default(
(ad_expr[gt_expr[0]] / dp_expr >= adj_ab)
& (ad_expr[gt_expr[1]] / dp_expr >= adj_ab)
)
)
)
return get_dp_gq_adj_expr(
gq_expr,
dp_expr,
gt_expr=gt_expr,
adj_gq=adj_gq,
adj_dp=adj_dp,
haploid_adj_dp=haploid_adj_dp,
) & get_adj_het_ab_expr(gt_expr, dp_expr, ad_expr, adj_ab)


def annotate_adj(
Expand Down Expand Up @@ -897,7 +1049,7 @@ def fs_from_sb(


def sor_from_sb(
sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression]
sb: Union[hl.expr.ArrayNumericExpression, hl.expr.ArrayExpression],
) -> hl.expr.Float64Expression:
"""
Compute `SOR` (Symmetric Odds Ratio test) annotation from the `SB` (strand balance table) field.
Expand Down Expand Up @@ -1949,6 +2101,7 @@ def agg_by_strata(
entry_agg_funcs: Dict[str, Tuple[Callable, Callable]],
select_fields: Optional[List[str]] = None,
group_membership_ht: Optional[hl.Table] = None,
entry_agg_group_membership: Optional[Dict[str, List[dict]]] = None,
) -> hl.Table:
"""
Get row expression for annotations of each entry aggregation function(s) by strata.
Expand All @@ -1969,6 +2122,14 @@ def agg_by_strata(
:param group_membership_ht: Optional Table containing group membership annotations
to stratify the aggregations by. If not provided, the 'group_membership'
annotation is expected to be present on `mt`.
:param entry_agg_group_membership: Optional dict indicating the subset of group
strata in 'freq_meta' to run the entry aggregation functions on. The keys of
the dict can be any of the keys in `entry_agg_funcs` and the values are lists
of dicts. Each dict in the list contains the strata in 'freq_meta' to use for
the corresponding entry aggregation function. If provided, 'freq_meta' must be
present in `group_membership_ht` or `mt` and represent the same strata as those
in 'group_membership'. If not provided, all entries of the 'group_membership'
annotation will have the entry aggregation functions applied to them.
:return: Table with annotations of stratified aggregations.
"""
if group_membership_ht is None and "group_membership" not in mt.col:
Expand Down Expand Up @@ -2019,6 +2180,18 @@ def agg_by_strata(
)
global_expr["adj_groups"] = hl.range(n_groups).map(lambda x: False)

if entry_agg_group_membership is not None and "freq_meta" not in group_globals:
raise ValueError(
"The 'freq_meta' global annotation must be supplied when the"
" 'entry_agg_group_membership' is specified."
)

entry_agg_group_membership = entry_agg_group_membership or {}
entry_agg_group_membership = {
ann: [group_globals["freq_meta"].index(s) for s in strata]
for ann, strata in entry_agg_group_membership.items()
}

n_adj_groups = hl.eval(hl.len(global_expr["adj_groups"]))
if n_adj_groups != n_groups:
raise ValueError(
Expand Down Expand Up @@ -2052,19 +2225,21 @@ def agg_by_strata(
# own ArrayExpression. This is important to prevent memory issues when performing
# the below array aggregations.
ht = ht.select(
**{
ann: ht.entries.map(lambda e: e[ann])
for ann in select_fields + list(select_expr.keys())
}
*select_fields,
**{ann: ht.entries.map(lambda e: e[ann]) for ann in select_expr.keys()},
)

def _agg_by_group(
ht: hl.Table, agg_func: Callable, ann_expr: hl.expr.ArrayExpression
indices_by_group_expr: hl.expr.ArrayExpression,
adj_groups_expr: hl.expr.ArrayExpression,
agg_func: Callable,
ann_expr: hl.expr.ArrayExpression,
) -> hl.expr.ArrayExpression:
"""
Aggregate `agg_expr` by group using the `agg_func` function.

:param ht: Input Hail Table.
:param indices_by_group_expr: ArrayExpression of indices of samples in each group.
:param adj_groups_expr: ArrayExpression indicating whether each group is adj.
:param agg_func: Aggregation function to apply to `ann_expr`.
:param ann_expr: Expression to aggregate by group.
:return: Aggregated array expression.
Expand All @@ -2079,14 +2254,27 @@ def _agg_by_group(

return hl.map(
lambda s_indices, adj: s_indices.aggregate(lambda i: f(i, adj)),
ht.indices_by_group,
ht.adj_groups,
indices_by_group_expr,
adj_groups_expr,
)

# Add annotations for any supplied entry transform and aggregation functions.
# Filter groups to only those in entry_agg_group_membership if specified.
# If there are no specific entry group indices for an annotation, use ht[g]
# to consider all groups without filtering.
ht = ht.select(
*select_fields,
**{ann: _agg_by_group(ht, f[1], ht[ann]) for ann, f in entry_agg_funcs.items()},
**{
ann: _agg_by_group(
*[
[ht[g][i] for i in entry_agg_group_membership.get(ann, [])] or ht[g]
for g in ["indices_by_group", "adj_groups"]
],
agg_func=f[1],
ann_expr=ht[ann],
)
for ann, f in entry_agg_funcs.items()
},
)

return ht.drop("cols")
Expand Down
Loading
Loading