Skip to content

Commit

Permalink
feat: add step to generate association data (#888)
Browse files Browse the repository at this point in the history
* feat: add step to generate association data

* fix: evidence input file is json

* feat: changed maximum theoretical harmonic sum formula

* test: add test for calculate_harmonic_sum function

* chore: move calculate_harmonic_sum function to spark_helpers.py

* chore: update import statement
  • Loading branch information
vivienho authored Nov 1, 2024
1 parent fa38ca6 commit b812f67
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/python_api/steps/l2g.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ title: Locus to Gene (L2G)
::: gentropy.l2g.LocusToGeneStep

::: gentropy.l2g.LocusToGeneEvidenceStep

::: gentropy.l2g.LocusToGeneAssociationsStep
35 changes: 35 additions & 0 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,38 @@ def get_struct_field_schema(schema: t.StructType, name: str) -> t.DataType:
if not matching_fields:
raise ValueError("Provided name %s is not present in the schema.", name)
return matching_fields[0].dataType

def calculate_harmonic_sum(input_array: Column) -> Column:
"""Calculate the harmonic sum of an array.
Args:
input_array (Column): input array of doubles
Returns:
Column: column of harmonic sums
Examples:
>>> from pyspark.sql import Row
>>> df = spark.createDataFrame([
... Row([0.3, 0.8, 1.0]),
... Row([0.7, 0.2, 0.9]),
... ], ["input_array"]
... )
>>> df.select("*", calculate_harmonic_sum(f.col("input_array")).alias("harmonic_sum")).show()
+---------------+------------------+
| input_array| harmonic_sum|
+---------------+------------------+
|[0.3, 0.8, 1.0]|0.7502326177269538|
|[0.7, 0.2, 0.9]|0.6674366756805108|
+---------------+------------------+
<BLANKLINE>
"""
return f.aggregate(
f.arrays_zip(
f.sort_array(input_array, False).alias("score"),
f.sequence(f.lit(1), f.size(input_array)).alias("pos")
),
f.lit(0.0),
lambda acc, x: acc
+ x["score"]/f.pow(x["pos"], 2)/f.lit(sum(1 / ((i + 1)**2) for i in range(1000)))
)
14 changes: 14 additions & 0 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,15 @@ class LocusToGeneEvidenceStepConfig(StepConfig):
locus_to_gene_threshold: float = 0.05
_target_: str = "gentropy.l2g.LocusToGeneEvidenceStep"

@dataclass
class LocusToGeneAssociationsStepConfig(StepConfig):
"""Configuration of the locus to gene association step."""

evidence_input_path: str = MISSING
disease_index_path: str = MISSING
direct_associations_output_path: str = MISSING
indirect_associations_output_path: str = MISSING
_target_: str = "gentropy.l2g.LocusToGeneAssociationsStep"

@dataclass
class StudyLocusValidationStepConfig(StepConfig):
Expand Down Expand Up @@ -733,5 +742,10 @@ def register_config() -> None:
name="locus_to_gene_evidence",
node=LocusToGeneEvidenceStepConfig,
)
cs.store(
group="step",
name="locus_to_gene_associations",
node=LocusToGeneAssociationsStepConfig,
)
cs.store(group="step", name="finngen_ukb_meta_ingestion", node=FinngenUkbMetaConfig)
cs.store(group="step", name="credible_set_qc", node=CredibleSetQCStepConfig)
59 changes: 59 additions & 0 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from wandb import login as wandb_login

from gentropy.common.session import Session
from gentropy.common.spark_helpers import calculate_harmonic_sum
from gentropy.common.utils import access_gcp_secret
from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.gene_index import GeneIndex
Expand Down Expand Up @@ -320,3 +321,61 @@ def __init__(
.write.mode(session.write_mode)
.json(evidence_output_path)
)

class LocusToGeneAssociationsStep:
"""Locus to gene associations step."""

def __init__(
self,
session: Session,
evidence_input_path: str,
disease_index_path: str,
direct_associations_output_path: str,
indirect_associations_output_path: str,
) -> None:
"""Create direct and indirect association datasets.
Args:
session (Session): Session object that contains the Spark session
evidence_input_path (str): Path to the L2G evidence input dataset
disease_index_path (str): Path to disease index file
direct_associations_output_path (str): Path to the direct associations output dataset
indirect_associations_output_path (str): Path to the indirect associations output dataset
"""
# Read in the disease index
disease_index = (
session.spark.read.parquet(disease_index_path)
.select(
f.col("id").alias("diseaseId"),
f.explode("ancestors").alias("ancestorDiseaseId")
)
)

# Read in the L2G evidence
disease_target_evidence = (
session.spark.read.json(evidence_input_path)
.select(
f.col("targetFromSourceId").alias("targetId"),
f.col("diseaseFromSourceMappedId").alias("diseaseId"),
f.col("resourceScore")
)
)

# Generate direct assocations and save file
(
disease_target_evidence
.groupBy("targetId", "diseaseId")
.agg(f.collect_set("resourceScore").alias("scores"))
.select("targetId", "diseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"))
.write.mode(session.write_mode).parquet(direct_associations_output_path)
)

# Generate indirect assocations and save file
(
disease_target_evidence
.join(disease_index, on="diseaseId", how="inner")
.groupBy("targetId", "ancestorDiseaseId")
.agg(f.collect_set("resourceScore").alias("scores"))
.select("targetId", "ancestorDiseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"))
.write.mode(session.write_mode).parquet(indirect_associations_output_path)
)

0 comments on commit b812f67

Please sign in to comment.