From 994fe73ce501ca44a55f274eabfa9c6b5150b7db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Tue, 1 Oct 2024 16:06:21 +0200 Subject: [PATCH] Multilingual Hellaswag tasks (#332) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add multilignaul dynamic generative metrics * draft * finish multichoice config * update tokenizers + install nltk reqs * use punkt tab * Update src/lighteval/utils/imports.py Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> * Update src/lighteval/metrics/normalizations.py Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> * fix imports * remove unused import * finish implementation of templates + move stuff around * resolve nits * when in rome do as romans do (handle error messages the same way) * fix utils * nicers tests + fix them * nicer todo * add nice doscrings 📃 * add even more docstring * nit * fix test * add multilingual to dev group * merge nli, add languagees to literals * translation literals * add nli * add copa tasks + fix tranlation literals * add hellaswag tasks * remove custom telgu hellaswag * remove hindi hellaswag * add rcb + chinese nli * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * add two new tasks + docs * add nice docs * update hellaswag with docs * move hellaswag to lighteval suite * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> * enable returning none from templates + better typing * change unoficial hellaswag names to have community_prefix + unify hellaswag preprocesisng * let strip be optional in hellaswag --------- Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Co-authored-by: Hynek Kydlicek Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> --- src/lighteval/tasks/default_prompts.py | 30 ++-- src/lighteval/tasks/lighteval_task.py | 2 +- src/lighteval/tasks/multilingual/tasks.py | 142 ++++++++++++++++ src/lighteval/tasks/templates/continuation.py | 18 +- src/lighteval/tasks/templates/copa.py | 9 +- src/lighteval/tasks/templates/hellaswag.py | 152 +++++++++++++++++ src/lighteval/tasks/templates/multichoice.py | 7 +- src/lighteval/tasks/templates/nli.py | 16 +- .../tasks/templates/utils/adapter_utils.py | 15 +- .../templates/utils/translation_literals.py | 1 + tests/tasks/templates/test_hellaswag.py | 160 ++++++++++++++++++ 11 files changed, 522 insertions(+), 30 deletions(-) create mode 100644 src/lighteval/tasks/templates/hellaswag.py create mode 100644 tests/tasks/templates/test_hellaswag.py diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index 482022d9..3a9b97f0 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -755,21 +755,29 @@ def headqa(line, task_name: str = None): ) -def hellaswag_harness(line, task_name: str = None): - def preprocess(text): - """Comes from AiHarness""" - # text = text.strip() - # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. - text = text.replace(" [title]", ". ") - text = re.sub("\\[.*?\\]", "", text) - text = text.replace(" ", " ") - return text +def hellaswag_preprocess( + text: str, wikihow_artifacts: list[str] = [" [title]"], truncate_dots: bool = False, strip_text: bool = False +): + """Comes from AiHarness""" + # text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + for dot_repl in wikihow_artifacts: + text = text.replace(dot_repl, ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + if truncate_dots: + text = text.replace(r"\.+", r"\.") + if strip_text: + text = text.strip() + return text + +def hellaswag_harness(line, task_name: str = None): ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} " return Doc( task_name=task_name, - query=preprocess(line["activity_label"] + ": " + ctx), - choices=[preprocess(ending) for ending in line["endings"]], + query=hellaswag_preprocess(line["activity_label"] + ": " + ctx), + choices=[hellaswag_preprocess(ending) for ending in line["endings"]], gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test # "metric": "choices_loglikelihood", ) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 8e23899a..00b4763b 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -89,7 +89,7 @@ class LightevalTaskConfig: """ name: str - prompt_function: Callable[[dict, str], Doc] + prompt_function: Callable[[dict, str], Doc | None] hf_repo: str hf_subset: str metric: ListLike[Metric | Metrics] diff --git a/src/lighteval/tasks/multilingual/tasks.py b/src/lighteval/tasks/multilingual/tasks.py index 7e59c8b5..14af524e 100644 --- a/src/lighteval/tasks/multilingual/tasks.py +++ b/src/lighteval/tasks/multilingual/tasks.py @@ -27,6 +27,7 @@ from lighteval.metrics.normalizations import LogProbTokenNorm from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.templates.copa import get_copa_prompt_function +from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function from lighteval.tasks.templates.nli import get_nli_prompt_function from lighteval.tasks.templates.utils.formulation import ( CFFormulation, @@ -386,6 +387,9 @@ ), hf_repo="ai4bharat/IndicCOPA", hf_subset=f"translation-{standardize_tag(language.value)}", + # Since we use trust_dataset, we have to be careful about what is inside the dataset + # script. We thus lock the revision to ensure that the script doesn't change + hf_revision="d356ef19a4eb287e88a51d07a56b73ba88c7f188", evaluation_splits=["test"], metric=[ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), @@ -443,6 +447,141 @@ ] +# ------------------------------- Hellaswag Tasks ------------------------------- # +# Hellaswag is a commonsense reasoning task that requires models to complete a given scenario +# with the most plausible ending. It tests the model's ability to understand and reason about +# everyday situations and human behavior. + +# MLMM-Hellaswag: Multilingual adaptation of Hellaswag +# Paper: https://arxiv.org/abs/2306.07610 +# This is a multilingual version of Hellaswag, part of the MLMM (Massive Language Model Meta-Evaluation) benchmark. +# It evaluates commonsense reasoning abilities across multiple languages. +mlmm_hellaswag_tasks = [ + LightevalTaskConfig( + name=f"hellaswag_{lang.value}_{formulation.name.lower()}", + suite=["lighteval"], + prompt_function=get_hellaswag_prompt_function( + language=lang, + adapter=lambda line: { + # We don't use activity_label as they are not available + "ctx_a": line["ctx_a"], + "ctx_b": line["ctx_b"], + "continuations": line["endings"], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + hf_repo="jon-tow/okapi_hellaswag", + hf_subset=standardize_tag(lang.value), + # Since we use trust_dataset, we have to be careful about what is inside the dataset + # script. We thus lock the revision to ensure that the script doesn't change + hf_revision="96ed8e0dfc6172dad1d3df338d7b8ba6c1ff9d83", + evaluation_splits=["validation"], + metric=[ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + ], + trust_dataset=True, + ) + for lang in [ + Language.ARABIC, + Language.BENGALI, + Language.CATALAN, + Language.DANISH, + Language.GERMAN, + Language.SPANISH, + Language.BASQUE, + Language.FRENCH, + Language.GUJARATI, + Language.HINDI, + Language.CROATIAN, + Language.HUNGARIAN, + Language.ARMENIAN, + Language.INDONESIAN, + Language.ICELANDIC, + Language.ITALIAN, + Language.KANNADA, + Language.MALAYALAM, + Language.MARATHI, + Language.NORWEGIAN, + Language.NEPALI, + Language.DUTCH, + Language.PORTUGUESE, + Language.ROMANIAN, + Language.RUSSIAN, + Language.SLOVAK, + Language.SERBIAN, + Language.SWEDISH, + Language.TAMIL, + Language.TELUGU, + Language.UKRAINIAN, + Language.VIETNAMESE, + Language.CHINESE, + ] + for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] +] + +# Hellaswag Turkish +# This is a Turkish adaptation of the Hellaswag task. +# While there's no specific paper for this version, it has been found to work well for evaluating +# Turkish language models on commonsense reasoning tasks. + +# We don't handle them in single task as there is quite a lot of differences (dataset/subset, dot replacement, etc.) +# which would make it hard to read +hellaswag_tur_tasks = [ + LightevalTaskConfig( + name=f"community_hellaswag_{Language.TURKISH.value}_{formulation.name.lower()}", + suite=["lighteval"], + prompt_function=get_hellaswag_prompt_function( + language=Language.TURKISH, + adapter=lambda line: { + "ctx_a": line["ctx_a"], + "ctx_b": line["ctx_b"], + "continuations": line["endings"], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + # https://github.com/malhajar17/lm-evaluation-harness_turkish/blob/main/lm_eval/tasks/hellaswag_tr-v0.2/utils.py + wikihow_artifacts=[" [title]", " [başlık]", " [adım]", " [header]"], + ), + hf_repo="malhajar/hellaswag_tr-v0.2", + hf_subset="default", + evaluation_splits=["validation"], + metric=[ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + ], + ) + for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] +] + +# Hellaswag Thai +# This is a Thai adaptation of the Hellaswag task. +# Similar to the Turkish version, there's no specific paper, but it has been found to be effective +# for evaluating Thai language models on commonsense reasoning tasks. +hellaswag_tha_tasks = [ + LightevalTaskConfig( + name=f"community_hellaswag_{Language.THAI.value}_{formulation.name.lower()}", + suite=["lighteval"], + prompt_function=get_hellaswag_prompt_function( + language=Language.THAI, + adapter=lambda line: { + "ctx_a": line["ctx_a"], + "ctx_b": line["ctx_b"], + "continuations": line["endings"], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + hf_repo="HuggingFaceFW-Dev/hellaswag_thai", + hf_subset="default", + evaluation_splits=["validation"], + few_shots_split="train", + metric=[ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + ], + ) + for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] +] + TASKS_TABLE = [ *xnli_tasks, *xnli2_tasks, @@ -454,4 +593,7 @@ *xcopa_tasks, *copa_indic_tasks, *parus_tasks, + *mlmm_hellaswag_tasks, + *hellaswag_tur_tasks, + *hellaswag_tha_tasks, ] diff --git a/src/lighteval/tasks/templates/continuation.py b/src/lighteval/tasks/templates/continuation.py index b7ea2e18..84c11230 100644 --- a/src/lighteval/tasks/templates/continuation.py +++ b/src/lighteval/tasks/templates/continuation.py @@ -84,7 +84,7 @@ class ContinuationDictAdapter(TypedDict): def get_continuation_prompt_function( language: Language, - adapter: Callable[[dict], ContinuationInput] | ContinuationDictAdapter, + adapter: Callable[[dict], ContinuationInput | None] | ContinuationDictAdapter, formulation: Formulation = MCFFormulation(), ): """ @@ -121,11 +121,13 @@ def get_continuation_prompt_function( Returns: Callable: A function that generates Continuation prompt based on the given parameters. """ - adapter_fn: Callable[[dict], ContinuationInput] = create_adapter_from_dict(adapter) # type: ignore + adapter_fn = create_adapter_from_dict(adapter) translation_literals = TRANSLATION_LITERALS[language] def prepare_prompt(line: dict): cont_input = adapter_fn(line) + if cont_input is None: + return None instruction_val = cont_input.get("instruction") instruction = f"{instruction_val}\n" if instruction_val else "" @@ -140,7 +142,11 @@ def prepare_prompt(line: dict): return cont_input, instruction, context, continuations def prompt_fn_cf(line, task_name: str): - cont_input, instruction, context, continuations = prepare_prompt(line) + prepared_prompt = prepare_prompt(line) + if prepared_prompt is None: + return None + + cont_input, instruction, context, continuations = prepared_prompt context_follows_sentence_space = punctuation_ends_sentence(context, translation_literals) answers = build_answers(continuations, formulation, translation_literals, context_follows_sentence_space) @@ -160,7 +166,11 @@ def prompt_fn_cf(line, task_name: str): ) def prompt_fn_mcf(line, task_name: str): - cont_input, instruction, context, continuations = prepare_prompt(line) + prepared_prompt = prepare_prompt(line) + if prepared_prompt is None: + return None + + cont_input, instruction, context, continuations = prepared_prompt options = build_choices(continuations, formulation, translation_literals) options = f"{options}\n" if options else "" diff --git a/src/lighteval/tasks/templates/copa.py b/src/lighteval/tasks/templates/copa.py index 03ae9f0e..2129332f 100644 --- a/src/lighteval/tasks/templates/copa.py +++ b/src/lighteval/tasks/templates/copa.py @@ -74,7 +74,9 @@ class COPAAdapter(TypedDict): def get_copa_prompt_function( - language: Language, adapter: Callable[[dict], COPAInput] | COPAAdapter, formulation: Formulation = MCFFormulation() + language: Language, + adapter: Callable[[dict], COPAInput | None] | COPAAdapter, + formulation: Formulation = MCFFormulation(), ): """ Create a templated prompt function for a COPA task. @@ -109,7 +111,7 @@ def get_copa_prompt_function( Returns: Callable: A function that generates COPA prompts based on the given parameters. """ - adapter_fn: Callable[[dict], COPAInput] = create_adapter_from_dict(adapter) # type: ignore + adapter_fn = create_adapter_from_dict(adapter) continuation_prompt_fn = get_continuation_prompt_function( language, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation ) @@ -120,6 +122,9 @@ def copa_prompt( task_name: str, ): input_data = adapter_fn(line) + if input_data is None: + return None + context = capitalize(input_data["context"].rstrip(PUNCT)) cause_or_effect_trans = ( translation_literals.cause_word diff --git a/src/lighteval/tasks/templates/hellaswag.py b/src/lighteval/tasks/templates/hellaswag.py new file mode 100644 index 00000000..536e90a5 --- /dev/null +++ b/src/lighteval/tasks/templates/hellaswag.py @@ -0,0 +1,152 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import Callable + +from typing_extensions import NotRequired, TypedDict + +from lighteval.tasks.default_prompts import hellaswag_preprocess +from lighteval.tasks.templates.continuation import get_continuation_prompt_function +from lighteval.tasks.templates.multichoice import create_adapter_from_dict +from lighteval.tasks.templates.utils.formatting_utils import ( + capitalize, + fix_capitalization, + fix_ending_punct, + punctuation_ends_sentence, +) +from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation +from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS +from lighteval.utils.language import Language + + +# NLI Cause/Effect (Copa) +HELLASWAG_QUERY = "{activity_label}{ctx}" + + +class HellaswagInput(TypedDict): + ctx_a: str + continuations: list[str] + gold_idx: int | list[int] + instruction: NotRequired[str] + activity_label: NotRequired[str] + ctx_b: NotRequired[str] + + +class HellaswagAdapter(TypedDict): + ctx_a: str + continuations: str + gold_idx: str + instruction: NotRequired[str] + activity_label: NotRequired[str] + ctx_b: NotRequired[str] + + +def get_hellaswag_prompt_function( + language: Language, + adapter: Callable[[dict], HellaswagInput | None] | HellaswagAdapter, + formulation: Formulation = MCFFormulation(), + wikihow_artifacts: list[str] = [" [title]"], +): + """ + Create a templated prompt function for a Hellaswag task. + + Format: + Context Premise thefore/cause | (Continuation 1, Continuation 2, Continuation 3) + + Args: + language (Language): The language of the Hellaswag task. + adapter (Callable[[dict], HellaswagInput] | HellaswagAdapter): A function or dictionary to adapt the input data to the required HellaswagInput format. + Must map data from the dataset row to the HellaswagInput format. + Note: The gold_idx must be an index or list of indices in the continuations list, indicating the correct continuation(s). + formulation (Formulation, optional): The formulation to use for the task. Defaults to MCFFormulation(). + wikihow_artifacts (list[str], optional): A list of strings to replace with dot. We have to replace the the texts with dots because + of wikihow source. + + Returns: + Callable: A function that generates COPA prompts based on the given parameters. + """ + + translation_literals = TRANSLATION_LITERALS[language] + + def process_context(ctx): + if ctx == "": + return "" + return capitalize( + fix_ending_punct( + hellaswag_preprocess(ctx, truncate_dots=True, wikihow_artifacts=wikihow_artifacts, strip_text=True), + translation_literals, + ) + ) + + def join_ctxs(ctx_a, ctx_b): + space = ( + translation_literals.sentence_space + if punctuation_ends_sentence(ctx_a, translation_literals) + else translation_literals.word_space + ) + return f"{ctx_a.rstrip()}{space}{fix_capitalization(ctx_a, ctx_b, translation_literals)}" + + adapter_fn = create_adapter_from_dict(adapter) + continuation_prompt_fn = get_continuation_prompt_function( + language, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation + ) + + def hellaswag_prompt( + line: dict, + task_name: str, + ): + input_data = adapter_fn(line) + if input_data is None: + return None + + activity_label = input_data.get("activity_label", "") + activity_label = f"{capitalize(activity_label)}:\n" if activity_label else "" + + # Last one should be left as is + ctx_a, ctx_b = capitalize(input_data["ctx_a"]), input_data.get("ctx_b", "") + if ctx_b: + ctx_a = join_ctxs(process_context(ctx_a), process_context(ctx_b)) + + # Removoal of the [header] can happen and we need the first letter to be capital afterwards + full_context = HELLASWAG_QUERY.format(activity_label=activity_label, ctx=ctx_a) + choices = [ + hellaswag_preprocess( + continuation, wikihow_artifacts=wikihow_artifacts, truncate_dots=True, strip_text=True + ) + for continuation in input_data["continuations"] + ] + + # It can happen that the continuations are empty we thus skip the task + if any(len(c.strip()) == 0 for c in choices): + return None + + return continuation_prompt_fn( + { + "instruction": input_data.get("instruction", ""), + "context": full_context, + "continuations": choices, + "gold_idx": input_data["gold_idx"], + }, + task_name, + ) + + return hellaswag_prompt diff --git a/src/lighteval/tasks/templates/multichoice.py b/src/lighteval/tasks/templates/multichoice.py index 1c6f2bfb..daa8fffd 100644 --- a/src/lighteval/tasks/templates/multichoice.py +++ b/src/lighteval/tasks/templates/multichoice.py @@ -80,7 +80,7 @@ class MCQDictAdapter(TypedDict): def get_mcq_prompt_function( language: Language, - adapter: Callable[[dict], MCQInput] | MCQDictAdapter, + adapter: Callable[[dict], MCQInput | None] | MCQDictAdapter, formulation: Formulation = MCFFormulation(), ): """ @@ -120,10 +120,13 @@ def get_mcq_prompt_function( Callable: A function that generates MCQ prompts based on the given parameters. """ - adapter_fn: Callable[[dict], MCQInput] = create_adapter_from_dict(adapter) # type: ignore + adapter_fn = create_adapter_from_dict(adapter) def prompt_fn(line, task_name: str): mcq_input = adapter_fn(line) + if mcq_input is None: + return None + translation_literals = TRANSLATION_LITERALS[language] instruction_val = mcq_input.get("instruction") diff --git a/src/lighteval/tasks/templates/nli.py b/src/lighteval/tasks/templates/nli.py index f20a07ea..5c7abec0 100644 --- a/src/lighteval/tasks/templates/nli.py +++ b/src/lighteval/tasks/templates/nli.py @@ -74,7 +74,7 @@ class NLIAdapter(TypedDict): def _nli_prompt_function_natural( language: Language, - adapter: Callable[[dict], NLIInput] | NLIAdapter, + adapter: Callable[[dict], NLIInput | None] | NLIAdapter, relations: list[RelationType], ): """ @@ -107,11 +107,14 @@ def get_relation_label(label: RelationType, translation_literals: TranslationLit return translation_literals.also translation_literals = TRANSLATION_LITERALS[language] - adapter_fn = create_adapter_from_dict(adapter) # type: ignore + adapter_fn = create_adapter_from_dict(adapter) def nli_natural_prompt(line: dict, task_name: str): labels = [capitalize(get_relation_label(label, translation_literals)) for label in relations] input_data = adapter_fn(line) + if input_data is None: + return None + premise, hypothesis, label = input_data["premise"], input_data["hypothesis"], input_data["gold_idx"] premise = capitalize(input_data["premise"].rstrip(PUNCT)) @@ -159,7 +162,7 @@ def nli_natural_prompt(line: dict, task_name: str): def get_nli_prompt_function( language: Language, - adapter: Callable[[dict], NLIInput] | NLIAdapter, + adapter: Callable[[dict], NLIInput | None] | NLIAdapter, relations: list[RelationType], formulation: Formulation = MCFFormulation(), ): @@ -211,7 +214,7 @@ def get_relation_label(label: RelationType, translation_literals: TranslationLit elif label == "neutral": return translation_literals.neither - adapter_fn = create_adapter_from_dict(adapter) # type: ignore + adapter_fn = create_adapter_from_dict(adapter) # For hybrid we use inlined choices so we use the cf formulation in multichoice prompt fn mcq_prompt_fn = get_mcq_prompt_function( @@ -221,10 +224,13 @@ def get_relation_label(label: RelationType, translation_literals: TranslationLit ) def prompt_fn(line: dict, task_name: str): + input_data = adapter_fn(line) + if input_data is None: + return None + # Template based on dicussion here: https://github.com/EleutherAI/lm-evaluation-harness/issues/450 labels = [capitalize(get_relation_label(label, translation_literals)) for label in relations] - input_data = adapter_fn(line) premise, hypothesis, gold_idx = input_data["premise"], input_data["hypothesis"], input_data["gold_idx"] premise = fix_ending_punct(capitalize(input_data["premise"]), translation_literals) hypothesis = input_data["hypothesis"] diff --git a/src/lighteval/tasks/templates/utils/adapter_utils.py b/src/lighteval/tasks/templates/utils/adapter_utils.py index 0e91b201..c260e568 100644 --- a/src/lighteval/tasks/templates/utils/adapter_utils.py +++ b/src/lighteval/tasks/templates/utils/adapter_utils.py @@ -21,10 +21,15 @@ # SOFTWARE. -from typing import Any, Callable +from typing import Any, Callable, Mapping, TypeVar -def create_adapter_from_dict(adapter: dict[str, str | None] | Callable[[dict], Any]): +AdapterReturnTypeVar = TypeVar("AdapterReturnTypeVar") + + +def create_adapter_from_dict( + adapter: Mapping[str, Any] | Callable[[dict], AdapterReturnTypeVar], +) -> Callable[[dict], AdapterReturnTypeVar]: """ Creates adapter function for the template input from a dict. Args: @@ -32,10 +37,10 @@ def create_adapter_from_dict(adapter: dict[str, str | None] | Callable[[dict], A """ - if not isinstance(adapter, dict): + if not isinstance(adapter, Mapping): return adapter def adapter_fn(line: dict): - return {key: line[value] for key, value in adapter.items()} # type: ignore + return {key: line[value] for key, value in adapter.items()} - return adapter_fn + return adapter_fn # type: ignore diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index 71edc763..0c521e68 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -349,4 +349,5 @@ def __getattribute__(self, name: str) -> str: Language.SORANI: TranslationLiterals(language=Language.SORANI), Language.CEBUANO: TranslationLiterals(language=Language.CEBUANO), Language.WAR: TranslationLiterals(language=Language.WAR), + Language.SWEDISH: TranslationLiterals(language=Language.SWEDISH), } diff --git a/tests/tasks/templates/test_hellaswag.py b/tests/tasks/templates/test_hellaswag.py new file mode 100644 index 00000000..2ef7b895 --- /dev/null +++ b/tests/tasks/templates/test_hellaswag.py @@ -0,0 +1,160 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function +from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation +from lighteval.utils.language import Language + + +def test_hellaswag_prompt_cf(): + """ + Tests that hellaswag prompt function works correctly. + Since it's pretty much a wrapper around continuation template we just test single formulation. + + """ + test_input = { + "activity_label": "fitness", + "ctx_a": "He is strong", + "ctx_b": "He is fast", + "continuations": ["he has big muscles", "he is weak"], + "gold_idx": 0, + } + + prompt_fn = get_hellaswag_prompt_function( + Language.ENGLISH, + { + "activity_label": "activity_label", + "continuations": "continuations", + "gold_idx": "gold_idx", + "ctx_a": "ctx_a", + "ctx_b": "ctx_b", + }, + CFFormulation(), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc.query == "Fitness:\nHe is strong he is fast" + + assert doc.unconditioned_query == "" + assert doc.choices == [" he has big muscles", " he is weak"] + assert doc.gold_index == [0] + + +def test_hellaswag_prompt_mcf(): + """ + Tests that hellaswag prompt function works correctly. + Since it's pretty much a wrapper around continuation template we just test single formulation. + + """ + test_input = { + "activity_label": "fitness", + "ctx_a": "He is strong", + "ctx_b": "He is fast", + "continuations": ["he has big muscles", "he is weak"], + "gold_idx": 0, + } + + prompt_fn = get_hellaswag_prompt_function( + Language.ENGLISH, + { + "activity_label": "activity_label", + "continuations": "continuations", + "gold_idx": "gold_idx", + "ctx_a": "ctx_a", + "ctx_b": "ctx_b", + }, + MCFFormulation(), + ) + + doc = prompt_fn(test_input, "test_task") + assert ( + doc.query + == """\ +Fitness:\nHe is strong he is fast + A. he has big muscles + B. he is weak +Answer:\ +""" + ) + + assert doc.unconditioned_query == "Answer:" + assert doc.choices == [" A", " B"] + assert doc.gold_index == [0] + + +def test_hellaswag_ctx_joining(): + """ + Tests that hellaswag prompt function works correctly. + Since it's pretty much a wrapper around continuation template we just test single formulation. + + """ + test_input = { + "activity_label": "fitness", + "ctx_a": "He is strong.", + "ctx_b": "he is fast.", + "continuations": ["he has big muscles", "he is weak"], + "gold_idx": 0, + } + + prompt_fn = get_hellaswag_prompt_function( + Language.ENGLISH, + { + "activity_label": "activity_label", + "continuations": "continuations", + "gold_idx": "gold_idx", + "ctx_a": "ctx_a", + "ctx_b": "ctx_b", + }, + CFFormulation(), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc.query == "Fitness:\nHe is strong. He is fast." + + +def test_hellaswag_single_ctx(): + """ + Tests that hellaswag prompt function works correctly. + Since it's pretty much a wrapper around continuation template we just test single formulation. + + """ + test_input = { + "activity_label": "fitness", + "ctx_a": "He is strong.", + "continuations": ["he has big muscles", "he is weak"], + "gold_idx": 0, + } + + prompt_fn = get_hellaswag_prompt_function( + Language.ENGLISH, + { + "activity_label": "activity_label", + "continuations": "continuations", + "gold_idx": "gold_idx", + "ctx_a": "ctx_a", + }, + CFFormulation(), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc.query == "Fitness:\nHe is strong."