diff --git a/docs/python_api/steps/l2g.md b/docs/python_api/steps/l2g.md index 5594f1605..e6aeb0ebb 100644 --- a/docs/python_api/steps/l2g.md +++ b/docs/python_api/steps/l2g.md @@ -7,3 +7,5 @@ title: Locus to Gene (L2G) ::: gentropy.l2g.LocusToGeneStep ::: gentropy.l2g.LocusToGeneEvidenceStep + +::: gentropy.l2g.LocusToGeneAssociationsStep diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 4e40ac4f1..a1bf9670a 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -847,3 +847,38 @@ def get_struct_field_schema(schema: t.StructType, name: str) -> t.DataType: if not matching_fields: raise ValueError("Provided name %s is not present in the schema.", name) return matching_fields[0].dataType + +def calculate_harmonic_sum(input_array: Column) -> Column: + """Calculate the harmonic sum of an array. + + Args: + input_array (Column): input array of doubles + + Returns: + Column: column of harmonic sums + + Examples: + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([ + ... Row([0.3, 0.8, 1.0]), + ... Row([0.7, 0.2, 0.9]), + ... ], ["input_array"] + ... ) + >>> df.select("*", calculate_harmonic_sum(f.col("input_array")).alias("harmonic_sum")).show() + +---------------+------------------+ + | input_array| harmonic_sum| + +---------------+------------------+ + |[0.3, 0.8, 1.0]|0.7502326177269538| + |[0.7, 0.2, 0.9]|0.6674366756805108| + +---------------+------------------+ + + """ + return f.aggregate( + f.arrays_zip( + f.sort_array(input_array, False).alias("score"), + f.sequence(f.lit(1), f.size(input_array)).alias("pos") + ), + f.lit(0.0), + lambda acc, x: acc + + x["score"]/f.pow(x["pos"], 2)/f.lit(sum(1 / ((i + 1)**2) for i in range(1000))) + ) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index c5889dbab..e9bf26f31 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -632,6 +632,15 @@ class LocusToGeneEvidenceStepConfig(StepConfig): locus_to_gene_threshold: float = 0.05 _target_: str = "gentropy.l2g.LocusToGeneEvidenceStep" +@dataclass +class LocusToGeneAssociationsStepConfig(StepConfig): + """Configuration of the locus to gene association step.""" + + evidence_input_path: str = MISSING + disease_index_path: str = MISSING + direct_associations_output_path: str = MISSING + indirect_associations_output_path: str = MISSING + _target_: str = "gentropy.l2g.LocusToGeneAssociationsStep" @dataclass class StudyLocusValidationStepConfig(StepConfig): @@ -733,5 +742,10 @@ def register_config() -> None: name="locus_to_gene_evidence", node=LocusToGeneEvidenceStepConfig, ) + cs.store( + group="step", + name="locus_to_gene_associations", + node=LocusToGeneAssociationsStepConfig, + ) cs.store(group="step", name="finngen_ukb_meta_ingestion", node=FinngenUkbMetaConfig) cs.store(group="step", name="credible_set_qc", node=CredibleSetQCStepConfig) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index ca52fbf04..1004fa0fb 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -9,6 +9,7 @@ from wandb import login as wandb_login from gentropy.common.session import Session +from gentropy.common.spark_helpers import calculate_harmonic_sum from gentropy.common.utils import access_gcp_secret from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.gene_index import GeneIndex @@ -320,3 +321,61 @@ def __init__( .write.mode(session.write_mode) .json(evidence_output_path) ) + +class LocusToGeneAssociationsStep: + """Locus to gene associations step.""" + + def __init__( + self, + session: Session, + evidence_input_path: str, + disease_index_path: str, + direct_associations_output_path: str, + indirect_associations_output_path: str, + ) -> None: + """Create direct and indirect association datasets. + + Args: + session (Session): Session object that contains the Spark session + evidence_input_path (str): Path to the L2G evidence input dataset + disease_index_path (str): Path to disease index file + direct_associations_output_path (str): Path to the direct associations output dataset + indirect_associations_output_path (str): Path to the indirect associations output dataset + """ + # Read in the disease index + disease_index = ( + session.spark.read.parquet(disease_index_path) + .select( + f.col("id").alias("diseaseId"), + f.explode("ancestors").alias("ancestorDiseaseId") + ) + ) + + # Read in the L2G evidence + disease_target_evidence = ( + session.spark.read.json(evidence_input_path) + .select( + f.col("targetFromSourceId").alias("targetId"), + f.col("diseaseFromSourceMappedId").alias("diseaseId"), + f.col("resourceScore") + ) + ) + + # Generate direct assocations and save file + ( + disease_target_evidence + .groupBy("targetId", "diseaseId") + .agg(f.collect_set("resourceScore").alias("scores")) + .select("targetId", "diseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum")) + .write.mode(session.write_mode).parquet(direct_associations_output_path) + ) + + # Generate indirect assocations and save file + ( + disease_target_evidence + .join(disease_index, on="diseaseId", how="inner") + .groupBy("targetId", "ancestorDiseaseId") + .agg(f.collect_set("resourceScore").alias("scores")) + .select("targetId", "ancestorDiseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum")) + .write.mode(session.write_mode).parquet(indirect_associations_output_path) + )