Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add step to generate association data #888

Merged
merged 8 commits into from
Nov 1, 2024
2 changes: 2 additions & 0 deletions docs/python_api/steps/l2g.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ title: Locus to Gene (L2G)
::: gentropy.l2g.LocusToGeneStep

::: gentropy.l2g.LocusToGeneEvidenceStep

::: gentropy.l2g.LocusToGeneAssociationsStep
19 changes: 19 additions & 0 deletions src/gentropy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,22 @@ def copy_to_gcs(source_path: str, destination_blob: str) -> None:
bucket = client.bucket(bucket_name=urlparse(destination_blob).hostname)
blob = bucket.blob(blob_name=urlparse(destination_blob).path.lstrip("/"))
blob.upload_from_filename(source_path)

def calculate_harmonic_sum(input_array: Column) -> Column:
"""Calculate the harmonic sum of an array.

Args:
input_array (Column): input array of doubles

Returns:
Column: column of harmonic sums
"""
return f.aggregate(
f.arrays_zip(
f.sort_array(input_array, False).alias("score"),
f.sequence(f.lit(1), f.size(input_array)).alias("pos")
),
f.lit(0.0),
lambda acc, x: acc
+ x["score"]/f.pow(x["pos"], 2)/f.lit(sum(1 / ((i + 1)**2) for i in range(100)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only first 100 are used?
And It should be a devision by 1.644 somewhere...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That part represents the division by ~1.644
I initially used first 1000 (1.6439..) but changed it to 100 (1.6349..)
Should I just change it to f.lit(1.644) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, misread it. All is fine. But I would use 1000 (to be consistent with the platform if documentation is correct)

)
14 changes: 14 additions & 0 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,15 @@ class LocusToGeneEvidenceStepConfig(StepConfig):
locus_to_gene_threshold: float = 0.05
_target_: str = "gentropy.l2g.LocusToGeneEvidenceStep"

@dataclass
class LocusToGeneAssociationsStepConfig(StepConfig):
"""Configuration of the locus to gene association step."""

evidence_input_path: str = MISSING
disease_index_path: str = MISSING
direct_associations_output_path: str = MISSING
indirect_associations_output_path: str = MISSING
_target_: str = "gentropy.l2g.LocusToGeneAssociationsStep"

@dataclass
class StudyLocusValidationStepConfig(StepConfig):
Expand Down Expand Up @@ -733,5 +742,10 @@ def register_config() -> None:
name="locus_to_gene_evidence",
node=LocusToGeneEvidenceStepConfig,
)
cs.store(
group="step",
name="locus_to_gene_associations",
node=LocusToGeneAssociationsStepConfig,
)
cs.store(group="step", name="finngen_ukb_meta_ingestion", node=FinngenUkbMetaConfig)
cs.store(group="step", name="credible_set_qc", node=CredibleSetQCStepConfig)
60 changes: 59 additions & 1 deletion src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from wandb import login as wandb_login

from gentropy.common.session import Session
from gentropy.common.utils import access_gcp_secret
from gentropy.common.utils import access_gcp_secret, calculate_harmonic_sum
from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
Expand Down Expand Up @@ -320,3 +320,61 @@ def __init__(
.write.mode(session.write_mode)
.json(evidence_output_path)
)

class LocusToGeneAssociationsStep:
"""Locus to gene associations step."""

def __init__(
self,
session: Session,
evidence_input_path: str,
disease_index_path: str,
direct_associations_output_path: str,
indirect_associations_output_path: str,
) -> None:
"""Create direct and indirect association datasets.

Args:
session (Session): Session object that contains the Spark session
evidence_input_path (str): Path to the L2G evidence input dataset
disease_index_path (str): Path to disease index file
direct_associations_output_path (str): Path to the direct associations output dataset
indirect_associations_output_path (str): Path to the indirect associations output dataset
"""
# Read in the disease index
disease_index = (
session.spark.read.parquet(disease_index_path)
.select(
f.col("id").alias("diseaseId"),
f.explode("ancestors").alias("ancestorDiseaseId")
)
)

# Read in the L2G evidence
disease_target_evidence = (
session.spark.read.json(evidence_input_path)
.select(
f.col("targetFromSourceId").alias("targetId"),
f.col("diseaseFromSourceMappedId").alias("diseaseId"),
f.col("resourceScore")
)
)

# Generate direct assocations and save file
(
disease_target_evidence
.groupBy("targetId", "diseaseId")
.agg(f.collect_set("resourceScore").alias("scores"))
.select("targetId", "diseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"))
.write.mode(session.write_mode).parquet(direct_associations_output_path)
)

# Generate indirect assocations and save file
(
disease_target_evidence
.join(disease_index, on="diseaseId", how="inner")
.groupBy("targetId", "ancestorDiseaseId")
.agg(f.collect_set("resourceScore").alias("scores"))
.select("targetId", "ancestorDiseaseId", calculate_harmonic_sum(f.col("scores")).alias("harmonicSum"))
.write.mode(session.write_mode).parquet(indirect_associations_output_path)
)
Loading