Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the step class for fine-mapping #554

Merged
merged 8 commits into from
Apr 2, 2024
Merged
140 changes: 140 additions & 0 deletions src/gentropy/susie_finemapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Step to run a finemapping using."""

from __future__ import annotations

from typing import Any

import numpy as np
import pyspark.sql.functions as f
from pyspark.sql import DataFrame, Window

from gentropy.common.session import Session
from gentropy.dataset.study_locus import StudyLocus


class SusieFineMapperStep:
"""SuSie finemaping. It has generic methods to run SuSie fine mapping for a study locus.

This class/step is the temporary solution of the fine-mapping warpper for the development purposes.
In the future this step will be refactored and moved to the methods module.
"""

@staticmethod
def susie_inf_to_studylocus(
susie_output: dict[str, Any],
session: Session,
_studyId: str,
_region: str,
variant_index: DataFrame,
cs_lbf_thr: float = 2,
) -> StudyLocus:
"""Convert SuSiE-inf output to studyLocus DataFrame.

Args:
susie_output (dict[str, Any]): SuSiE-inf output dictionary
session (Session): Spark session
_studyId (str): study ID
_region (str): region
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a consisted way of representing regions? If so, in the args description could be written eg.:

_region (str): finemapped region in chr:start-end format

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't have it now. But agree, we need to think about the standard.

variant_index (DataFrame): DataFrame with variant information
cs_lbf_thr (float): credible set logBF threshold, default is 2

Returns:
StudyLocus: StudyLocus object with fine-mapped credible sets
"""
variants = np.array(
[row["variantId"] for row in variant_index.select("variantId").collect()]
).reshape(-1, 1)
PIPs = susie_output["PIP"]
lbfs = susie_output["lbf_variable"]
mu = susie_output["mu"]
susie_result = np.hstack((variants, PIPs, lbfs, mu))

L_snps = PIPs.shape[1]

# Extracting credible sets
order_creds = list(enumerate(susie_output["lbf"]))
order_creds.sort(key=lambda x: x[1], reverse=True)
cred_sets = None
counter = 0
for i, cs_lbf_value in order_creds:
if counter > 0 and cs_lbf_value < cs_lbf_thr:
counter += 1
continue
counter += 1
sorted_arr = susie_result[
susie_result[:, i + 1].astype(float).argsort()[::-1]
]
cumsum_arr = np.cumsum(sorted_arr[:, i + 1].astype(float))
filter_row = np.argmax(cumsum_arr >= 0.99)
if filter_row == 0 and cumsum_arr[0] < 0.99:
filter_row = len(cumsum_arr)
filter_row += 1
filtered_arr = sorted_arr[:filter_row]
cred_set = filtered_arr[:, [0, i + 1, i + L_snps + 1, i + 2 * L_snps + 1]]
win = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
cred_set = (
session.spark.createDataFrame(
cred_set.tolist(),
["variantId", "posteriorProbability", "logBF", "beta"],
)
.join(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this join the mode is inner by default. Is it expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be 1 to 1 the same size and order

variant_index.select(
"variantId",
"chromosome",
"position",
),
"variantId",
)
.sort(f.desc("posteriorProbability"))
.withColumn(
"locus",
f.collect_list(
f.struct(
f.col("variantId").cast("string").alias("variantId"),
f.col("posteriorProbability")
.cast("double")
.alias("posteriorProbability"),
f.col("logBF").cast("double").alias("logBF"),
f.col("beta").cast("double").alias("beta"),
)
).over(win),
)
.limit(1)
.withColumns(
{
"studyId": f.lit(_studyId),
"region": f.lit(_region),
"credibleSetIndex": f.lit(counter),
"credibleSetlog10BF": f.lit(cs_lbf_value * 0.4342944819),
"finemappingMethod": f.lit("SuSiE-inf"),
}
)
.withColumn(
"studyLocusId",
StudyLocus.assign_study_locus_id(
f.col("studyId"), f.col("variantId")
),
)
.select(
"studyLocusId",
"studyId",
"region",
"credibleSetIndex",
"locus",
"variantId",
"chromosome",
"position",
"finemappingMethod",
"credibleSetlog10BF",
)
)
if cred_sets is None:
cred_sets = cred_set
else:
cred_sets = cred_sets.unionByName(cred_set)
return StudyLocus(
_df=cred_sets,
_schema=StudyLocus.get_schema(),
)
29 changes: 29 additions & 0 deletions tests/gentropy/method/test_susie_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from __future__ import annotations

import numpy as np
from gentropy.common.session import Session
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.method.susie_inf import SUSIE_inf
from gentropy.susie_finemapper import SusieFineMapperStep


class TestSUSIE_inf:
Expand Down Expand Up @@ -48,3 +52,28 @@ def test_SUSIE_inf_cred(
)
cred = SUSIE_inf.cred_inf(susie_output["PIP"], LD=ld)
assert cred[0] == [5]

def test_SUSIE_inf_convert_to_study_locus(
self: TestSUSIE_inf,
sample_data_for_susie_inf: list[np.ndarray],
sample_summary_statistics: SummaryStatistics,
session: Session,
) -> None:
"""Test of SuSiE-inf credible set generator."""
ld = sample_data_for_susie_inf[0]
z = sample_data_for_susie_inf[1]
susie_output = SUSIE_inf.susie_inf(
z=z,
LD=ld,
est_tausq=False,
)
gwas_df = sample_summary_statistics._df.limit(21)
L1 = SusieFineMapperStep.susie_inf_to_studylocus(
susie_output=susie_output,
session=session,
_studyId="sample_id",
_region="sample_region",
variant_index=gwas_df,
cs_lbf_thr=2,
)
assert isinstance(L1, StudyLocus), "L1 is not an instance of StudyLocus"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how does the test dataset look like, it would be great to assert that the number of credible set is what you are expecting, and validate if the locus object is healthy. However I understand if that is not a high priority for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can create more meaningful test for bigger function that will use this convertor on later stages.