Skip to content

Commit

Permalink
Fix embedding similarity crash (#90)
Browse files Browse the repository at this point in the history
Normalizes `output` and `expected` to strings so that they are valid
cache keys.

Fixes BRA-1517.
  • Loading branch information
aphinx authored Nov 1, 2024
1 parent f30aa1e commit 6d205c6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
3 changes: 3 additions & 0 deletions py/autoevals/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from Levenshtein import distance

from autoevals.partial import ScorerWithPartial
from autoevals.value import normalize_value

from .oai import arun_cached_request, run_cached_request

Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(self, prefix="", model=MODEL, expected_min=0.7, api_key=None, base_
self.extra_args["base_url"] = base_url

async def _a_embed(self, value):
value = normalize_value(value, maybe_object=False)
with self._CACHE_LOCK:
if value in self._CACHE:
return self._CACHE[value]
Expand All @@ -71,6 +73,7 @@ async def _a_embed(self, value):
return result

def _embed(self, value):
value = normalize_value(value, maybe_object=False)
with self._CACHE_LOCK:
if value in self._CACHE:
return self._CACHE[value]
Expand Down
21 changes: 21 additions & 0 deletions py/autoevals/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import asyncio

from autoevals import EmbeddingSimilarity
from autoevals.value import normalize_value

SYNONYMS = [
("water", ["water", "H2O", "agua"]),
Expand Down Expand Up @@ -27,3 +30,21 @@ def test_embeddings():
result = evaluator(word1, word2)
print(f"[{word1}]", f"[{word2}]", result)
assert result.score < 0.5


VALUES = [
("water", "wind"),
(["cold", "water"], ["cold", "wind"]),
({"water": "wet"}, {"wind": "dry"}),
]


def test_embedding_values():
for run_async in [False, True]:
evaluator = EmbeddingSimilarity()
for (word1, word2) in VALUES:
if run_async:
result = asyncio.run(evaluator.eval_async(word1, word2))
else:
result = evaluator(word1, word2)
print(f"[{word1}]", f"[{word2}]", f"run_async={run_async}", result)

0 comments on commit 6d205c6

Please sign in to comment.