Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregations to variant QC evaluation for additional plots #609

Merged
merged 5 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions gnomad/variant_qc/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ def compute_ranked_bin(
_rand=hl.rand_unif(0, 1),
)

# Checkpoint bin Table prior to variant count aggregation.
bin_ht = bin_ht.checkpoint(hl.utils.new_temp_file("bin", "ht"))

# Compute variant counts per group defined by bin_expr. This is used to determine
# bin assignment.
bin_group_variant_counts = bin_ht.aggregate(
hl.Struct(
**{
bin_id: hl.agg.filter(
bin_ht[f"_filter_{bin_id}"],
hl.agg.count(),
)
for bin_id in bin_expr
}
)
)

logger.info(
"Sorting the HT by score_expr followed by a random float between 0 and 1. "
"Then adding a row index per grouping defined by bin_expr..."
Expand All @@ -97,22 +114,6 @@ def compute_ranked_bin(
)
bin_ht = bin_ht.key_by("locus", "alleles")

# Annotate globals with variant counts per group defined by bin_expr. This
# is used to determine bin assignment
bin_ht = bin_ht.annotate_globals(
bin_group_variant_counts=bin_ht.aggregate(
hl.Struct(
**{
bin_id: hl.agg.filter(
bin_ht[f"_filter_{bin_id}"],
hl.agg.count(),
)
for bin_id in bin_expr
}
)
)
)

logger.info("Binning ranked rows into %d bins...", n_bins)
bin_ht = bin_ht.select(
"snv",
Expand All @@ -123,7 +124,7 @@ def compute_ranked_bin(
n_bins
* (
bin_ht[f"{bin_id}_rank"]
/ hl.float64(bin_ht.bin_group_variant_counts[bin_id])
/ hl.float64(bin_group_variant_counts[bin_id])
)
)
+ 1
Expand All @@ -143,20 +144,18 @@ def compute_ranked_bin(
# in bin names in the table
if compute_snv_indel_separately:
bin_expr_no_snv = {
bin_id.rsplit("_", 1)[0] for bin_id in bin_ht.bin_group_variant_counts
bin_id.rsplit("_", 1)[0] for bin_id in bin_group_variant_counts
}
bin_ht = bin_ht.annotate_globals(
bin_group_variant_counts=hl.struct(
**{
bin_id: hl.struct(
**{
snv: bin_ht.bin_group_variant_counts[f"{bin_id}_{snv}"]
for snv in ["snv", "indel"]
}
)
for bin_id in bin_expr_no_snv
}
)
bin_group_variant_counts = hl.struct(
**{
bin_id: hl.struct(
**{
snv: bin_group_variant_counts[f"{bin_id}_{snv}"]
for snv in ["snv", "indel"]
}
)
for bin_id in bin_expr_no_snv
}
)

bin_ht = bin_ht.transmute(
Expand All @@ -170,6 +169,8 @@ def compute_ranked_bin(
}
)

bin_ht = bin_ht.annotate_globals(bin_group_variant_counts=bin_group_variant_counts)

return bin_ht


Expand Down Expand Up @@ -264,7 +265,7 @@ def compute_binned_truth_sample_concordance(
score=indexed_binned_score_ht.score,
global_bin=indexed_binned_score_ht.bin,
)

ht = ht.checkpoint(hl.utils.new_temp_file("pre_bin", "ht"))
# Annotate the truth sample bin
bin_ht = compute_ranked_bin(
ht,
Expand Down
94 changes: 63 additions & 31 deletions gnomad/variant_qc/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# noqa: D100

import logging
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import hail as hl
import pyspark.sql
Expand Down Expand Up @@ -69,10 +69,11 @@ def create_binned_ht(
:return: table with bin number for each variant
"""

def _update_bin_expr(
bin_expr: Dict[str, hl.expr.BooleanExpression],
def _new_bin_expr(
bin_expr: Union[Dict[str, hl.expr.BooleanExpression], Dict[str, bool]],
new_expr: hl.expr.BooleanExpression,
new_id: str,
update: bool = False,
) -> Dict[str, hl.expr.BooleanExpression]:
"""
Update a dictionary of expressions to add another stratification.
Expand All @@ -83,29 +84,33 @@ def _update_bin_expr(
:return: Dictionary of `bin_expr` updated with `new_expr` added as an additional stratification to all
expressions already in `bin_expr`
"""
bin_expr.update(
{
f"{new_id}_{bin_id}": bin_expr & new_expr
for bin_id, bin_expr in bin_expr.items()
}
)
return bin_expr
new_bin_expr = {
f"{new_id}_{bin_id}": bin_expr & new_expr
for bin_id, bin_expr in bin_expr.items()
}
if update:
bin_expr.update(new_bin_expr)
return bin_expr
else:
return new_bin_expr

# Desired bins and sub-bins
bin_expr = {"bin": True}

if singleton:
bin_expr = _update_bin_expr(bin_expr, ht.ac_raw == 1, "singleton")
bin_expr = _new_bin_expr(bin_expr, ht.ac_raw == 1, "singleton", update=True)

if biallelic:
bin_expr = _update_bin_expr(bin_expr, ~ht.was_split, "biallelic")
bin_expr = _new_bin_expr(bin_expr, ~ht.was_split, "biallelic", update=True)

if adj:
bin_expr = _update_bin_expr(bin_expr, (ht.ac > 0), "adj")
bin_expr = _new_bin_expr(bin_expr, (ht.ac > 0), "adj", update=True)

if add_substrat:
if add_substrat is not None:
new_bin_expr = {}
for add_id, add_expr in add_substrat.items():
bin_expr = _update_bin_expr(bin_expr, add_expr, add_id)
new_bin_expr.update(_new_bin_expr(bin_expr, add_expr, add_id))
bin_expr.update(new_bin_expr)

bin_ht = compute_ranked_bin(
ht, score_expr=ht.score, bin_expr=bin_expr, n_bins=n_bins
Expand Down Expand Up @@ -223,22 +228,39 @@ def score_bin_agg(
"Either 'fail_hard_filters' or 'info' must be present in the input Table!"
)

ins_expr = hl.is_insertion(ht.alleles[0], ht.alleles[1])
del_expr = hl.is_deletion(ht.alleles[0], ht.alleles[1])
indel_1bp_expr = indel_length == 1
count_where_expr = {
"n_ins": ins_expr,
"n_del": del_expr,
"n_ti": hl.is_transition(ht.alleles[0], ht.alleles[1]),
"n_tv": hl.is_transversion(ht.alleles[0], ht.alleles[1]),
"n_1bp_indel": indel_1bp_expr,
"n_1bp_ins": ins_expr & indel_1bp_expr,
"n_2bp_ins": ins_expr & (indel_length == 2),
"n_3bp_ins": ins_expr & (indel_length == 3),
"n_1bp_del": del_expr & indel_1bp_expr,
"n_2bp_del": del_expr & (indel_length == 2),
"n_3bp_del": del_expr & (indel_length == 3),
"n_mod3bp_indel": (indel_length % 3) == 0,
"n_singleton": ht.singleton,
"fail_hard_filters": fail_hard_filters_expr,
"n_pos_train": ht.positive_train_site,
"n_neg_train": ht.negative_train_site,
"n_clinvar": hl.is_defined(clinvar),
"n_clinvar_path": hl.is_defined(clinvar_path),
"n_omni": truth_data.omni,
"n_mills": truth_data.mills,
"n_hapmap": truth_data.hapmap,
"n_kgp_phase1_hc": truth_data.kgp_phase1_hc,
}

return dict(
min_score=hl.agg.min(ht.score),
max_score=hl.agg.max(ht.score),
n=hl.agg.count(),
n_ins=hl.agg.count_where(hl.is_insertion(ht.alleles[0], ht.alleles[1])),
n_del=hl.agg.count_where(hl.is_deletion(ht.alleles[0], ht.alleles[1])),
n_ti=hl.agg.count_where(hl.is_transition(ht.alleles[0], ht.alleles[1])),
n_tv=hl.agg.count_where(hl.is_transversion(ht.alleles[0], ht.alleles[1])),
n_1bp_indel=hl.agg.count_where(indel_length == 1),
n_mod3bp_indel=hl.agg.count_where((indel_length % 3) == 0),
n_singleton=hl.agg.count_where(ht.singleton),
fail_hard_filters=hl.agg.count_where(fail_hard_filters_expr),
n_pos_train=hl.agg.count_where(ht.positive_train_site),
n_neg_train=hl.agg.count_where(ht.negative_train_site),
n_clinvar=hl.agg.count_where(hl.is_defined(clinvar)),
n_clinvar_path=hl.agg.count_where(hl.is_defined(clinvar_path)),
**{k: hl.agg.count_where(v) for k, v in count_where_expr.items()},
n_de_novos_singleton_adj=hl.agg.filter(
ht.ac == 1, hl.agg.sum(fam.n_de_novos_adj)
),
Expand All @@ -247,6 +269,20 @@ def score_bin_agg(
),
n_de_novos_adj=hl.agg.sum(fam.n_de_novos_adj),
n_de_novo=hl.agg.sum(fam.n_de_novos_raw),
n_de_novos_AF_001_adj=hl.agg.filter(
hl.if_else(
fam.ac_parents_adj == 0, 0.0, fam.ac_parents_adj / fam.an_parents_adj
)
< 0.001,
hl.agg.sum(fam.n_de_novos_adj),
),
n_de_novos_AF_001=hl.agg.filter(
hl.if_else(
fam.ac_parents_raw == 0, 0.0, fam.ac_parents_raw / fam.an_parents_raw
)
< 0.001,
hl.agg.sum(fam.n_de_novos_raw),
),
n_trans_singletons=hl.agg.filter(
ht.ac_raw == 2, hl.agg.sum(fam.n_transmitted_raw)
),
Expand All @@ -257,10 +293,6 @@ def score_bin_agg(
n_train_trans_singletons=hl.agg.filter(
(ht.ac_raw == 2) & ht.positive_train_site, hl.agg.sum(fam.n_transmitted_raw)
),
n_omni=hl.agg.count_where(truth_data.omni),
n_mills=hl.agg.count_where(truth_data.mills),
n_hapmap=hl.agg.count_where(truth_data.hapmap),
n_kgp_phase1_hc=hl.agg.count_where(truth_data.kgp_phase1_hc),
)


Expand Down