Skip to content

Commit

Permalink
Fix QuestionAnsweringEvaluator for squad v2, fix examples (#190)
Browse files Browse the repository at this point in the history
* fix squad v2, better doc

* fix warning bootstrap

* fix test

* fix test

Co-authored-by: Felix Marty <felix@huggingface.co>
  • Loading branch information
fxmarty and fxmarty authored Jul 20, 2022
1 parent 6ef285c commit 4e7f682
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 44 deletions.
15 changes: 7 additions & 8 deletions src/evaluate/evaluator/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def predictions_processor(self, predictions, label_mapping):

return {"predictions": pred_label}

def compute(self, *args, **kwargs) -> Tuple[Dict[str, float], Any]:
def compute(self, input_column: str = "image", *args, **kwargs) -> Tuple[Dict[str, float], Any]:
"""
Compute the metric for a given pipeline and dataset combination.
Args:
Expand Down Expand Up @@ -68,7 +68,7 @@ def compute(self, *args, **kwargs) -> Tuple[Dict[str, float], Any]:
debugging.
input_column (`str`, defaults to `"image"`):
the name of the column containing the images as PIL ImageFile in the dataset specified by `data`.
label_column (`str`, defaults to `"labels"`):
label_column (`str`, defaults to `"label"`):
the name of the column containing the labels in the dataset specified by `data`.
label_mapping (`Dict[str, Number]`, *optional*, defaults to `None`):
We want to map class labels defined by the model in the pipeline to values consistent with those
Expand All @@ -80,20 +80,19 @@ def compute(self, *args, **kwargs) -> Tuple[Dict[str, float], Any]:
Examples:
```python
>>> from evaluate import evaluator
>>> from datasets import Dataset, load_dataset
>>> from datasets import load_dataset
>>> task_evaluator = evaluator("image-classification")
>>> data = load_dataset("beans", split="test[:2]")
>>> data = load_dataset("beans", split="test[:40]")
>>> results = task_evaluator.compute(
>>> model_or_pipeline="nateraw/vit-base-beans",
>>> data=data,
>>> label_column="labels",
>>> metric="accuracy",
>>> label_mapping={'angular_leaf_spot': 0, 'bean_rust': 1, 'healthy': 2},
>>> strategy="bootstrap",
>>> n_resamples=10,
>>> random_state=0
>>> strategy="bootstrap"
>>> )
```"""

result = super().compute(*args, **kwargs)
result = super().compute(input_column=input_column, *args, **kwargs)

return result
44 changes: 26 additions & 18 deletions src/evaluate/evaluator/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class QuestionAnsweringEvaluator(Evaluator):
[`QuestionAnsweringPipeline`](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.QuestionAnsweringPipeline).
"""

PIPELINE_KWARGS = {"handle_impossible_answer": False}

def __init__(self, task="question-answering", default_metric_name=None):
super().__init__(task, default_metric_name=default_metric_name)

Expand Down Expand Up @@ -89,7 +91,7 @@ def prepare_data(

return metric_inputs, {"question": data[question_column], "context": data[context_column]}

def is_squad_v2_schema(self, data: Dataset, label_column: str = "answers"):
def is_squad_v2_format(self, data: Dataset, label_column: str = "answers"):
"""
Check if the provided dataset follows the squad v2 data schema, namely possible samples where the answer is not in the context.
In this case, the answer text list should be `[]`.
Expand All @@ -103,11 +105,11 @@ def is_squad_v2_schema(self, data: Dataset, label_column: str = "answers"):
else:
return False

def predictions_processor(self, predictions: List, squad_v2_schema: bool, ids: List):
def predictions_processor(self, predictions: List, squad_v2_format: bool, ids: List):
result = []
for i in range(len(predictions)):
pred = {"prediction_text": predictions[i]["answer"], "id": ids[i]}
if squad_v2_schema:
if squad_v2_format:
pred["no_answer_probability"] = predictions[i]["score"]
result.append(pred)
return {"predictions": result}
Expand All @@ -126,6 +128,7 @@ def compute(
context_column: str = "context",
id_column: str = "id",
label_column: str = "answers",
squad_v2_format: Optional[bool] = None,
) -> Tuple[Dict[str, float], Any]:
"""
Compute the metric for a given pipeline and dataset combination.
Expand Down Expand Up @@ -167,6 +170,9 @@ def compute(
dataset specified by `data`.
label_column (`str`, defaults to `"answers"`):
the name of the column containing the answers in the dataset specified by `data`.
squad_v2_format (`bool`, *optional*, defaults to `None`):
whether the dataset follows the format of squad_v2 dataset where a question may have no answer in the context. If this parameter is not provided,
the format will be automatically infered.
Return:
A `Dict`. The keys represent metric keys calculated for the `metric` spefied in function arguments. For the
`"simple"` strategy, the value is the metric score. For the `"bootstrap"` strategy, the value is a `Dict`
Expand All @@ -176,9 +182,9 @@ def compute(
```python
>>> from evaluate import evaluator
>>> from datasets import load_dataset
>>> e = evaluator("question-answering")
>>> task_evaluator = evaluator("question-answering")
>>> data = load_dataset("squad", split="validation[:2]")
>>> results = e.compute(
>>> results = task_evaluator.compute(
>>> model_or_pipeline="sshleifer/tiny-distilbert-base-cased-distilled-squad",
>>> data=data,
>>> metric="squad",
Expand All @@ -187,26 +193,21 @@ def compute(
<Tip>
Datasets where the answer may be missing in the context are supported, for example SQuAD v2 dataset. If using transformers pipeline
with models trained on this type of data, make sure to pass `handle_impossible_answer=True` as an argument to the pipeline.
Datasets where the answer may be missing in the context are supported, for example SQuAD v2 dataset. In this case, it is safer to pass `squad_v2_format=True` to
the compute() call.
</Tip>
```python
>>> from evaluate import evaluator
>>> from datasets import load_dataset
>>> from transformers import pipeline
>>> task_evaluator = evaluator("question-answering")
>>> data = load_dataset("squad_v2", split="validation[:2]")
>>> pipe = pipeline(
>>> task="question-answering",
>>> model="sshleifer/mrm8488/bert-tiny-finetuned-squadv2",
>>> handle_impossible_answer=True
>>> )
>>> results = task_evaluator.compute(
>>> model_or_pipeline=pipe,
>>> model_or_pipeline="mrm8488/bert-tiny-finetuned-squadv2",
>>> data=data,
>>> metric="squad_v2",
>>> squad_v2_format=True,
>>> )
```
"""
Expand All @@ -220,23 +221,30 @@ def compute(
label_column=label_column,
)

squad_v2_schema = self.is_squad_v2_schema(data=data, label_column=label_column)
if squad_v2_format is None:
squad_v2_format = self.is_squad_v2_format(data=data, label_column=label_column)
logger.warn(
f"`squad_v2_format` parameter not provided to QuestionAnsweringEvaluator.compute(). Auto-infered `squad_v2_format` to {squad_v2_format}."
)

pipe = self.prepare_pipeline(model_or_pipeline=model_or_pipeline, tokenizer=tokenizer)

metric = self.prepare_metric(metric)

if squad_v2_schema and metric.name == "squad":
if squad_v2_format and metric.name == "squad":
logger.warn(
"The dataset has SQuAD v2 format but you are using the SQuAD metric. Consider passing the 'squad_v2' metric."
)
if not squad_v2_schema and metric.name == "squad_v2":
if not squad_v2_format and metric.name == "squad_v2":
logger.warn(
"The dataset has SQuAD v1 format but you are using the SQuAD v2 metric. Consider passing the 'squad' metric."
)

if squad_v2_format:
self.PIPELINE_KWARGS["handle_impossible_answer"] = True

predictions, perf_results = self.call_pipeline(pipe, **pipe_inputs)
predictions = self.predictions_processor(predictions, squad_v2_schema=squad_v2_schema, ids=data[id_column])
predictions = self.predictions_processor(predictions, squad_v2_format=squad_v2_format, ids=data[id_column])

metric_inputs.update(predictions)

Expand Down
2 changes: 1 addition & 1 deletion src/evaluate/evaluator/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def compute(self, *args, **kwargs) -> Tuple[Dict[str, float], Any]:
Examples:
```python
>>> from evaluate import evaluator
>>> from datasets import Dataset, load_dataset
>>> from datasets import load_dataset
>>> task_evaluator = evaluator("text-classification")
>>> data = load_dataset("imdb", split="test[:2]")
>>> results = task_evaluator.compute(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_model_init(self):
metric="squad",
)
self.assertEqual(scores["exact_match"], 0)
self.assertEqual(scores["f1"], 100 / 3)
self.assertEqual(scores["f1"], 0)

model = AutoModelForQuestionAnswering.from_pretrained(self.default_model)
tokenizer = AutoTokenizer.from_pretrained(self.default_model)
Expand All @@ -372,7 +372,7 @@ def test_model_init(self):
tokenizer=tokenizer,
)
self.assertEqual(scores["exact_match"], 0)
self.assertEqual(scores["f1"], 100 / 3)
self.assertEqual(scores["f1"], 0)

def test_class_init(self):
# squad_v1-like dataset
Expand Down
35 changes: 20 additions & 15 deletions tests/test_trainer_evaluator_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_text_classification_parity(self):

pipe = pipeline(task="text-classification", model=model_name, tokenizer=model_name)

e = evaluator(task="text-classification")
evaluator_results = e.compute(
task_evaluator = evaluator(task="text-classification")
evaluator_results = task_evaluator.compute(
model_or_pipeline=pipe,
data=eval_dataset,
metric="accuracy",
Expand Down Expand Up @@ -119,8 +119,8 @@ def collate_fn(examples):

pipe = pipeline(task="image-classification", model=model_name, feature_extractor=model_name)

e = evaluator(task="image-classification")
evaluator_results = e.compute(
task_evaluator = evaluator(task="image-classification")
evaluator_results = task_evaluator.compute(
model_or_pipeline=pipe,
data=eval_dataset,
metric="accuracy",
Expand All @@ -133,7 +133,8 @@ def collate_fn(examples):
self.assertEqual(transformers_results["eval_accuracy"], evaluator_results["accuracy"])

def test_question_answering_parity(self):
model_name = "mrm8488/bert-tiny-finetuned-squadv2"
model_name_v1 = "anas-awadalla/bert-tiny-finetuned-squad"
model_name_v2 = "mrm8488/bert-tiny-finetuned-squadv2"

subprocess.run(
"git sparse-checkout set examples/pytorch/question-answering",
Expand All @@ -144,7 +145,7 @@ def test_question_answering_parity(self):
# test squad_v1-like dataset
subprocess.run(
f"python examples/pytorch/question-answering/run_qa.py"
f" --model_name_or_path {model_name}"
f" --model_name_or_path {model_name_v1}"
f" --dataset_name squad"
f" --do_eval"
f" --output_dir {os.path.join(self.dir_path, 'questionanswering_squad_transformers')}"
Expand All @@ -162,11 +163,15 @@ def test_question_answering_parity(self):
eval_dataset = load_dataset("squad", split="validation[:100]")

pipe = pipeline(
task="question-answering", model=model_name, tokenizer=model_name, max_answer_len=30, padding="max_length"
task="question-answering",
model=model_name_v1,
tokenizer=model_name_v1,
max_answer_len=30,
padding="max_length",
)

e = evaluator(task="question-answering")
evaluator_results = e.compute(
task_evaluator = evaluator(task="question-answering")
evaluator_results = task_evaluator.compute(
model_or_pipeline=pipe,
data=eval_dataset,
metric="squad",
Expand All @@ -179,7 +184,7 @@ def test_question_answering_parity(self):
# test squad_v2-like dataset
subprocess.run(
f"python examples/pytorch/question-answering/run_qa.py"
f" --model_name_or_path {model_name}"
f" --model_name_or_path {model_name_v2}"
f" --dataset_name squad_v2"
f" --version_2_with_negative"
f" --do_eval"
Expand All @@ -199,18 +204,18 @@ def test_question_answering_parity(self):

pipe = pipeline(
task="question-answering",
model=model_name,
tokenizer=model_name,
model=model_name_v2,
tokenizer=model_name_v2,
max_answer_len=30,
handle_impossible_answer=True,
)

e = evaluator(task="question-answering")
evaluator_results = e.compute(
task_evaluator = evaluator(task="question-answering")
evaluator_results = task_evaluator.compute(
model_or_pipeline=pipe,
data=eval_dataset,
metric="squad_v2",
strategy="simple",
squad_v2_format=True,
)

self.assertEqual(transformers_results["eval_f1"], evaluator_results["f1"])
Expand Down

0 comments on commit 4e7f682

Please sign in to comment.