Skip to content

Commit

Permalink
finished persisting cache
Browse files Browse the repository at this point in the history
  • Loading branch information
leokim-l committed Aug 15, 2024
1 parent 112bbe9 commit 1b0a231
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 122 deletions.
32 changes: 25 additions & 7 deletions src/malco/post_process/mondo_score_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
from pathlib import Path

from typing import List
from cachetools import cached, LRUCache
from cachetools.keys import hashkey


FULL_SCORE = 1.0
PARTIAL_SCORE = 0.5



@cached(pc1, info=True, key=lambda term, adapter: hashkey(term))
def omim_mappings(term: str, adapter) -> List[str]:
"""
Get the OMIM mappings for a term.
Expand All @@ -37,8 +34,7 @@ def omim_mappings(term: str, adapter) -> List[str]:
return omims


@cached(pc2, info=True, key=lambda prediction, ground_truth, mondo: hashkey(prediction, ground_truth))
def score_grounded_result(prediction: str, ground_truth: str, mondo) -> float:
def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) -> float:
"""
Score the grounded result.
Expand Down Expand Up @@ -74,14 +70,36 @@ def score_grounded_result(prediction: str, ground_truth: str, mondo) -> float:
# predication is the correct OMIM
return FULL_SCORE

if ground_truth in omim_mappings(prediction, mondo):

ground_truths = get_ground_truth_from_cache_or_compute(prediction, mondo, cache)
if ground_truth in ground_truths:
# prediction is a MONDO that directly maps to a correct OMIM
return FULL_SCORE

descendants_list = mondo.descendants([prediction], predicates=[IS_A], reflexive=True)
for mondo_descendant in descendants_list:
if ground_truth in omim_mappings(mondo_descendant, mondo):
ground_truths = get_ground_truth_from_cache_or_compute(mondo_descendant, mondo, cache)
if ground_truth in ground_truths:
# prediction is a MONDO that maps to a correct OMIM via a descendant
return PARTIAL_SCORE
return 0.0

def get_ground_truth_from_cache_or_compute(
term,
adapter,
cache,
):
if cache is None:
return omim_mappings(term, adapter)

k = hashkey(term)
try:
ground_truths = cache[k]
cache.hits += 1
except KeyError:
# cache miss
ground_truths = omim_mappings(term, adapter)
cache[k] = ground_truths
cache.misses += 1
return ground_truths

127 changes: 13 additions & 114 deletions src/malco/post_process/ranking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,121 +11,17 @@
from oaklib.interfaces import MappingProviderInterface
from oaklib import get_adapter

from cachetools import cached, LRUCache
from malco.post_process.mondo_score_utils import score_grounded_result
from cachetools import LRUCache
from typing import List
from cachetools.keys import hashkey
from shelved_cache import PersistentCache

FULL_SCORE = 1.0
PARTIAL_SCORE = 0.5

#pc1 = {}
#pc2 = {}

#efilename1 = str(output_dir / "omim_mappings_cache")
#filenam2 = str(output_dir / "score_grounded_result_cache")
#filename1 = str("omim_mappings_cache")
# filename2 = str("score_grounded_result_cache")
# global pc1
# global pc2
#pc1 = PersistentCache(LRUCache(maxsize=16384), filename1)
#pc2 = PersistentCache(LRUCache(maxsize=4096), filename2)
#pc1 = PersistentCache(LRUCache, filename1, maxsize=16384)
# pc2 = PersistentCache(LRUCache, filename2, maxsize=4096)

#@cached(pc1, info=True, key=lambda term, adapter: hashkey(term))
def omim_mappings(term: str, adapter) -> List[str]:
"""
Get the OMIM mappings for a term.
Example:
>>> from oaklib import get_adapter
>>> omim_mappings("MONDO:0007566", get_adapter("sqlite:obo:mondo"))
['OMIM:132800']
Args:
term (str): The term.
adapter: The mondo adapter.
Returns:
str: The OMIM mappings.
"""
omims = []
for m in adapter.sssom_mappings([term], source="OMIM"):
if m.predicate_id == "skos:exactMatch":
omims.append(m.object_id)
return omims


# @cached(pc2, info=True, key=lambda prediction, ground_truth, mondo: hashkey(prediction, ground_truth))
def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) -> float:
"""
Score the grounded result.
Exact match:
>>> from oaklib import get_adapter
>>> score_grounded_result("OMIM:132800", "OMIM:132800", get_adapter("sqlite:obo:mondo"))
1.0
The predicted Mondo is equivalent to the ground truth OMIM
(via skos:exactMatches in Mondo):
>>> score_grounded_result("MONDO:0007566", "OMIM:132800", get_adapter("sqlite:obo:mondo"))
1.0
The predicted Mondo is a disease entity that groups multiple
OMIMs, one of which is the ground truth:
>>> score_grounded_result("MONDO:0008029", "OMIM:158810", get_adapter("sqlite:obo:mondo"))
0.5
Args:
prediction (str): The prediction.
ground_truth (str): The ground truth.
mondo: The mondo adapter.
Returns:
float: The score.
"""
if not isinstance(mondo, MappingProviderInterface):
raise ValueError("Adapter is not an MappingProviderInterface")

if prediction == ground_truth:
# predication is the correct OMIM
return FULL_SCORE


ground_truths = get_ground_truth_from_cache_or_compute(prediction, mondo, cache)
if ground_truth in ground_truths:
# prediction is a MONDO that directly maps to a correct OMIM
return FULL_SCORE

descendants_list = mondo.descendants([prediction], predicates=[IS_A], reflexive=True)
for mondo_descendant in descendants_list:
ground_truths = get_ground_truth_from_cache_or_compute(mondo_descendant, mondo, cache)
if ground_truth in ground_truths:
# prediction is a MONDO that maps to a correct OMIM via a descendant
return PARTIAL_SCORE
return 0.0

def get_ground_truth_from_cache_or_compute(
term,
adapter,
cache,
):
if cache is None:
return omim_mappings(term, adapter)

k = hashkey(term)
try:
ground_truths = cache[k]
except KeyError:
# cache miss
ground_truths = omim_mappings(term, adapter)
cache[k] = ground_truths
return ground_truths

def cache_info(self):
return f"CacheInfo: hits={self.hits}, misses={self.misses}, maxsize={self.wrapped.maxsize}, currsize={self.wrapped.currsize}"

def mondo_adapter() -> OboGraphInterface:
"""
Expand All @@ -151,9 +47,12 @@ def compute_mrr_and_ranks(
pc2 = PersistentCache(LRUCache, pc2_cache_file, maxsize=4096)
pc1_cache_file = str(output_dir / "omim_mappings_cache")
pc1 = PersistentCache(LRUCache, pc1_cache_file, maxsize=16384)
pc1.hits = pc1.misses = 0
pc2.hits = pc2.misses = 0
PersistentCache.cache_info = cache_info



for subdir, dirs, files in os.walk(output_dir): # maybe change this so it only looks into multilingual/multimodel? I.e. use that as outputdir...?
for subdir, dirs, files in os.walk(output_dir):
for filename in files:
if filename.startswith("result") and filename.endswith(".tsv"):
file_path = os.path.join(subdir, filename)
Expand Down Expand Up @@ -197,16 +96,17 @@ def compute_mrr_and_ranks(
# Make sure caching is used in the following by unwrapping explicitly
results = []
for idx, row in df.iterrows():
#breakpoint()

# lambda prediction, ground_truth, mondo: hashkey(prediction, ground_truth)
k = hashkey(row['term'], row['correct_term'])
try:
val = pc2[k]
pc2.hits += 1
except KeyError:
# cache miss
val = score_grounded_result(row['term'], row['correct_term'], mondo, pc1)
pc2[k] = val
pc2.misses += 1
is_correct = val > 0
results.append(is_correct)

Expand Down Expand Up @@ -248,12 +148,11 @@ def compute_mrr_and_ranks(
rank_df.loc[i,"n10p"] += 1

# Write cache charatcteristics to file
breakpoint()
cf.write(results_files[i])
cf.write('\nscore_grounded_result cache info:\n')
#cf.write(str(score_grounded_result.cache_info()))
cf.write(str(pc2.cache_info()))
cf.write('\nomim_mappings cache info:\n')
#cf.write(str(omim_mappings.cache_info()))
cf.write(str(pc1.cache_info()))
cf.write('\n\n')
i = i + 1

Expand Down
3 changes: 2 additions & 1 deletion src/malco/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from malco.run import search_ppkts

def call_ontogpt(lang, raw_results_dir, input_dir, model, modality):
# TODO
# Check what ppkts have already been computed in current output dir, for current run parameters
# ontogpt will run every txt that is in inputdir, we need a tmp inputdir
# This tmp inputdir contains only the prompts that have not yet been computed for a given, fixed model (pars set)
Expand Down Expand Up @@ -38,7 +39,7 @@ def call_ontogpt(lang, raw_results_dir, input_dir, model, modality):
process.communicate()
print(f"Finished command for language {lang} and model {model}")

#TODO get rid of parallelization?
#TODO decide whether to get rid of parallelization
def run(testdata_dir: Path,
raw_results_dir: Path,
input_dir: Path,
Expand Down

0 comments on commit 1b0a231

Please sign in to comment.