diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index 9ef536b4d..44778311e 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -326,6 +326,25 @@ def filter_credible_set( ) return self + @staticmethod + def filter_ld_set(ld_set: Column, r2_threshold: float) -> Column: + """Filter the LD set by a given R2 threshold. + + Args: + ld_set (Column): LD set + r2_threshold (float): R2 threshold to filter the LD set on + + Returns: + Column: Filtered LD index + """ + return f.when( + ld_set.isNotNull(), + f.filter( + ld_set, + lambda tag: tag["r2Overall"] >= r2_threshold, + ), + ) + def find_overlaps( self: StudyLocus, study_index: StudyIndex, intra_study_overlap: bool = False ) -> StudyLocusOverlap: @@ -524,20 +543,24 @@ def annotate_locus_statistics( return self def annotate_ld( - self: StudyLocus, study_index: StudyIndex, ld_index: LDIndex + self: StudyLocus, + study_index: StudyIndex, + ld_index: LDIndex, + r2_threshold: float = 0.0, ) -> StudyLocus: """Annotate LD information to study-locus. Args: study_index (StudyIndex): Study index to resolve ancestries. ld_index (LDIndex): LD index to resolve LD information. + r2_threshold (float): R2 threshold to filter the LD index. Default is 0.0. Returns: StudyLocus: Study locus annotated with ld information from LD index. """ from gentropy.method.ld import LDAnnotator - return LDAnnotator.ld_annotate(self, study_index, ld_index) + return LDAnnotator.ld_annotate(self, study_index, ld_index, r2_threshold) def clump(self: StudyLocus) -> StudyLocus: """Perform LD clumping of the studyLocus. diff --git a/src/gentropy/method/ld.py b/src/gentropy/method/ld.py index f0eab7c4b..68b78b103 100644 --- a/src/gentropy/method/ld.py +++ b/src/gentropy/method/ld.py @@ -1,4 +1,5 @@ """Performing linkage disequilibrium (LD) operations.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -120,6 +121,7 @@ def ld_annotate( associations: StudyLocus, studies: StudyIndex, ld_index: LDIndex, + r2_threshold: float = 0.5, ) -> StudyLocus: """Annotate linkage disequilibrium (LD) information to a set of studyLocus. @@ -131,10 +133,14 @@ def ld_annotate( 5. Flags associations with variants that are not found in the LD reference 6. Rescues lead variant when no LD information is available but lead variant is available + !!! note + Because the LD index has a pre-set threshold of R2 = 0.5, this is the minimum threshold for the LD information to be included in the ldSet. + Args: associations (StudyLocus): Dataset to be LD annotated studies (StudyIndex): Dataset with study information ld_index (LDIndex): Dataset with LD information for every variant present in LD matrix + r2_threshold (float): R2 threshold to filter the LD set on. Default is 0.5. Returns: StudyLocus: including additional column with LD information. @@ -175,6 +181,12 @@ def ld_annotate( ), ) .drop("ldPopulationStructure") + # Filter the LD set by the R2 threshold and set to null if no LD information passes the threshold + .withColumn( + "ldSet", + StudyLocus.filter_ld_set(f.col("ldSet"), r2_threshold), + ) + .withColumn("ldSet", f.when(f.size("ldSet") > 0, f.col("ldSet"))) # QC: Flag associations with variants that are not found in the LD reference .withColumn( "qualityControls", diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index 1401b9dd3..772e49742 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -11,7 +11,7 @@ from gentropy.dataset.study_locus import CredibleInterval, StudyLocus from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.summary_statistics import SummaryStatistics -from pyspark.sql import Column, SparkSession +from pyspark.sql import Column, Row, SparkSession from pyspark.sql.types import ( ArrayType, BooleanType, @@ -23,11 +23,6 @@ ) -def test_study_locus_creation(mock_study_locus: StudyLocus) -> None: - """Test study locus creation with mock data.""" - assert isinstance(mock_study_locus, StudyLocus) - - @pytest.mark.parametrize( "has_overlap, expected", [ @@ -531,3 +526,17 @@ def test_ldannotate( assert isinstance( mock_study_locus.annotate_ld(mock_study_index, mock_ld_index), StudyLocus ) + + +def test_filter_ld_set(spark: SparkSession) -> None: + """Test filter_ld_set.""" + observed_data = [ + Row(studyLocusId="sl1", ldSet=[{"tagVariantId": "tag1", "r2Overall": 0.4}]) + ] + observed_df = spark.createDataFrame( + observed_data, ["studyLocusId", "ldSet"] + ).withColumn("ldSet", StudyLocus.filter_ld_set(f.col("ldSet"), 0.5)) + expected_tags_in_ld = 0 + assert ( + observed_df.filter(f.size("ldSet") > 1).count() == expected_tags_in_ld + ), "Expected tags in ld set differ from observed."