From d76ebbef03a130a43d9d89a5daa21c24f9a88907 Mon Sep 17 00:00:00 2001 From: Yakov Date: Tue, 2 Apr 2024 14:52:29 +0100 Subject: [PATCH] feat: add the step class for fine-mapping (#554) * feat: add the step class for fine-mapping * test: adding test for susie to studylocus converter * chore: fix the class description * chore: answering comments --------- Co-authored-by: Daniel Suveges --- src/gentropy/susie_finemapper.py | 140 ++++++++++++++++++++++++ tests/gentropy/method/test_susie_inf.py | 29 +++++ 2 files changed, 169 insertions(+) create mode 100644 src/gentropy/susie_finemapper.py diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py new file mode 100644 index 000000000..bd02716c2 --- /dev/null +++ b/src/gentropy/susie_finemapper.py @@ -0,0 +1,140 @@ +"""Step to run a finemapping using.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pyspark.sql.functions as f +from pyspark.sql import DataFrame, Window + +from gentropy.common.session import Session +from gentropy.dataset.study_locus import StudyLocus + + +class SusieFineMapperStep: + """SuSie finemaping. It has generic methods to run SuSie fine mapping for a study locus. + + This class/step is the temporary solution of the fine-mapping warpper for the development purposes. + In the future this step will be refactored and moved to the methods module. + """ + + @staticmethod + def susie_inf_to_studylocus( + susie_output: dict[str, Any], + session: Session, + _studyId: str, + _region: str, + variant_index: DataFrame, + cs_lbf_thr: float = 2, + ) -> StudyLocus: + """Convert SuSiE-inf output to studyLocus DataFrame. + + Args: + susie_output (dict[str, Any]): SuSiE-inf output dictionary + session (Session): Spark session + _studyId (str): study ID + _region (str): region + variant_index (DataFrame): DataFrame with variant information + cs_lbf_thr (float): credible set logBF threshold, default is 2 + + Returns: + StudyLocus: StudyLocus object with fine-mapped credible sets + """ + variants = np.array( + [row["variantId"] for row in variant_index.select("variantId").collect()] + ).reshape(-1, 1) + PIPs = susie_output["PIP"] + lbfs = susie_output["lbf_variable"] + mu = susie_output["mu"] + susie_result = np.hstack((variants, PIPs, lbfs, mu)) + + L_snps = PIPs.shape[1] + + # Extracting credible sets + order_creds = list(enumerate(susie_output["lbf"])) + order_creds.sort(key=lambda x: x[1], reverse=True) + cred_sets = None + counter = 0 + for i, cs_lbf_value in order_creds: + if counter > 0 and cs_lbf_value < cs_lbf_thr: + counter += 1 + continue + counter += 1 + sorted_arr = susie_result[ + susie_result[:, i + 1].astype(float).argsort()[::-1] + ] + cumsum_arr = np.cumsum(sorted_arr[:, i + 1].astype(float)) + filter_row = np.argmax(cumsum_arr >= 0.99) + if filter_row == 0 and cumsum_arr[0] < 0.99: + filter_row = len(cumsum_arr) + filter_row += 1 + filtered_arr = sorted_arr[:filter_row] + cred_set = filtered_arr[:, [0, i + 1, i + L_snps + 1, i + 2 * L_snps + 1]] + win = Window.rowsBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + cred_set = ( + session.spark.createDataFrame( + cred_set.tolist(), + ["variantId", "posteriorProbability", "logBF", "beta"], + ) + .join( + variant_index.select( + "variantId", + "chromosome", + "position", + ), + "variantId", + ) + .sort(f.desc("posteriorProbability")) + .withColumn( + "locus", + f.collect_list( + f.struct( + f.col("variantId").cast("string").alias("variantId"), + f.col("posteriorProbability") + .cast("double") + .alias("posteriorProbability"), + f.col("logBF").cast("double").alias("logBF"), + f.col("beta").cast("double").alias("beta"), + ) + ).over(win), + ) + .limit(1) + .withColumns( + { + "studyId": f.lit(_studyId), + "region": f.lit(_region), + "credibleSetIndex": f.lit(counter), + "credibleSetlog10BF": f.lit(cs_lbf_value * 0.4342944819), + "finemappingMethod": f.lit("SuSiE-inf"), + } + ) + .withColumn( + "studyLocusId", + StudyLocus.assign_study_locus_id( + f.col("studyId"), f.col("variantId") + ), + ) + .select( + "studyLocusId", + "studyId", + "region", + "credibleSetIndex", + "locus", + "variantId", + "chromosome", + "position", + "finemappingMethod", + "credibleSetlog10BF", + ) + ) + if cred_sets is None: + cred_sets = cred_set + else: + cred_sets = cred_sets.unionByName(cred_set) + return StudyLocus( + _df=cred_sets, + _schema=StudyLocus.get_schema(), + ) diff --git a/tests/gentropy/method/test_susie_inf.py b/tests/gentropy/method/test_susie_inf.py index b671bf274..393f786d7 100644 --- a/tests/gentropy/method/test_susie_inf.py +++ b/tests/gentropy/method/test_susie_inf.py @@ -3,7 +3,11 @@ from __future__ import annotations import numpy as np +from gentropy.common.session import Session +from gentropy.dataset.study_locus import StudyLocus +from gentropy.dataset.summary_statistics import SummaryStatistics from gentropy.method.susie_inf import SUSIE_inf +from gentropy.susie_finemapper import SusieFineMapperStep class TestSUSIE_inf: @@ -48,3 +52,28 @@ def test_SUSIE_inf_cred( ) cred = SUSIE_inf.cred_inf(susie_output["PIP"], LD=ld) assert cred[0] == [5] + + def test_SUSIE_inf_convert_to_study_locus( + self: TestSUSIE_inf, + sample_data_for_susie_inf: list[np.ndarray], + sample_summary_statistics: SummaryStatistics, + session: Session, + ) -> None: + """Test of SuSiE-inf credible set generator.""" + ld = sample_data_for_susie_inf[0] + z = sample_data_for_susie_inf[1] + susie_output = SUSIE_inf.susie_inf( + z=z, + LD=ld, + est_tausq=False, + ) + gwas_df = sample_summary_statistics._df.limit(21) + L1 = SusieFineMapperStep.susie_inf_to_studylocus( + susie_output=susie_output, + session=session, + _studyId="sample_id", + _region="sample_region", + variant_index=gwas_df, + cs_lbf_thr=2, + ) + assert isinstance(L1, StudyLocus), "L1 is not an instance of StudyLocus"