Skip to content

Commit

Permalink
Speech Translation Evals (#54)
Browse files Browse the repository at this point in the history
* add covost2 dataset for validation & test

* add more ST scenarios and restructure
  • Loading branch information
farzadab authored Jul 25, 2024
1 parent bd3e917 commit cc48b74
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 25 deletions.
208 changes: 196 additions & 12 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ openai = "~1.33.0"
jiwer = "~3.0.4"
tensorboardx = "~2.6.2.2"
wandb = "~0.17.1"
sacrebleu = "^2.4.2"

[tool.poetry.group.dev.dependencies]
black = "~24.4.2"
Expand Down
21 changes: 11 additions & 10 deletions ultravox/evaluation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from ultravox.evaluation import string_based
from ultravox.evaluation import wer

METRIC_REGISTRY = {
"asr": wer.evaluate_answer_asr,
"boolq": gpt_eval_boolq.evaluate_answer_boolq,
"instruct": gpt_eval_instruct.evaluate_answer_instruct,
"conversation": gpt_eval_conv.evaluate_conversation_response,
"exact_match_last_word": string_based.match_last_word,
"bleu": string_based.bleu,
}


def evaluate_answer(sample: eval_types.Sample, metric: str) -> eval_types.Result:
if metric == "asr":
return wer.evaluate_answer_asr(sample)
elif metric == "boolq":
return gpt_eval_boolq.evaluate_answer_boolq(sample)
elif metric == "instruct":
return gpt_eval_instruct.evaluate_answer_instruct(sample)
elif metric == "conversation":
return gpt_eval_conv.evaluate_conversation_response(sample)
elif metric == "exact_match_last_word":
return string_based.match_last_word(sample)
if metric in METRIC_REGISTRY:
return METRIC_REGISTRY[metric](sample)
else:
raise ValueError(f"Unknown metric: {metric}")
10 changes: 10 additions & 0 deletions ultravox/evaluation/eval_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,14 @@ class ExactMatchResult:
reason: str


@dataclasses.dataclass
class BleuResult:
"""
Score is the BLEU score for the generated answer.
Note: BLEU is supposed to be computed on a corpus level, not on a single sample.
"""

score: float


Result = Union[InstructResult, WerResult, ExactMatchResult]
16 changes: 16 additions & 0 deletions ultravox/evaluation/string_based.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

import sacrebleu

from ultravox.evaluation import eval_types


Expand All @@ -21,3 +23,17 @@ def match_last_word(sample: eval_types.Sample) -> eval_types.ExactMatchResult:
return eval_types.ExactMatchResult(
score=last_word == expected_tf, reason="exact_match check"
)


def bleu(sample: eval_types.Sample) -> eval_types.BleuResult:
"""
Compute BLEU score for a single sample.
Note: BLEU is supposed to be computed on a corpus level, not on a single sample.
As such, reported values here might not be easily comparable to other metrics.
"""
score = sacrebleu.sentence_bleu(
hypothesis=sample.generated_answer,
references=[sample.expected_answer],
).score
return eval_types.BleuResult(score=score)
24 changes: 21 additions & 3 deletions ultravox/training/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,29 @@ class EvalScenario:


EVAL_SCENARIOS = [
# automatic speech recognition scenarios
EvalScenario("boolq__wer", "boolq_in", "asr"),
# automatic speech translation scenarios
EvalScenario("covost2_en_de__bleu", "covost2:en_de", "bleu"),
EvalScenario("covost2_en_zh-CN__bleu", "covost2:en_zh-CN", "bleu"),
EvalScenario("covost2_es_en__bleu", "covost2:es_en", "bleu"),
EvalScenario(
"covost2_en_de__bleu__text_only", "covost2:en_de", "bleu", include_audio=False
),
EvalScenario(
"covost2_en_zh-CN__bleu__text_only",
"covost2:en_zh-CN",
"bleu",
include_audio=False,
),
EvalScenario(
"covost2_es_en__bleu__text_only", "covost2:es_en", "bleu", include_audio=False
),
# SQA scenarios
EvalScenario("anyinstruct__instruct_follow", "anyinstruct", "instruct"),
EvalScenario(
"boolq__binary", "boolq_extended", "exact_match_last_word", new_tokens=128
),
EvalScenario("boolq__wer", "boolq_in", "asr"),
EvalScenario("soda__sensible_generation", "soda", "conversation", new_tokens=64),
# Text-only scenarios: tests for catastrophic forgetting.
EvalScenario(
"anyinstruct__instruct_follow__text_only",
"anyinstruct",
Expand All @@ -78,6 +94,8 @@ class EvalScenario:
new_tokens=128,
include_audio=False,
),
# Conversation dialogue scenarios
EvalScenario("soda__sensible_generation", "soda", "conversation", new_tokens=64),
EvalScenario(
"soda__sensible_generation__text_only",
"soda",
Expand Down

0 comments on commit cc48b74

Please sign in to comment.