Skip to content

Commit

Permalink
refactor: extract detect_language into a utility method used by both …
Browse files Browse the repository at this point in the history
…REST and CLI
  • Loading branch information
osma committed Sep 17, 2024
1 parent 5d9c081 commit 2c06655
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
7 changes: 3 additions & 4 deletions annif/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
OperationFailedException,
)
from annif.project import Access
from annif.simplemma_util import get_language_detector
from annif.simplemma_util import detect_language
from annif.util import metric_code

logger = annif.logger
Expand Down Expand Up @@ -745,13 +745,12 @@ def run_detect_language(languages):
raise click.UsageError("At least one language is required as an argument")

text = sys.stdin.read()
detector = get_language_detector(tuple(languages))
try:
proportions = detector.proportion_in_each_language(text)
proportions = detect_language(text, languages)
except ValueError as e:
raise click.UsageError(e)

for lang, score in sorted(proportions.items(), key=lambda x: x[1], reverse=True):
for lang, score in proportions.items():
if lang == "unk":
lang = "?"
click.echo(f"{lang}\t{score:.04f}")
Expand Down
17 changes: 6 additions & 11 deletions annif/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import connexion

import annif.registry
import annif.simplemma_util
from annif.corpus import Document, DocumentList, SubjectSet
from annif.exception import AnnifException
from annif.project import Access
from annif.simplemma_util import get_language_detector

if TYPE_CHECKING:
from connexion.lifecycle import ConnexionResponse
Expand Down Expand Up @@ -89,9 +89,8 @@ def detect_language(body: dict[str, Any]):
text = body.get("text")
languages = body.get("languages")

detector = get_language_detector(tuple(languages))
try:
proportions = detector.proportion_in_each_language(text)
proportions = annif.simplemma_util.detect_language(text, tuple(languages))
except ValueError:
return connexion.problem(
status=400,
Expand All @@ -100,14 +99,10 @@ def detect_language(body: dict[str, Any]):
)

result = {
"results": sorted(
[
{"language": lang if lang != "unk" else None, "score": score}
for lang, score in proportions.items()
],
key=lambda x: x["score"],
reverse=True,
)
"results": [
{"language": lang if lang != "unk" else None, "score": score}
for lang, score in proportions.items()
]
}
return result, 200, {"Content-Type": "application/json"}

Expand Down
8 changes: 7 additions & 1 deletion annif/simplemma_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Wrapper code for using Simplemma functionality in Annif"""

from typing import Tuple, Union
from typing import Dict, Tuple, Union

from simplemma import LanguageDetector, Lemmatizer
from simplemma.strategies import DefaultStrategy
Expand All @@ -15,3 +15,9 @@

def get_language_detector(lang: Union[str, Tuple[str, ...]]) -> LanguageDetector:
return LanguageDetector(lang, lemmatization_strategy=_lemmatization_strategy)


def detect_language(text: str, languages: Tuple[str, ...]) -> Dict[str, float]:
detector = get_language_detector(languages)
proportions = detector.proportion_in_each_language(text)
return dict(sorted(proportions.items(), key=lambda x: x[1], reverse=True))
11 changes: 10 additions & 1 deletion tests/test_simplemma_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from annif.simplemma_util import get_language_detector
from annif.simplemma_util import detect_language, get_language_detector


def test_get_language_detector():
Expand All @@ -17,3 +17,12 @@ def test_get_language_detector_many():
text = "She said 'au revoir' and left"
proportion = detector.proportion_in_target_languages(text)
assert proportion == pytest.approx(1.0)


def test_detect_language():
text = "She said 'au revoir' and left"
languages = ("fr", "en")
proportions = detect_language(text, languages)
assert proportions["en"] == pytest.approx(0.75)
assert proportions["fr"] == pytest.approx(0.25)
assert list(proportions.keys())[0] == "en"

0 comments on commit 2c06655

Please sign in to comment.