diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 13992f8f2..78bd567da 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -3,7 +3,6 @@ from __future__ import annotations from pathlib import Path -from typing import Any import dbldatagen as dg import hail as hl @@ -648,46 +647,3 @@ def sample_data_for_susie_inf() -> list[np.ndarray]: lbf_moments = np.loadtxt("tests/gentropy/data_samples/01_test_lbf_moments.csv") lbf_mle = np.loadtxt("tests/gentropy/data_samples/01_test_lbf_mle.csv") return [ld, z, lbf_moments, lbf_mle] - - -@pytest.fixture() -def sample_data_for_coloc(spark: SparkSession) -> list[Any]: - """Sample data for Coloc tests.""" - overlap_df = spark.read.parquet( - "tests/gentropy/data_samples/coloc_test_data.snappy.parquet" - ) - expected_df = spark.createDataFrame( - [ - { - "h0": 1.3769995397857477e-18, - "h1": 2.937336451601565e-10, - "h2": 8.593226431647826e-12, - "h3": 8.338916748775843e-4, - "h4": 0.9991661080227981, - } - ] - ) - single_snp_coloc = spark.createDataFrame( - [ - { - "leftStudyLocusId": 1, - "rightStudyLocusId": 2, - "chromosome": "1", - "tagVariantId": "snp", - "left_logBF": 10.3, - "right_logBF": 10.5, - } - ] - ) - expected_single_snp_coloc = spark.createDataFrame( - [ - { - "h0": 9.254841951638903e-5, - "h1": 2.7517068829182966e-4, - "h2": 3.3609423764447284e-4, - "h3": 9.254841952564387e-13, - "h4": 0.9992961866536217, - } - ] - ) - return [overlap_df, expected_df, single_snp_coloc, expected_single_snp_coloc] diff --git a/tests/gentropy/data_samples/coloc_test_data.snappy.parquet b/tests/gentropy/data_samples/coloc_test_data.snappy.parquet deleted file mode 100644 index 71b3913eb..000000000 Binary files a/tests/gentropy/data_samples/coloc_test_data.snappy.parquet and /dev/null differ diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index 1f52244be..d311d88ad 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -4,10 +4,12 @@ from typing import Any +import pytest from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.method.colocalisation import Coloc, ECaviar -from pyspark.sql import functions as f +from pandas.testing import assert_frame_equal +from pyspark.sql import SparkSession def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: @@ -15,43 +17,92 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: assert isinstance(Coloc.colocalise(mock_study_locus_overlap), Colocalisation) -def test_coloc_colocalise( - sample_data_for_coloc: list[Any], - threshold: float = 1e-4, -) -> None: - """Compare COLOC results with R implementation, using provided sample dataset from R package (StudyLocusOverlap).""" - test_overlap_df = sample_data_for_coloc[0] - test_overlap = StudyLocusOverlap( - _df=test_overlap_df, _schema=StudyLocusOverlap.get_schema() - ) - test_result = Coloc.colocalise(test_overlap) - expected = sample_data_for_coloc[1] - difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) - for col in difference.columns: - assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 - - -def test_single_snp_coloc( - sample_data_for_coloc: list[Any], +@pytest.mark.parametrize( + "observed_data, expected_data", + [ + # associations with a single overlapping SNP + ( + # observed overlap + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp", + "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, + }, + ], + # expected coloc + [ + { + "h0": 9.254841951638903e-5, + "h1": 2.7517068829182966e-4, + "h2": 3.3609423764447284e-4, + "h3": 9.254841952564387e-13, + "h4": 0.9992961866536217, + }, + ], + ), + # associations with multiple overlapping SNPs + ( + # observed overlap + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, + }, + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": {"left_logBF": 10.3, "right_logBF": 10.5}, + }, + ], + # expected coloc + [ + { + "h0": 4.6230151407950416e-5, + "h1": 2.749086942648107e-4, + "h2": 3.357742374172504e-4, + "h3": 9.983447421747411e-4, + "h4": 0.9983447421747356, + }, + ], + ), + ], +) +def test_coloc_semantic( + spark: SparkSession, + observed_data: list[Any], + expected_data: list[Any], threshold: float = 1e-5, ) -> None: - """Test edge case of coloc where only one causal SNP is present in the StudyLocusOverlap.""" - test_overlap_df = sample_data_for_coloc[2] - test_overlap = StudyLocusOverlap( - _df=test_overlap_df.select( - "leftStudyLocusId", - "rightStudyLocusId", - "chromosome", - "tagVariantId", - f.struct(f.col("left_logBF"), f.col("right_logBF")).alias("statistics"), - ), + """Test our COLOC with the implementation in R.""" + observed_overlap = StudyLocusOverlap( + _df=spark.createDataFrame(observed_data, schema=StudyLocusOverlap.get_schema()), _schema=StudyLocusOverlap.get_schema(), ) - test_result = Coloc.colocalise(test_overlap) - expected = sample_data_for_coloc[3] - difference = test_result.df.select("h0", "h1", "h2", "h3", "h4").subtract(expected) - for col in difference.columns: - assert difference.filter(f.abs(f.col(col)) > threshold).count() == 0 + observed_coloc_pdf = ( + Coloc.colocalise(observed_overlap) + .df.select("h0", "h1", "h2", "h3", "h4") + .toPandas() + ) + expected_coloc_pdf = ( + spark.createDataFrame(expected_data) + .select("h0", "h1", "h2", "h3", "h4") + .toPandas() + ) + + assert_frame_equal( + observed_coloc_pdf, + expected_coloc_pdf, + check_exact=False, + check_dtype=True, + ) def test_ecaviar(mock_study_locus_overlap: StudyLocusOverlap) -> None: