Skip to content

Commit

Permalink
feat: add the step class for fine-mapping (#554)
Browse files Browse the repository at this point in the history
* feat: add the step class for fine-mapping

* test: adding test for susie to studylocus converter

* chore: fix the class description

* chore: answering comments

---------

Co-authored-by: Daniel Suveges <daniel.suveges@protonmail.com>
  • Loading branch information
addramir and DSuveges authored Apr 2, 2024
1 parent 255c42d commit d76ebbe
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
140 changes: 140 additions & 0 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Step to run a finemapping using."""

from __future__ import annotations

from typing import Any

import numpy as np
import pyspark.sql.functions as f
from pyspark.sql import DataFrame, Window

from gentropy.common.session import Session
from gentropy.dataset.study_locus import StudyLocus


class SusieFineMapperStep:
"""SuSie finemaping. It has generic methods to run SuSie fine mapping for a study locus.
This class/step is the temporary solution of the fine-mapping warpper for the development purposes.
In the future this step will be refactored and moved to the methods module.
"""

@staticmethod
def susie_inf_to_studylocus(
susie_output: dict[str, Any],
session: Session,
_studyId: str,
_region: str,
variant_index: DataFrame,
cs_lbf_thr: float = 2,
) -> StudyLocus:
"""Convert SuSiE-inf output to studyLocus DataFrame.
Args:
susie_output (dict[str, Any]): SuSiE-inf output dictionary
session (Session): Spark session
_studyId (str): study ID
_region (str): region
variant_index (DataFrame): DataFrame with variant information
cs_lbf_thr (float): credible set logBF threshold, default is 2
Returns:
StudyLocus: StudyLocus object with fine-mapped credible sets
"""
variants = np.array(
[row["variantId"] for row in variant_index.select("variantId").collect()]
).reshape(-1, 1)
PIPs = susie_output["PIP"]
lbfs = susie_output["lbf_variable"]
mu = susie_output["mu"]
susie_result = np.hstack((variants, PIPs, lbfs, mu))

L_snps = PIPs.shape[1]

# Extracting credible sets
order_creds = list(enumerate(susie_output["lbf"]))
order_creds.sort(key=lambda x: x[1], reverse=True)
cred_sets = None
counter = 0
for i, cs_lbf_value in order_creds:
if counter > 0 and cs_lbf_value < cs_lbf_thr:
counter += 1
continue
counter += 1
sorted_arr = susie_result[
susie_result[:, i + 1].astype(float).argsort()[::-1]
]
cumsum_arr = np.cumsum(sorted_arr[:, i + 1].astype(float))
filter_row = np.argmax(cumsum_arr >= 0.99)
if filter_row == 0 and cumsum_arr[0] < 0.99:
filter_row = len(cumsum_arr)
filter_row += 1
filtered_arr = sorted_arr[:filter_row]
cred_set = filtered_arr[:, [0, i + 1, i + L_snps + 1, i + 2 * L_snps + 1]]
win = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
cred_set = (
session.spark.createDataFrame(
cred_set.tolist(),
["variantId", "posteriorProbability", "logBF", "beta"],
)
.join(
variant_index.select(
"variantId",
"chromosome",
"position",
),
"variantId",
)
.sort(f.desc("posteriorProbability"))
.withColumn(
"locus",
f.collect_list(
f.struct(
f.col("variantId").cast("string").alias("variantId"),
f.col("posteriorProbability")
.cast("double")
.alias("posteriorProbability"),
f.col("logBF").cast("double").alias("logBF"),
f.col("beta").cast("double").alias("beta"),
)
).over(win),
)
.limit(1)
.withColumns(
{
"studyId": f.lit(_studyId),
"region": f.lit(_region),
"credibleSetIndex": f.lit(counter),
"credibleSetlog10BF": f.lit(cs_lbf_value * 0.4342944819),
"finemappingMethod": f.lit("SuSiE-inf"),
}
)
.withColumn(
"studyLocusId",
StudyLocus.assign_study_locus_id(
f.col("studyId"), f.col("variantId")
),
)
.select(
"studyLocusId",
"studyId",
"region",
"credibleSetIndex",
"locus",
"variantId",
"chromosome",
"position",
"finemappingMethod",
"credibleSetlog10BF",
)
)
if cred_sets is None:
cred_sets = cred_set
else:
cred_sets = cred_sets.unionByName(cred_set)
return StudyLocus(
_df=cred_sets,
_schema=StudyLocus.get_schema(),
)
29 changes: 29 additions & 0 deletions tests/gentropy/method/test_susie_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from __future__ import annotations

import numpy as np
from gentropy.common.session import Session
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.method.susie_inf import SUSIE_inf
from gentropy.susie_finemapper import SusieFineMapperStep


class TestSUSIE_inf:
Expand Down Expand Up @@ -48,3 +52,28 @@ def test_SUSIE_inf_cred(
)
cred = SUSIE_inf.cred_inf(susie_output["PIP"], LD=ld)
assert cred[0] == [5]

def test_SUSIE_inf_convert_to_study_locus(
self: TestSUSIE_inf,
sample_data_for_susie_inf: list[np.ndarray],
sample_summary_statistics: SummaryStatistics,
session: Session,
) -> None:
"""Test of SuSiE-inf credible set generator."""
ld = sample_data_for_susie_inf[0]
z = sample_data_for_susie_inf[1]
susie_output = SUSIE_inf.susie_inf(
z=z,
LD=ld,
est_tausq=False,
)
gwas_df = sample_summary_statistics._df.limit(21)
L1 = SusieFineMapperStep.susie_inf_to_studylocus(
susie_output=susie_output,
session=session,
_studyId="sample_id",
_region="sample_region",
variant_index=gwas_df,
cs_lbf_thr=2,
)
assert isinstance(L1, StudyLocus), "L1 is not an instance of StudyLocus"

0 comments on commit d76ebbe

Please sign in to comment.