Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic constraint functions: oe_aggregation_expr(), compute_pli(), oe_confidence_interval(), calculate_raw_z_score(), calculate_raw_z_score_sd() #505

Merged
merged 44 commits into from
Apr 14, 2023
Merged
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4897b17
add draft
Nov 2, 2022
88b1941
add draft
Nov 2, 2022
ebc0e7e
add comments
Nov 9, 2022
c9116b7
black reformat
Nov 9, 2022
dffd326
merge main
Nov 9, 2022
588f818
Apply suggestions from code review
averywpx Nov 29, 2022
8966efb
fix change requests
Dec 1, 2022
06bec20
small fix
Dec 1, 2022
59eb457
add comments
Dec 1, 2022
0c9148e
add filter for compute_oe_per_transcript
Dec 1, 2022
ec371a5
Apply suggestions from code review
jkgoodrich Mar 3, 2023
f827d30
Merge branch 'main' of https://github.com/broadinstitute/gnomad_metho…
jkgoodrich Mar 3, 2023
5a64d72
Modifications to make most of the functions return expressions
jkgoodrich Mar 7, 2023
db5f159
Fixes during testing
jkgoodrich Mar 14, 2023
65581f0
Add docstrings
jkgoodrich Mar 14, 2023
4c98b60
Remove unneeded f-strings
jkgoodrich Mar 17, 2023
8f593bc
Apply suggestions from code review
jkgoodrich Mar 20, 2023
4649470
Merge branch 'constraint_finalize_datasets' of https://github.com/bro…
jkgoodrich Mar 20, 2023
33d0b60
Fix docstring for calculate_raw_z_score
jkgoodrich Mar 20, 2023
6badab5
Update gnomad/utils/constraint.py
jkgoodrich Mar 21, 2023
65f3f8c
Z-score -> z-score, delete unused function, and add ht to parameters
jkgoodrich Mar 21, 2023
911123c
small edit
klaricch Mar 23, 2023
5a44ea5
Changes during PR review
jkgoodrich Mar 28, 2023
a31ee84
Change `calculate_z_score` to `calculate_raw_z_score_sd` and modify i…
jkgoodrich Apr 2, 2023
27909d7
Use `add_filters_expr` for constraint filters and add `no_var_expr` p…
jkgoodrich Apr 2, 2023
a9e55b9
Fix `oe_confidence_interval` docstring
jkgoodrich Apr 2, 2023
bf476a0
Set dpois to missing if `exp_expr > 0`
jkgoodrich Apr 3, 2023
1dbbd19
Change constraint flags to include flag_prefix and remove no_variants…
jkgoodrich Apr 4, 2023
6b39fbb
Change expected_values default so it's not mutable.
jkgoodrich Apr 4, 2023
bdfdf38
Small docstring change
jkgoodrich Apr 4, 2023
a8c6719
Fix NaN returned for oe
jkgoodrich Apr 4, 2023
c2f792f
pli exclude exp 0
jkgoodrich Apr 4, 2023
b540b8f
return correct missing type for oe_aggregation_expr
jkgoodrich Apr 4, 2023
6ea9f21
constraint flags prefix -> postfix
jkgoodrich Apr 4, 2023
e29fddc
Remove redundancy with dpois_expr
jkgoodrich Apr 4, 2023
aee69bd
Use divide_null in calculate_raw_z_score
jkgoodrich Apr 4, 2023
635f98c
Fix types
jkgoodrich Apr 13, 2023
bcfee5c
Merge branch 'main' of https://github.com/broadinstitute/gnomad_metho…
jkgoodrich Apr 13, 2023
b0b12ad
Merge branch 'main' of https://github.com/broadinstitute/gnomad_metho…
jkgoodrich Apr 13, 2023
3019cc3
Merge branch 'constraint_finalize_datasets' of https://github.com/bro…
jkgoodrich Apr 13, 2023
b147e98
Small docstring change
jkgoodrich Apr 13, 2023
65a698f
Docstring fixes
jkgoodrich Apr 13, 2023
dfe085a
change `calculate_raw_z_score_sd` to have `mirror_neg_raw_z` instead …
jkgoodrich Apr 14, 2023
889f62a
Merge pull request #526 from broadinstitute/constraint_finalize_datas…
jkgoodrich Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 326 additions & 10 deletions gnomad/utils/constraint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Script containing generic constraint functions that may be used in the constraint pipeline."""

import copy
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import hail as hl
from hail.utils.misc import divide_null, new_temp_file

from gnomad.utils.vep import explode_by_vep_annotation, process_consequences

Expand Down Expand Up @@ -669,24 +671,28 @@ def build_coverage_model(
def get_all_pop_lengths(
ht: hl.Table,
pops: Tuple[str],
prefix: str = "observed_",
obs_expr: hl.expr.StructExpression,
) -> List[Tuple[str, str]]:
"""
Get the minimum length of observed variant counts array for each population downsamping.
Get the minimum length of observed variant counts array for each population downsampling.

The annotations are specified by the combination of `prefix` and each population in
`pops`.
The observed variant counts for each population in `pops` are specified by
annotations on the `obs_expr` expression.

:param ht: Input Table used to build population plateau models.
:param pops: Populations used to categorize observed variant counts in downsampings.
:param prefix: Prefix of population observed variant counts. Default is `observed_`.
The function also performs a check that arrays of variant counts within population
downsamplings all have the same lengths.

:param ht: Input Table containing `obs_expr`.
:param pops: Populations used to categorize observed variant counts in downsamplings.
:param obs_expr: Expression for the population observed variant counts. Should be a
struct containing an array for each pop in `pops`.
:return: A Dictionary with the minimum array length for each population.
"""
# TODO: This function will be converted into doing just the length check if there
# is no usage of pop_lengths in the constraint pipeline.
# is no usage of pop_lengths in the constraint pipeline.
# Get minimum length of downsamplings for each population.
pop_downsampling_lengths = ht.aggregate(
[hl.agg.min(hl.len(ht[f"{prefix}{pop}"])) for pop in pops]
[hl.agg.min(hl.len(obs_expr[pop])) for pop in pops]
)

# Zip population name with their downsampling length.
Expand All @@ -696,7 +702,7 @@ def get_all_pop_lengths(
assert ht.all(
hl.all(
lambda f: f,
[hl.len(ht[f"{prefix}{pop}"]) == length for length, pop in pop_lengths],
[hl.len(obs_expr[pop]) == length for length, pop in pop_lengths],
)
), (
"The arrays of variant counts within population downsamplings have different"
Expand Down Expand Up @@ -877,3 +883,313 @@ def compute_expected_variants(
agg_expr.update({ann: agg_func(ht[ann]) for ann in ann_to_sum})

return agg_expr


def oe_aggregation_expr(
ht: hl.Table,
filter_expr: hl.expr.BooleanExpression,
pops: Tuple[str] = (),
exclude_mu_sum: bool = False,
) -> hl.expr.StructExpression:
"""
Get aggregation expressions to compute the observed:expected ratio for rows defined by `filter_expr`.

Return a Struct containing aggregation expressions to sum the number of observed
variants, possible variants, expected variants, and mutation rate (if
`exclude_mu_sum` is not True) for rows defined by `filter_expr`. The Struct also
includes an aggregation expression for the observed:expected ratio.

The following annotations are in the returned StructExpression:
- obs - the sum of observed variants filtered to `filter_expr`.
- mu - the sum of mutation rate of variants filtered to `filter_expr`.
- possible - possible number of variants filtered to `filter_expr`.
- exp - expected number of variants filtered to `filter_expr`.
- oe - observed:expected ratio of variants filtered to `filter_expr`.

If `pops` is specified:
- pop_exp - Struct with the expected number of variants per population (for
all pop in `pops`) filtered to `filter_expr`.
- pop_obs - Struct with the observed number of variants per population (for
all pop in `pops`) filtered to `filter_expr`.

.. note::
jkgoodrich marked this conversation as resolved.
Show resolved Hide resolved
The following annotations should be present in `ht`:
- observed_variants
- mu
- possible_variants
- expected_variants
If `pops` is specified, the following annotations should also be present:
- expected_variants_{pop} for all pop in `pops`
- downsampling_counts_{pop} for all pop in `pops`

:param ht: Input Table to create observed:expected ratio aggregation expressions for.
:param filter_expr: Boolean expression used to filter `ht` before aggregation.
:param pops: List of populations to compute constraint metrics for. Default is ().
:param exclude_mu_sum: Whether to exclude mu sum aggregation expression from
returned struct. Default is False.
:return: StructExpression with observed:expected ratio aggregation expressions.
"""
# Create aggregators that sum the number of observed variants, possible variants,
# and expected variants and compute observed:expected ratio.
agg_expr = {
"obs": hl.agg.sum(ht.observed_variants),
"exp": hl.agg.sum(ht.expected_variants),
"possible": hl.agg.sum(ht.possible_variants),
}
agg_expr["oe"] = divide_null(agg_expr["obs"], agg_expr["exp"])

# Create an aggregator that sums the mutation rate.
if not exclude_mu_sum:
agg_expr["mu"] = hl.agg.sum(ht.mu)

# Create aggregators that sum the number of observed variants
# and expected variants for each population if pops is specified.
if pops:
agg_expr["pop_exp"] = hl.struct(
**{pop: hl.agg.array_sum(ht[f"expected_variants_{pop}"]) for pop in pops}
)
agg_expr["pop_obs"] = hl.struct(
**{pop: hl.agg.array_sum(ht[f"downsampling_counts_{pop}"]) for pop in pops}
)

agg_expr = hl.struct(**agg_expr)
return hl.agg.group_by(filter_expr, agg_expr).get(True, hl.missing(agg_expr.dtype))


def compute_pli(
ht: hl.Table,
obs_expr: hl.expr.Int64Expression,
exp_expr: hl.expr.Float64Expression,
expected_values: Optional[Dict[str, float]] = None,
min_diff_convergence: float = 0.001,
) -> hl.StructExpression:
"""
Compute the pLI score using the observed and expected variant counts.

Full details on pLI can be found in the ExAC paper: Lek, M., Karczewski, K.,
Minikel, E. et al. Analysis of protein-coding genetic variation in 60,706 humans.
Nature 536, 285–291 (2016).

pLI is the probability of being loss-of-function intolerant, and this function
computes that probability using the expectation-maximization (EM) algorithm.

We assume a 3 state model, where each gene fits into one of three categories
with respect loss-of-function variation sensitivity:

- Null: where protein truncating variation is completely tolerated by natural
selection.
- Recessive (Rec): where heterozygous pLoFs are tolerated but homozygous pLoFs
are not.
- Haploinsufficient (LI): where heterozygous pLoFs are not tolerated.

The function requires the expected amount of loss-of-function depletion for each of
these states. The default provided is based on the observed depletion of
protein-truncating variation in the Blekhman autosomal recessive and ClinGen
dosage sensitivity gene sets (Supplementary Information Table 12 of the above
reference):

- Null: 1.0, assume tolerant genes have the expected amount of truncating
variation.
- Rec: 0.463, derived from the empirical mean observed/expected rate of
truncating variation for recessive disease genes (0.463).
- LI: 0.089, derived from the empirical mean observed/expected rate of
truncating variation for severe haploinsufficient genes.

The output StructExpression will include the following annotations:

- pLI: Probability of loss-of-function intolerance; probability that transcript
falls into distribution of haploinsufficient genes.
- pNull: Probability that transcript falls into distribution of unconstrained
genes.
- pRec: Probability that transcript falls into distribution of recessive genes.

:param ht: Input Table containing `obs_expr` and `exp_expr`.
:param obs_expr: Expression for the number of observed variants on each gene or
transcript in `ht`.
:param exp_expr: Expression for the number of expected variants on each gene or
transcript in `ht`.
:param expected_values: Dictionary containing the expected values for 'Null',
'Rec', and 'LI' to use as starting values.
:param min_diff_convergence: Minimum iteration change in LI to consider the EM
model convergence criteria as met. Default is 0.001.
:return: StructExpression for pLI scores.
"""
if expected_values is None:
expected_values = {"Null": 1.0, "Rec": 0.463, "LI": 0.089}

# Set up initial values.
last_pi = {"Null": 0, "Rec": 0, "LI": 0}
pi = {"Null": 1 / 3, "Rec": 1 / 3, "LI": 1 / 3}

dpois_expr = {
k: hl.or_missing(
exp_expr > 0, hl.dpois(obs_expr, exp_expr * expected_values[k])
)
for k in pi
}
_ht = ht.select(dpois=dpois_expr)
# Checkpoint the temp HT because it will need to be aggregated several times.
_ht = _ht.checkpoint(new_temp_file(prefix="compute_pli", extension="ht"))

# Calculate pLI scores.
while abs(pi["LI"] - last_pi["LI"]) > min_diff_convergence:
last_pi = copy.deepcopy(pi)
pi_expr = {k: v * _ht.dpois[k] for k, v in pi.items()}
row_sum_expr = hl.sum([pi_expr[k] for k in pi])
pi_expr = {k: pi_expr[k] / row_sum_expr for k, v in pi.items()}
pi = _ht.aggregate({k: hl.agg.mean(pi_expr[k]) for k in pi.keys()})

# Get expression for pLI scores.
pli_expr = {k: v * dpois_expr[k] for k, v in pi.items()}
row_sum_expr = hl.sum([pli_expr[k] for k in pi])

return hl.struct(**{f"p{k}": pli_expr[k] / row_sum_expr for k in pi.keys()})


def oe_confidence_interval(
jkgoodrich marked this conversation as resolved.
Show resolved Hide resolved
obs_expr: hl.expr.Int64Expression,
exp_expr: hl.expr.Float64Expression,
alpha: float = 0.05,
) -> hl.expr.StructExpression:
"""
Determine the confidence interval around the observed:expected ratio.

For a given pair of observed (`obs_expr`) and expected (`exp_expr`) values, the
function computes the density of the Poisson distribution (performed using Hail's
`dpois` module) with fixed k (`x` in `dpois` is set to the observed number of
variants) over a range of lambda (`lamb` in `dpois`) values, which are given by the
expected number of variants times a varying parameter ranging between 0 and 2 (the
observed:expected ratio is typically between 0 and 1, so we want to extend the
upper bound of the confidence interval to capture this). The cumulative density
function of the Poisson distribution density is computed and the value of the
varying parameter is extracted at points corresponding to `alpha` (defaults to 5%)
and 1-`alpha` (defaults to 95%) to indicate the lower and upper bounds of the
confidence interval.

The following annotations are in the output StructExpression:
- lower - the lower bound of confidence interval
- upper - the upper bound of confidence interval

:param obs_expr: Expression for the observed variant counts of pLoF, missense, or
synonymous variants in `ht`.
:param exp_expr: Expression for the expected variant counts of pLoF, missense, or
synonymous variants in `ht`.
:param alpha: The significance level used to compute the confidence interval.
Default is 0.05.
:return: StructExpression for the confidence interval lower and upper bounds.
"""
# Set up range between 0 and 2.
range_expr = hl.range(0, 2000).map(lambda x: hl.float64(x) / 1000)
range_dpois_expr = range_expr.map(lambda x: hl.dpois(obs_expr, exp_expr * x))

# Compute cumulative density function of the Poisson distribution density.
cumulative_dpois_expr = hl.cumulative_sum(range_dpois_expr)
max_cumulative_dpois_expr = cumulative_dpois_expr[-1]
norm_dpois_expr = cumulative_dpois_expr.map(lambda x: x / max_cumulative_dpois_expr)

# Extract the value of the varying parameter within specified range.
lower_idx_expr = hl.argmax(
norm_dpois_expr.map(lambda x: hl.or_missing(x < alpha, x))
)
upper_idx_expr = hl.argmin(
norm_dpois_expr.map(lambda x: hl.or_missing(x > 1 - alpha, x))
)
return hl.struct(
lower=hl.if_else(obs_expr > 0, range_expr[lower_idx_expr], 0),
upper=range_expr[upper_idx_expr],
)


def calculate_raw_z_score(
obs_expr: hl.expr.Int64Expression,
exp_expr: hl.expr.Float64Expression,
) -> hl.expr.StructExpression:
"""
Compute the signed raw z-score using observed and expected variant counts.

The raw z-scores are positive when the transcript had fewer variants than expected,
and are negative when transcripts had more variants than expected.

:param obs_expr: Observed variant count expression.
:param exp_expr: Expected variant count expression.
:return: StructExpression for the raw z-score.
"""
chisq_expr = divide_null((obs_expr - exp_expr) ** 2, exp_expr)
return hl.sqrt(chisq_expr) * hl.if_else(obs_expr > exp_expr, -1, 1)


def get_constraint_flags(
exp_expr: hl.expr.Float64Expression,
raw_z_expr: hl.expr.Float64Expression,
raw_z_lower_threshold: Optional[float] = -5.0,
raw_z_upper_threshold: Optional[float] = 5.0,
flag_postfix: str = "",
) -> Dict[str, hl.expr.Expression]:
"""
Determine the constraint flags that define why constraint will not be calculated.

Flags which are added:
- "no_exp_{flag_postfix}" - for genes that have missing or zero expected variants.
- "outlier_{flag_postfix}" - for genes that are raw z-score outliers:
(`raw_z_expr` < `raw_z_lower_threshold`) or (`raw_z_expr` >
`raw_z_upper_threshold`).

:param exp_expr: Expression for the expected variant counts of pLoF, missense, or
synonymous variants.
:param raw_z_expr: Expression for the signed raw z-score of pLoF, missense, or
synonymous variants.
:param raw_z_lower_threshold: Lower threshold for the raw z-score. When `raw_z_expr`
is less than this threshold it is considered an 'outlier'. Default is -5.0.
:param raw_z_upper_threshold: Upper threshold for the raw z-score. When `raw_z_expr`
is greater than this threshold it is considered an 'outlier'. Default is 5.0.
:param flag_postfix: Postfix to add to the end of the constraint flag names.
:return: Dictionary containing expressions for constraint flags.
"""
outlier_expr = False
if raw_z_lower_threshold is not None:
outlier_expr |= raw_z_expr < raw_z_lower_threshold
if raw_z_upper_threshold is not None:
outlier_expr |= raw_z_expr > raw_z_upper_threshold

if flag_postfix:
flag_postfix = f"_{flag_postfix}"

constraint_flags = {
f"no_exp{flag_postfix}": hl.or_else(exp_expr <= 0, True),
f"outlier{flag_postfix}": hl.or_else(outlier_expr, False),
}

return constraint_flags


def calculate_raw_z_score_sd(
raw_z_expr: hl.expr.Float64Expression,
flag_expr: hl.expr.StringExpression,
mirror_neg_raw_z: bool = True,
) -> hl.expr.Expression:
"""
Calculate the standard deviation of the raw z-score.

When using `mirror_neg_raw_z` is True, all the negative raw z-scores (defined by
`raw_z_expr`) are combined with those same z-scores multiplied by -1 (to create a
mirrored distribution).

:param raw_z_expr: Expression for the raw z-score.
:param flag_expr: Expression for the constraint flags. z-score will not be
calculated if flags are present.
:param mirror_neg_raw_z: Whether the standard deviation should be computed using a
mirrored distribution of negative `raw_z_expr`.
:return: StructExpression containing standard deviation of the raw z-score and
the z-score.
"""
filter_expr = (hl.len(flag_expr) == 0) & hl.is_defined(raw_z_expr)

if mirror_neg_raw_z:
filter_expr &= raw_z_expr < 0
sd_expr = hl.agg.explode(
lambda x: hl.agg.stats(x), [raw_z_expr, -raw_z_expr]
).stdev
else:
sd_expr = hl.agg.stats(raw_z_expr).stdev

return hl.agg.filter(filter_expr, sd_expr)