From fc41967b45013f14bb03332c84092bd69c34623d Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 14:28:13 +0530 Subject: [PATCH 01/17] Add SQuAD metric - functional. --- docs/source/links.rst | 1 + docs/source/references/functional.rst | 7 + torchmetrics/functional/__init__.py | 2 + torchmetrics/functional/text/__init__.py | 1 + torchmetrics/functional/text/squad.py | 256 +++++++++++++++++++++++ 5 files changed, 267 insertions(+) create mode 100644 torchmetrics/functional/text/squad.py diff --git a/docs/source/links.rst b/docs/source/links.rst index 4b72fb82a78..8bc4ddc3c58 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -72,3 +72,4 @@ .. _Python ROUGE Implementation: https://pypi.org/project/rouge-score/ .. _Scikit_Learn-Ranking.py: https: //github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py .. _Verified Uncertainty Calibration: https://arxiv.org/abs/1909.10155 +.. _SQuAD Metric: https://arxiv.org/pdf/1606.05250.pdf diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 95d4193845d..595d5442f22 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -458,6 +458,13 @@ sacre_bleu_score [func] .. autofunction:: torchmetrics.functional.sacre_bleu_score :noindex: + +squad [func] +~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.squad + :noindex: + wer [func] ~~~~~~~~~~ diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 69b7a810667..70a302f67d9 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -71,6 +71,7 @@ from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score +from torchmetrics.functional.text.squad import squad from torchmetrics.functional.text.wer import wer __all__ = [ @@ -132,6 +133,7 @@ "snr", "spearman_corrcoef", "specificity", + "squad", "ssim", "stat_scores", "stoi", diff --git a/torchmetrics/functional/text/__init__.py b/torchmetrics/functional/text/__init__.py index 6d312e47504..713e08f33d3 100644 --- a/torchmetrics/functional/text/__init__.py +++ b/torchmetrics/functional/text/__init__.py @@ -16,4 +16,5 @@ from torchmetrics.functional.text.cer import char_error_rate # noqa: F401 from torchmetrics.functional.text.mer import match_error_rate # noqa: F401 from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401 +from torchmetrics.functional.text.squad import squad # noqa: F401 from torchmetrics.functional.text.wer import wer # noqa: F401 diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py new file mode 100644 index 00000000000..c4f59b02951 --- /dev/null +++ b/torchmetrics/functional/text/squad.py @@ -0,0 +1,256 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from: +# Link: https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ +# Link: https://github.com/huggingface/datasets/blob/master/metrics/squad/squad.py +import re +import string +from collections import Counter +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch import Tensor, tensor + +from torchmetrics.utilities import rank_zero_warn + +PREDS_TYPE = Dict[str, str] +TARGETS_TYPE = List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]] + + +def normalize_text(s: str) -> str: + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text: str) -> str: + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text: str) -> str: + return " ".join(text.split()) + + def remove_punc(text: str) -> str: + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text: str) -> str: + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def get_tokens(s: str) -> List[str]: + """Split a sentence into separate tokens.""" + if not s: + return [] + return normalize_text(s).split() + + +def compute_f1_score(predictied_answer, target_answer) -> Tensor: + """Compute F1 Score for two sentences.""" + target_tokens: Tensor = get_tokens(target_answer) + predicted_tokens: Tensor = get_tokens(predictied_answer) + common = Counter(target_tokens) & Counter(predicted_tokens) + num_same: Tensor = tensor(sum(common.values())) + if len(target_tokens) == 0 or len(predicted_tokens) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return tensor(int(target_tokens == predicted_tokens)) + if num_same == 0: + return tensor(0.0) + precision: Tensor = 1.0 * num_same / tensor(len(predicted_tokens)) + recall: Tensor = 1.0 * num_same / tensor(len(target_tokens)) + f1: Tensor = (2 * precision * recall) / (precision + recall) + return f1 + + +def compute_exact_match_score(prediction, ground_truth) -> Tensor: + """Compute Exact Match for two sentences.""" + return tensor(int(normalize_text(prediction) == normalize_text(ground_truth))) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths) -> Tensor: + """Calculate maximum score for a predicted answer with all reference answers.""" + scores_for_ground_truths: List[Tensor] = [] + for ground_truth in ground_truths: + score: Tensor = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return torch.max(tensor(scores_for_ground_truths)) + + +def _squad_update( + preds: PREDS_TYPE, + targets: TARGETS_TYPE, +) -> Tuple[Tensor, Tensor, Tensor]: + """Compute F1 Score and Exact Match for a collection of predictions and references. + + Args: + preds: + A dictionary mapping an `id` to the predicted `answer`. + targets: + A list of dictionary mapping `paragraphs` to list of dictionary mapping `qas` to a list of dictionary + containing `id` and list of all possible `answers`. + + Return: + Tuple containing F1 score, Exact match score and total number of examples. + + Example: + >>> from torchmetrics.functional.text.squad import _squad_update + >>> predictions = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] + >>> targets = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] + >>> preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} + >>> targets_dict = [ + ... { + ... "paragraphs": [ + ... { + ... "qas": [ + ... { + ... "answers": [{"text": answer_text} for answer_text in target["answers"]["text"]], + ... "id": target["id"], + ... } + ... for target in targets + ... ] + ... } + ... ] + ... } + ... ] + >>> _squad_update(preds_dict, targets_dict) + (tensor(1.), tensor(1.), tensor(1)) + """ + f1: Tensor = tensor(0.0) + exact_match: Tensor = tensor(0.0) + total: Tensor = tensor(0) + for article in targets: + for paragraph in article["paragraphs"]: + for qa in paragraph["qas"]: + total += 1 + if qa["id"] not in preds: + rank_zero_warn(f"Unanswered question {qa['id']} will receive score 0.") + continue + ground_truths = list(map(lambda x: x["text"], qa["answers"])) + prediction = preds[qa["id"]] + exact_match += metric_max_over_ground_truths(compute_exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths(compute_f1_score, prediction, ground_truths) + + return f1, exact_match, total + + +def _squad_compute(scores: Tuple[Tensor, Tensor, Tensor]) -> Dict[str, Tensor]: + """Aggregate the F1 Score and Exact match for the batch. + + Args: + scores: + F1 Score, Exact Match, and Total number of examples in the batch + + Return: + Dictionary containing the F1 score, Exact match score for the batch. + """ + f1: Tensor = scores[0] + exact_match: Tensor = scores[1] + total: Tensor = scores[2] + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + return {"exact_match": exact_match, "f1": f1} + + +def squad( + preds: List[Dict[str, str]], + targets: List[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]]], +) -> Dict[str, Tensor]: + """Calculate `SQuAD Metric`_ . + + Args: + preds: + An iterable of predicted sentences or a single predicted sentence. + targets: + An iterable of target sentences or a single target sentence. + + Return: + Dictionary containing the F1 score, Exact match score for the batch. + + Example: + >>> from torchmetrics.functional.text.squad import squad + >>> predictions = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] + >>> references = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}] + >>> squad(predictions, references) + {'exact_match': tensor(100.), 'f1': tensor(100.)} + + Raises: + KeyError: + If the required keys are missing in either predictions or targets. + + References: + [1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin + Lopyrev, Percy Liang `SQuAD Metric`_ . + """ + + for pred in preds: + keys = pred.keys() + if "prediction_text" not in keys or "id" not in keys: + raise KeyError( + "Expected keys in a single prediction are 'prediction_text' and 'id'." + "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string." + ) + + for target in targets: + keys = target.keys() + if "answers" not in keys or "id" not in keys: + raise KeyError( + "Expected keys in a single target are 'answers' and 'id'." + "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n" + "SQuAD Format: " + "{" + " 'answers': {" + " 'answer_start': [1]," + " 'text': ['This is a test text']" + " }," + " 'context': 'This is a test context.'," + " 'id': '1'," + " 'question': 'Is this a test?'," + " 'title': 'train test'" + "}" + ) + + answers_keys = target["answers"].keys() + if "text" not in answers_keys: + raise KeyError( + "Expected keys in a 'answers' are 'text'." + "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" + "SQuAD Format: " + "{" + " 'answers': {" + " 'answer_start': [1]," + " 'text': ['This is a test text']" + " }," + " 'context': 'This is a test context.'," + " 'id': '1'," + " 'question': 'Is this a test?'," + " 'title': 'train test'" + "}" + ) + + preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} + targets_dict = [ + { + "paragraphs": [ + { + "qas": [ + { + "answers": [{"text": answer_text} for answer_text in target["answers"]["text"]], + "id": target["id"], + } + for target in targets + ] + } + ] + } + ] + scores: Tuple[Tensor, Tensor, Tensor] = _squad_update(preds_dict, targets_dict) + return _squad_compute(scores) From a403ccb6ed0a08cf6241da6793d45a884e1f88b7 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 15:18:37 +0530 Subject: [PATCH 02/17] Add SQuAD metric - module. --- docs/source/references/modules.rst | 7 ++ torchmetrics/__init__.py | 2 + torchmetrics/functional/text/squad.py | 9 +- torchmetrics/text/__init__.py | 1 + torchmetrics/text/squad.py | 165 ++++++++++++++++++++++++++ 5 files changed, 178 insertions(+), 6 deletions(-) create mode 100644 torchmetrics/text/squad.py diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 9088d8541a2..7e07dffdb70 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -640,6 +640,13 @@ SacreBLEUScore .. autoclass:: torchmetrics.SacreBLEUScore :noindex: + +SQuAD +~~~~~ + +.. autoclass:: torchmetrics.SQuAD + :noindex: + WER ~~~ diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index d4321776c5d..d349d842f82 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -74,6 +74,7 @@ MatchErrorRate, ROUGEScore, SacreBLEUScore, + SQuAD, ) from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402 @@ -143,6 +144,7 @@ "SNR", "SpearmanCorrcoef", "Specificity", + "SQuAD", "SSIM", "StatScores", "STOI", diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index c4f59b02951..d3a0913be23 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -17,7 +17,7 @@ import re import string from collections import Counter -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from torch import Tensor, tensor @@ -25,7 +25,7 @@ from torchmetrics.utilities import rank_zero_warn PREDS_TYPE = Dict[str, str] -TARGETS_TYPE = List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]] +TARGETS_TYPE = List[Dict[str, List[Dict[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]]]] def normalize_text(s: str) -> str: @@ -85,10 +85,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths) -> Tenso return torch.max(tensor(scores_for_ground_truths)) -def _squad_update( - preds: PREDS_TYPE, - targets: TARGETS_TYPE, -) -> Tuple[Tensor, Tensor, Tensor]: +def _squad_update(preds: PREDS_TYPE, targets: TARGETS_TYPE) -> Tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. Args: diff --git a/torchmetrics/text/__init__.py b/torchmetrics/text/__init__.py index f7cf05cdfa9..a0885b568c5 100644 --- a/torchmetrics/text/__init__.py +++ b/torchmetrics/text/__init__.py @@ -17,4 +17,5 @@ from torchmetrics.text.mer import MatchErrorRate # noqa: F401 from torchmetrics.text.rouge import ROUGEScore # noqa: F401 from torchmetrics.text.sacre_bleu import SacreBLEUScore # noqa: F401 +from torchmetrics.text.squad import SQuAD # noqa: F401 from torchmetrics.text.wer import WER # noqa: F401 diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py new file mode 100644 index 00000000000..e833a9807b4 --- /dev/null +++ b/torchmetrics/text/squad.py @@ -0,0 +1,165 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor + +from torchmetrics import Metric +from torchmetrics.functional.text.squad import _squad_compute, _squad_update + + +class SQuAD(Metric): + """Calculate `SQuAD Metric`_. + + Args: + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + + Example: + >>> from torchmetrics.text.squad import SQuAD + >>> predictions = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] + >>> references = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] + >>> _sqaud = SQuAD() + >>> _sqaud(predictions, references) + {'exact_match': tensor(100.), 'f1': tensor(100.)} + + References: + [1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin + Lopyrev, Percy Liang `SQuAD Metric`_ . + """ + + higher_is_better = True + + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state(name="f1_score", default=[], dist_reduce_fx=None) + self.add_state(name="exact_match", default=[], dist_reduce_fx=None) + self.add_state(name="total", default=[], dist_reduce_fx=None) + + def update( + self, + preds: List[Dict[str, str]], + targets: List[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]]], + ) -> None: # type: ignore + """Compute F1 Score and Exact Match for a collection of predictions and references. + + Args: + preds: + An iterable of predicted sentences or a single predicted sentence. + targets: + An iterable of target sentences or a single target sentence. + + Raises: + KeyError: + If the required keys are missing in either predictions or targets. + """ + + for pred in preds: + keys = pred.keys() + if "prediction_text" not in keys or "id" not in keys: + raise KeyError( + "Expected keys in a single prediction are 'prediction_text' and 'id'." + "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string." + ) + + for target in targets: + keys = target.keys() + if "answers" not in keys or "id" not in keys: + raise KeyError( + "Expected keys in a single target are 'answers' and 'id'." + "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key" + "string.\n" + "SQuAD Format: " + "{" + " 'answers': {" + " 'answer_start': [1]," + " 'text': ['This is a test text']" + " }," + " 'context': 'This is a test context.'," + " 'id': '1'," + " 'question': 'Is this a test?'," + " 'title': 'train test'" + "}" + ) + + answers_keys = target["answers"].keys() + if "text" not in answers_keys: + raise KeyError( + "Expected keys in a 'answers' are 'text'." + "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" + "SQuAD Format: " + "{" + " 'answers': {" + " 'answer_start': [1]," + " 'text': ['This is a test text']" + " }," + " 'context': 'This is a test context.'," + " 'id': '1'," + " 'question': 'Is this a test?'," + " 'title': 'train test'" + "}" + ) + + preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} + targets_dict = [ + { + "paragraphs": [ + { + "qas": [ + { + "answers": [{"text": answer_text} for answer_text in target["answers"]["text"]], + "id": target["id"], + } + for target in targets + ] + } + ] + } + ] + scores: Tuple[Tensor, Tensor, Tensor] = _squad_update(preds_dict, targets_dict) + getattr(self, "f1_score").append(scores[0].to(self.device)) + getattr(self, "exact_match").append(scores[1].to(self.device)) + getattr(self, "total").append(scores[2].to(self.device)) + + def compute(self) -> Dict[str, Tensor]: + """Aggregate the F1 Score and Exact match for the batch. + + Return: + Dictionary containing the F1 score, Exact match score for the batch. + """ + f1_score = torch.sum(torch.tensor(getattr(self, "f1_score"))) + exact_match = torch.sum(torch.tensor(getattr(self, "exact_match"))) + total = torch.sum(torch.tensor(getattr(self, "total"))) + return _squad_compute((f1_score, exact_match, total)) From 723d27dc9c44d4a8da7df5520c41e026f091ceab Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 17:52:43 +0530 Subject: [PATCH 03/17] Add tests for squad metric. --- tests/text/test_squad.py | 103 ++++++++++++++++++++++++++ torchmetrics/functional/text/squad.py | 20 +++-- torchmetrics/text/squad.py | 13 +++- 3 files changed, 127 insertions(+), 9 deletions(-) create mode 100644 tests/text/test_squad.py diff --git a/tests/text/test_squad.py b/tests/text/test_squad.py new file mode 100644 index 00000000000..33cc5925a1a --- /dev/null +++ b/tests/text/test_squad.py @@ -0,0 +1,103 @@ +import os + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from tests.helpers.testers import _assert_allclose, _assert_tensor +from torchmetrics.functional.text import squad +from torchmetrics.text.squad import SQuAD + +os.environ["TOKENIZERS_PARALLELISM"] = "1" + +SAMPLE_1 = { + "exact_match": 100.0, + "f1": 100.0, + "predictions": {"prediction_text": "1976", "id": "id1"}, + "references": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"}, +} + +SAMPLE_2 = { + "exact_match": 0.0, + "f1": 0.0, + "predictions": {"prediction_text": "Hello", "id": "id2"}, + "references": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"}, +} + +BATCH = { + "exact_match": [100.0, 0.0], + "f1": [100.0, 0.0], + "predictions": [ + {"prediction_text": "1976", "id": "id1"}, + {"prediction_text": "Hello", "id": "id2"}, + ], + "references": [ + {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"}, + {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"}, + ], +} + + +@pytest.mark.parametrize( + "preds,targets,exact_match,f1", + [ + (SAMPLE_1["predictions"], SAMPLE_1["references"], SAMPLE_1["exact_match"], SAMPLE_1["exact_match"]), + (SAMPLE_2["predictions"], SAMPLE_2["references"], SAMPLE_2["exact_match"], SAMPLE_2["exact_match"]), + ], +) +def test_score_fn(preds, targets, exact_match, f1): + """Tests for functional.""" + metrics_score = squad(preds, targets) + _assert_tensor(metrics_score["exact_match"]) + _assert_tensor(metrics_score["f1"]) + _assert_allclose(metrics_score["exact_match"], exact_match) + _assert_allclose(metrics_score["f1"], f1) + + +@pytest.mark.parametrize( + "preds,targets,exact_match,f1", + [(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])], +) +def test_accumulation(preds, targets, exact_match, f1): + """Tests for metric works with accumulation.""" + squad_metric = SQuAD() + for pred, target in zip(preds, targets): + squad_metric.update(preds=[pred], targets=[target]) + metrics_score = squad_metric.compute() + + _assert_tensor(metrics_score["exact_match"]) + _assert_tensor(metrics_score["f1"]) + _assert_allclose(metrics_score["exact_match"], torch.mean(torch.tensor(exact_match))) + _assert_allclose(metrics_score["f1"], torch.mean(torch.tensor(f1))) + + +def _bert_score_ddp(rank, world_size, pred, target, exact_match, f1): + """Define a DDP process for BERTScore.""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("gloo", rank=rank, world_size=world_size) + squad_metric = SQuAD() + squad_metric.update(pred, target) + metrics_score = squad_metric.compute() + _assert_tensor(metrics_score["exact_match"]) + _assert_tensor(metrics_score["f1"]) + _assert_allclose(metrics_score["exact_match"], exact_match) + _assert_allclose(metrics_score["f1"], f1) + dist.destroy_process_group() + + +def _test_score_ddp_fn(rank, world_size, preds, targets, exact_match, f1): + """Core functionality for the `test_score_ddp` test.""" + _bert_score_ddp(rank, world_size, preds[rank], targets[rank], exact_match[rank], f1[rank]) + + +@pytest.mark.parametrize( + "preds,targets,exact_match,f1", + [(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])], +) +@pytest.mark.skipif(not dist.is_available(), reason="test requires torch distributed") +def test_score_ddp(preds, targets, exact_match, f1): + """Tests for metric using DDP.""" + world_size = 2 + mp.spawn(_test_score_ddp_fn, args=(world_size, preds, targets, exact_match, f1), nprocs=world_size, join=False) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index d3a0913be23..7b4761b3f5c 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -24,8 +24,10 @@ from torchmetrics.utilities import rank_zero_warn -PREDS_TYPE = Dict[str, str] -TARGETS_TYPE = List[Dict[str, List[Dict[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]]]] +SINGLE_PRED_TYPE = Dict[str, str] +PREDS_TYPE = Union[SINGLE_PRED_TYPE, List[SINGLE_PRED_TYPE]] +SINGLE_TARGET_TYPE = Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]] +TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, List[SINGLE_TARGET_TYPE]] def normalize_text(s: str) -> str: @@ -85,7 +87,9 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths) -> Tenso return torch.max(tensor(scores_for_ground_truths)) -def _squad_update(preds: PREDS_TYPE, targets: TARGETS_TYPE) -> Tuple[Tensor, Tensor, Tensor]: +def _squad_update( + preds: Dict[str, str], targets: List[Dict[str, List[Dict[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]]]] +) -> Tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. Args: @@ -158,8 +162,8 @@ def _squad_compute(scores: Tuple[Tensor, Tensor, Tensor]) -> Dict[str, Tensor]: def squad( - preds: List[Dict[str, str]], - targets: List[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]]], + preds: PREDS_TYPE, + targets: TARGETS_TYPE, ) -> Dict[str, Tensor]: """Calculate `SQuAD Metric`_ . @@ -188,6 +192,12 @@ def squad( Lopyrev, Percy Liang `SQuAD Metric`_ . """ + if isinstance(preds, Dict): + preds = [preds] + + if isinstance(targets, Dict): + targets = [targets] + for pred in preds: keys = pred.keys() if "prediction_text" not in keys or "id" not in keys: diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index e833a9807b4..1cb4d72f3c6 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch import Tensor from torchmetrics import Metric -from torchmetrics.functional.text.squad import _squad_compute, _squad_update +from torchmetrics.functional.text.squad import PREDS_TYPE, TARGETS_TYPE, _squad_compute, _squad_update class SQuAD(Metric): @@ -70,8 +70,8 @@ def __init__( def update( self, - preds: List[Dict[str, str]], - targets: List[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]]], + preds: PREDS_TYPE, + targets: TARGETS_TYPE, ) -> None: # type: ignore """Compute F1 Score and Exact Match for a collection of predictions and references. @@ -85,6 +85,11 @@ def update( KeyError: If the required keys are missing in either predictions or targets. """ + if isinstance(preds, Dict): + preds = [preds] + + if isinstance(targets, Dict): + targets = [targets] for pred in preds: keys = pred.keys() From ec1c3659d6cc0f1da98ac1a87bd72dfb6d0c57ef Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 18:47:15 +0530 Subject: [PATCH 04/17] Add documentation. --- torchmetrics/functional/text/squad.py | 33 +++++++++++++++++++++++++-- torchmetrics/text/squad.py | 32 ++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index 7b4761b3f5c..645bb431af6 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -169,9 +169,38 @@ def squad( Args: preds: - An iterable of predicted sentences or a single predicted sentence. + A Dictionary or List of Dictionary-s that map `id` and `prediction_text` to the respective values. + + Example prediction: + + .. code-block:: python + + {"prediction_text": "TorchMetrics is awesome", "id": "123"} + targets: - An iterable of target sentences or a single target sentence. + A Dictioinary or List of Dictionary-s that contain the `answers` and `id` in the SQuAD Format. + + Example target: + + .. code-block:: python + + { + 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], + 'id': '1', + } + + Reference SQuAD Format: + + .. code-block:: python + + { + 'answers': {'answer_start': [1], 'text': ['This is a test text']}, + 'context': 'This is a test context.', + 'id': '1', + 'question': 'Is this a test?', + 'title': 'train test' + } + Return: Dictionary containing the F1 score, Exact match score for the batch. diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 1cb4d72f3c6..d68c8e62206 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -77,9 +77,37 @@ def update( Args: preds: - An iterable of predicted sentences or a single predicted sentence. + A Dictionary or List of Dictionary-s that map `id` and `prediction_text` to the respective values. + + Example prediction: + + .. code-block:: python + + {"prediction_text": "TorchMetrics is awesome", "id": "123"} + targets: - An iterable of target sentences or a single target sentence. + A Dictioinary or List of Dictionary-s that contain the `answers` and `id` in the SQuAD Format. + + Example target: + + .. code-block:: python + + { + 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], + 'id': '1', + } + + Reference SQuAD Format: + + .. code-block:: python + + { + 'answers': {'answer_start': [1], 'text': ['This is a test text']}, + 'context': 'This is a test context.', + 'id': '1', + 'question': 'Is this a test?', + 'title': 'train test' + } Raises: KeyError: From e8c6bc78a9b811d96000d9ee1a9a6c4f0b44fde0 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 18:48:53 +0530 Subject: [PATCH 05/17] Update CHANGELOG. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 041342a62e7..c0767a54a90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added NLP metrics: - `MatchErrorRate` ([#619](https://github.com/PyTorchLightning/metrics/pull/619)) + - `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623)) ### Changed From f5a810e9352184adf592c5f44c1a1bba3838ce94 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 21:16:02 +0530 Subject: [PATCH 06/17] Changes from code review. --- tests/text/test_squad.py | 6 +-- torchmetrics/functional/text/squad.py | 59 +++++++++------------------ torchmetrics/text/squad.py | 45 +++++++------------- 3 files changed, 37 insertions(+), 73 deletions(-) diff --git a/tests/text/test_squad.py b/tests/text/test_squad.py index 33cc5925a1a..3c018253f40 100644 --- a/tests/text/test_squad.py +++ b/tests/text/test_squad.py @@ -72,8 +72,8 @@ def test_accumulation(preds, targets, exact_match, f1): _assert_allclose(metrics_score["f1"], torch.mean(torch.tensor(f1))) -def _bert_score_ddp(rank, world_size, pred, target, exact_match, f1): - """Define a DDP process for BERTScore.""" +def _squad_score_ddp(rank, world_size, pred, target, exact_match, f1): + """Define a DDP process for SQuAD metric.""" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("gloo", rank=rank, world_size=world_size) @@ -89,7 +89,7 @@ def _bert_score_ddp(rank, world_size, pred, target, exact_match, f1): def _test_score_ddp_fn(rank, world_size, preds, targets, exact_match, f1): """Core functionality for the `test_score_ddp` test.""" - _bert_score_ddp(rank, world_size, preds[rank], targets[rank], exact_match[rank], f1[rank]) + _squad_score_ddp(rank, world_size, preds[rank], targets[rank], exact_match[rank], f1[rank]) @pytest.mark.parametrize( diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index 645bb431af6..a9ef8dc775d 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -19,7 +19,6 @@ from collections import Counter from typing import Dict, List, Tuple, Union -import torch from torch import Tensor, tensor from torchmetrics.utilities import rank_zero_warn @@ -30,7 +29,16 @@ TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, List[SINGLE_TARGET_TYPE]] -def normalize_text(s: str) -> str: +SQuAD_FORMAT = { + "answers": {"answer_start": [1], "text": ["This is a test text"]}, + "context": "This is a test context.", + "id": "1", + "question": "Is this a test?", + "title": "train test", +} + + +def _normalize_text(s: str) -> str: """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text: str) -> str: @@ -51,15 +59,13 @@ def lower(text: str) -> str: def get_tokens(s: str) -> List[str]: """Split a sentence into separate tokens.""" - if not s: - return [] - return normalize_text(s).split() + return [] if not s else _normalize_text(s).split() def compute_f1_score(predictied_answer, target_answer) -> Tensor: """Compute F1 Score for two sentences.""" - target_tokens: Tensor = get_tokens(target_answer) - predicted_tokens: Tensor = get_tokens(predictied_answer) + target_tokens: List[str] = get_tokens(target_answer) + predicted_tokens: List[str] = get_tokens(predictied_answer) common = Counter(target_tokens) & Counter(predicted_tokens) num_same: Tensor = tensor(sum(common.values())) if len(target_tokens) == 0 or len(predicted_tokens) == 0: @@ -75,16 +81,12 @@ def compute_f1_score(predictied_answer, target_answer) -> Tensor: def compute_exact_match_score(prediction, ground_truth) -> Tensor: """Compute Exact Match for two sentences.""" - return tensor(int(normalize_text(prediction) == normalize_text(ground_truth))) + return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth))) def metric_max_over_ground_truths(metric_fn, prediction, ground_truths) -> Tensor: """Calculate maximum score for a predicted answer with all reference answers.""" - scores_for_ground_truths: List[Tensor] = [] - for ground_truth in ground_truths: - score: Tensor = metric_fn(prediction, ground_truth) - scores_for_ground_truths.append(score) - return torch.max(tensor(scores_for_ground_truths)) + return max(metric_fn(prediction, truth) for truth in ground_truths) def _squad_update( @@ -143,7 +145,7 @@ def _squad_update( return f1, exact_match, total -def _squad_compute(scores: Tuple[Tensor, Tensor, Tensor]) -> Dict[str, Tensor]: +def _squad_compute(f1: Tensor, exact_match: Tensor, total: Tensor) -> Dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. Args: @@ -153,9 +155,6 @@ def _squad_compute(scores: Tuple[Tensor, Tensor, Tensor]) -> Dict[str, Tensor]: Return: Dictionary containing the F1 score, Exact match score for the batch. """ - f1: Tensor = scores[0] - exact_match: Tensor = scores[1] - total: Tensor = scores[2] exact_match = 100.0 * exact_match / total f1 = 100.0 * f1 / total return {"exact_match": exact_match, "f1": f1} @@ -242,16 +241,7 @@ def squad( "Expected keys in a single target are 'answers' and 'id'." "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n" "SQuAD Format: " - "{" - " 'answers': {" - " 'answer_start': [1]," - " 'text': ['This is a test text']" - " }," - " 'context': 'This is a test context.'," - " 'id': '1'," - " 'question': 'Is this a test?'," - " 'title': 'train test'" - "}" + f"{SQuAD_FORMAT}" ) answers_keys = target["answers"].keys() @@ -260,16 +250,7 @@ def squad( "Expected keys in a 'answers' are 'text'." "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" "SQuAD Format: " - "{" - " 'answers': {" - " 'answer_start': [1]," - " 'text': ['This is a test text']" - " }," - " 'context': 'This is a test context.'," - " 'id': '1'," - " 'question': 'Is this a test?'," - " 'title': 'train test'" - "}" + f"{SQuAD_FORMAT}" ) preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} @@ -288,5 +269,5 @@ def squad( ] } ] - scores: Tuple[Tensor, Tensor, Tensor] = _squad_update(preds_dict, targets_dict) - return _squad_compute(scores) + f1, exact_match, total = _squad_update(preds_dict, targets_dict) + return _squad_compute(f1, exact_match, total) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index d68c8e62206..91cda85dbb1 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional import torch from torch import Tensor from torchmetrics import Metric -from torchmetrics.functional.text.squad import PREDS_TYPE, TARGETS_TYPE, _squad_compute, _squad_update +from torchmetrics.functional.text.squad import PREDS_TYPE, TARGETS_TYPE, SQuAD_FORMAT, _squad_compute, _squad_update class SQuAD(Metric): - """Calculate `SQuAD Metric`_. + """Calculate `SQuAD Metric`_ which corresponds to the scoring script for version 1 of the Stanford Question + Answering Dataset (SQuAD). Args: compute_on_step: @@ -36,11 +37,11 @@ class SQuAD(Metric): will be used to perform the allgather. Example: - >>> from torchmetrics.text.squad import SQuAD + >>> from torchmetrics import SQuAD >>> predictions = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> references = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] - >>> _sqaud = SQuAD() - >>> _sqaud(predictions, references) + >>> sqaud = SQuAD() + >>> sqaud(predictions, references) {'exact_match': tensor(100.), 'f1': tensor(100.)} References: @@ -135,16 +136,7 @@ def update( "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key" "string.\n" "SQuAD Format: " - "{" - " 'answers': {" - " 'answer_start': [1]," - " 'text': ['This is a test text']" - " }," - " 'context': 'This is a test context.'," - " 'id': '1'," - " 'question': 'Is this a test?'," - " 'title': 'train test'" - "}" + f"{SQuAD_FORMAT}" ) answers_keys = target["answers"].keys() @@ -153,16 +145,7 @@ def update( "Expected keys in a 'answers' are 'text'." "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" "SQuAD Format: " - "{" - " 'answers': {" - " 'answer_start': [1]," - " 'text': ['This is a test text']" - " }," - " 'context': 'This is a test context.'," - " 'id': '1'," - " 'question': 'Is this a test?'," - " 'title': 'train test'" - "}" + f"{SQuAD_FORMAT}" ) preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} @@ -181,10 +164,10 @@ def update( ] } ] - scores: Tuple[Tensor, Tensor, Tensor] = _squad_update(preds_dict, targets_dict) - getattr(self, "f1_score").append(scores[0].to(self.device)) - getattr(self, "exact_match").append(scores[1].to(self.device)) - getattr(self, "total").append(scores[2].to(self.device)) + f1_score, exact_match, total = _squad_update(preds_dict, targets_dict) + getattr(self, "f1_score").append(f1_score.to(self.device)) + getattr(self, "exact_match").append(exact_match.to(self.device)) + getattr(self, "total").append(total.to(self.device)) def compute(self) -> Dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. @@ -195,4 +178,4 @@ def compute(self) -> Dict[str, Tensor]: f1_score = torch.sum(torch.tensor(getattr(self, "f1_score"))) exact_match = torch.sum(torch.tensor(getattr(self, "exact_match"))) total = torch.sum(torch.tensor(getattr(self, "total"))) - return _squad_compute((f1_score, exact_match, total)) + return _squad_compute(f1_score, exact_match, total) From dd5500097b0605b73d36cb5bf76599f134afc22b Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 21:23:21 +0530 Subject: [PATCH 07/17] Changes from code review. --- torchmetrics/text/squad.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 91cda85dbb1..eef29d2926d 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -139,8 +139,7 @@ def update( f"{SQuAD_FORMAT}" ) - answers_keys = target["answers"].keys() - if "text" not in answers_keys: + if "text" not in target["answers"].keys(): raise KeyError( "Expected keys in a 'answers' are 'text'." "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" @@ -165,9 +164,9 @@ def update( } ] f1_score, exact_match, total = _squad_update(preds_dict, targets_dict) - getattr(self, "f1_score").append(f1_score.to(self.device)) - getattr(self, "exact_match").append(exact_match.to(self.device)) - getattr(self, "total").append(total.to(self.device)) + self.f1_score.append(f1_score.to(self.device)) + self.exact_match.append(exact_match.to(self.device)) + self.total.append(total.to(self.device)) def compute(self) -> Dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. @@ -175,7 +174,7 @@ def compute(self) -> Dict[str, Tensor]: Return: Dictionary containing the F1 score, Exact match score for the batch. """ - f1_score = torch.sum(torch.tensor(getattr(self, "f1_score"))) - exact_match = torch.sum(torch.tensor(getattr(self, "exact_match"))) - total = torch.sum(torch.tensor(getattr(self, "total"))) + f1_score = torch.sum(torch.tensor(self.f1_score)) + exact_match = torch.sum(torch.tensor(self.exact_match)) + total = torch.sum(torch.tensor(self.total)) return _squad_compute(f1_score, exact_match, total) From a6920e48e8d276082ac37d95d6fe9cbf3d7cd93f Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Mon, 15 Nov 2021 22:09:50 +0530 Subject: [PATCH 08/17] Fix mypy issues and ignore few. --- torchmetrics/functional/text/squad.py | 25 ++++++++++++++----------- torchmetrics/text/squad.py | 26 ++++++++++---------------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index a9ef8dc775d..c9362cc2e98 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -17,7 +17,7 @@ import re import string from collections import Counter -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union from torch import Tensor, tensor @@ -27,7 +27,7 @@ PREDS_TYPE = Union[SINGLE_PRED_TYPE, List[SINGLE_PRED_TYPE]] SINGLE_TARGET_TYPE = Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]] TARGETS_TYPE = Union[SINGLE_TARGET_TYPE, List[SINGLE_TARGET_TYPE]] - +UPDATE_METHOD_SINGLE_PRED_TYPE = Union[List[Dict[str, Union[str, int]]], str, Dict[str, Union[List[str], List[int]]]] SQuAD_FORMAT = { "answers": {"answer_start": [1], "text": ["This is a test text"]}, @@ -62,7 +62,7 @@ def get_tokens(s: str) -> List[str]: return [] if not s else _normalize_text(s).split() -def compute_f1_score(predictied_answer, target_answer) -> Tensor: +def compute_f1_score(predictied_answer: str, target_answer: str) -> Tensor: """Compute F1 Score for two sentences.""" target_tokens: List[str] = get_tokens(target_answer) predicted_tokens: List[str] = get_tokens(predictied_answer) @@ -79,18 +79,19 @@ def compute_f1_score(predictied_answer, target_answer) -> Tensor: return f1 -def compute_exact_match_score(prediction, ground_truth) -> Tensor: +def compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: """Compute Exact Match for two sentences.""" return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth))) -def metric_max_over_ground_truths(metric_fn, prediction, ground_truths) -> Tensor: +def metric_max_over_ground_truths(metric_fn: Callable, prediction: str, ground_truths: List[str]) -> Tensor: """Calculate maximum score for a predicted answer with all reference answers.""" return max(metric_fn(prediction, truth) for truth in ground_truths) def _squad_update( - preds: Dict[str, str], targets: List[Dict[str, List[Dict[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]]]] + preds: Dict[str, str], + targets: List[Dict[str, List[Dict[str, List[Dict[str, UPDATE_METHOD_SINGLE_PRED_TYPE]]]]]], ) -> Tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. @@ -137,8 +138,8 @@ def _squad_update( if qa["id"] not in preds: rank_zero_warn(f"Unanswered question {qa['id']} will receive score 0.") continue - ground_truths = list(map(lambda x: x["text"], qa["answers"])) - prediction = preds[qa["id"]] + ground_truths = list(map(lambda x: x["text"], qa["answers"])) # type: ignore + prediction = preds[qa["id"]] # type: ignore exact_match += metric_max_over_ground_truths(compute_exact_match_score, prediction, ground_truths) f1 += metric_max_over_ground_truths(compute_f1_score, prediction, ground_truths) @@ -244,8 +245,8 @@ def squad( f"{SQuAD_FORMAT}" ) - answers_keys = target["answers"].keys() - if "text" not in answers_keys: + answers: Dict[str, Union[List[str], List[int]]] = target["answers"] # type: ignore + if "text" not in answers.keys(): raise KeyError( "Expected keys in a 'answers' are 'text'." "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" @@ -260,7 +261,9 @@ def squad( { "qas": [ { - "answers": [{"text": answer_text} for answer_text in target["answers"]["text"]], + "answers": [ + {"text": answer_text} for answer_text in target["answers"]["text"] # type: ignore + ], "id": target["id"], } for target in targets diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index eef29d2926d..4d092cc0bd9 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -65,15 +65,11 @@ def __init__( dist_sync_fn=dist_sync_fn, ) - self.add_state(name="f1_score", default=[], dist_reduce_fx=None) - self.add_state(name="exact_match", default=[], dist_reduce_fx=None) - self.add_state(name="total", default=[], dist_reduce_fx=None) + self.add_state(name="f1_score", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state(name="exact_match", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state(name="total", default=torch.tensor(0, dtype=torch.int), dist_reduce_fx="sum") - def update( - self, - preds: PREDS_TYPE, - targets: TARGETS_TYPE, - ) -> None: # type: ignore + def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ignore """Compute F1 Score and Exact Match for a collection of predictions and references. Args: @@ -139,7 +135,8 @@ def update( f"{SQuAD_FORMAT}" ) - if "text" not in target["answers"].keys(): + answers: Dict[str, Any] = target["answers"] + if "text" not in answers.keys(): raise KeyError( "Expected keys in a 'answers' are 'text'." "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" @@ -164,9 +161,9 @@ def update( } ] f1_score, exact_match, total = _squad_update(preds_dict, targets_dict) - self.f1_score.append(f1_score.to(self.device)) - self.exact_match.append(exact_match.to(self.device)) - self.total.append(total.to(self.device)) + self.f1_score += f1_score + self.exact_match += exact_match + self.total += total def compute(self) -> Dict[str, Tensor]: """Aggregate the F1 Score and Exact match for the batch. @@ -174,7 +171,4 @@ def compute(self) -> Dict[str, Tensor]: Return: Dictionary containing the F1 score, Exact match score for the batch. """ - f1_score = torch.sum(torch.tensor(self.f1_score)) - exact_match = torch.sum(torch.tensor(self.exact_match)) - total = torch.sum(torch.tensor(self.total)) - return _squad_compute(f1_score, exact_match, total) + return _squad_compute(self.f1_score, self.exact_match, self.total) # type: ignore From 20bf46ffc2f67f5c561c32b7e1021160882727df Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Tue, 16 Nov 2021 10:56:11 +0530 Subject: [PATCH 09/17] Add typing and differentiability. --- torchmetrics/functional/text/squad.py | 4 ++-- torchmetrics/text/squad.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index c9362cc2e98..d9854c0de0f 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -17,7 +17,7 @@ import re import string from collections import Counter -from typing import Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union from torch import Tensor, tensor @@ -91,7 +91,7 @@ def metric_max_over_ground_truths(metric_fn: Callable, prediction: str, ground_t def _squad_update( preds: Dict[str, str], - targets: List[Dict[str, List[Dict[str, List[Dict[str, UPDATE_METHOD_SINGLE_PRED_TYPE]]]]]], + targets: List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]], ) -> Tuple[Tensor, Tensor, Tensor]: """Compute F1 Score and Exact Match for a collection of predictions and references. diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 4d092cc0bd9..b33743bb262 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -49,8 +49,13 @@ class SQuAD(Metric): Lopyrev, Percy Liang `SQuAD Metric`_ . """ + is_differentiable = False higher_is_better = True + f1_score: Tensor + exact_match: Tensor + total: Tensor + def __init__( self, compute_on_step: bool = True, @@ -135,7 +140,7 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign f"{SQuAD_FORMAT}" ) - answers: Dict[str, Any] = target["answers"] + answers: Dict[str, Any] = target["answers"] # type: ignore if "text" not in answers.keys(): raise KeyError( "Expected keys in a 'answers' are 'text'." @@ -151,7 +156,9 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign { "qas": [ { - "answers": [{"text": answer_text} for answer_text in target["answers"]["text"]], + "answers": [ + {"text": answer_text} for answer_text in target["answers"]["text"] # type: ignore + ], "id": target["id"], } for target in targets @@ -171,4 +178,4 @@ def compute(self) -> Dict[str, Tensor]: Return: Dictionary containing the F1 score, Exact match score for the batch. """ - return _squad_compute(self.f1_score, self.exact_match, self.total) # type: ignore + return _squad_compute(self.f1_score, self.exact_match, self.total) From d359290a3498b2d8e9ece27ca25996b33bcb230c Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 17 Nov 2021 16:44:17 +0530 Subject: [PATCH 10/17] Update function signature typing. --- torchmetrics/functional/text/squad.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index d9854c0de0f..a5023452811 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -84,7 +84,9 @@ def compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth))) -def metric_max_over_ground_truths(metric_fn: Callable, prediction: str, ground_truths: List[str]) -> Tensor: +def metric_max_over_ground_truths( + metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: List[str] +) -> Tensor: """Calculate maximum score for a predicted answer with all reference answers.""" return max(metric_fn(prediction, truth) for truth in ground_truths) From 9d9774584a588a4be643715e2a6d9a6259bede7e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Nov 2021 17:50:15 +0100 Subject: [PATCH 11/17] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen --- tests/text/test_squad.py | 2 - torchmetrics/functional/text/squad.py | 60 ++++++++++----------------- torchmetrics/text/squad.py | 19 ++++----- 3 files changed, 30 insertions(+), 51 deletions(-) diff --git a/tests/text/test_squad.py b/tests/text/test_squad.py index 3c018253f40..0f52681d605 100644 --- a/tests/text/test_squad.py +++ b/tests/text/test_squad.py @@ -9,8 +9,6 @@ from torchmetrics.functional.text import squad from torchmetrics.text.squad import SQuAD -os.environ["TOKENIZERS_PARALLELISM"] = "1" - SAMPLE_1 = { "exact_match": 100.0, "f1": 100.0, diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index a5023452811..8cf43f9bf50 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -57,15 +57,15 @@ def lower(text: str) -> str: return white_space_fix(remove_articles(remove_punc(lower(s)))) -def get_tokens(s: str) -> List[str]: +def _get_tokens(s: str) -> List[str]: """Split a sentence into separate tokens.""" return [] if not s else _normalize_text(s).split() -def compute_f1_score(predictied_answer: str, target_answer: str) -> Tensor: +def _compute_f1_score(predicted_answer: str, target_answer: str) -> Tensor: """Compute F1 Score for two sentences.""" - target_tokens: List[str] = get_tokens(target_answer) - predicted_tokens: List[str] = get_tokens(predictied_answer) + target_tokens = _get_tokens(target_answer) + predicted_tokens = _get_tokens(predicted_answer) common = Counter(target_tokens) & Counter(predicted_tokens) num_same: Tensor = tensor(sum(common.values())) if len(target_tokens) == 0 or len(predicted_tokens) == 0: @@ -79,12 +79,12 @@ def compute_f1_score(predictied_answer: str, target_answer: str) -> Tensor: return f1 -def compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: +def _compute_exact_match_score(prediction: str, ground_truth: str) -> Tensor: """Compute Exact Match for two sentences.""" return tensor(int(_normalize_text(prediction) == _normalize_text(ground_truth))) -def metric_max_over_ground_truths( +def _metric_max_over_ground_truths( metric_fn: Callable[[str, str], Tensor], prediction: str, ground_truths: List[str] ) -> Tensor: """Calculate maximum score for a predicted answer with all reference answers.""" @@ -112,27 +112,16 @@ def _squad_update( >>> predictions = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> targets = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} - >>> targets_dict = [ - ... { - ... "paragraphs": [ - ... { - ... "qas": [ - ... { - ... "answers": [{"text": answer_text} for answer_text in target["answers"]["text"]], - ... "id": target["id"], - ... } - ... for target in targets - ... ] - ... } - ... ] - ... } - ... ] + >>> targets_dict = [dict(paragraphs=[ + ... dict(qas=[dict(answers=[{"text": answer_text} for answer_text in target["answers"]["text"]], id=target["id"]) + ... for target in targets] + ... )])] >>> _squad_update(preds_dict, targets_dict) (tensor(1.), tensor(1.), tensor(1)) """ - f1: Tensor = tensor(0.0) - exact_match: Tensor = tensor(0.0) - total: Tensor = tensor(0) + f1 = tensor(0.0) + exact_match = tensor(0.0) + total = tensor(0) for article in targets: for paragraph in article["paragraphs"]: for qa in paragraph["qas"]: @@ -142,8 +131,8 @@ def _squad_update( continue ground_truths = list(map(lambda x: x["text"], qa["answers"])) # type: ignore prediction = preds[qa["id"]] # type: ignore - exact_match += metric_max_over_ground_truths(compute_exact_match_score, prediction, ground_truths) - f1 += metric_max_over_ground_truths(compute_f1_score, prediction, ground_truths) + exact_match += _metric_max_over_ground_truths(_compute_exact_match_score, prediction, ground_truths) + f1 += _metric_max_over_ground_truths(_compute_f1_score, prediction, ground_truths) return f1, exact_match, total @@ -258,21 +247,16 @@ def squad( preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} targets_dict = [ - { - "paragraphs": [ - { - "qas": [ - { - "answers": [ - {"text": answer_text} for answer_text in target["answers"]["text"] # type: ignore - ], - "id": target["id"], - } + dict(paragraphs=[ + dict(qas=[ + dict(answers=[ + dict(text=answer_text) for answer_text in target["answers"]["text"] # type: ignore + ], id=target["id"]) for target in targets ] - } + ) ] - } + ) ] f1, exact_match, total = _squad_update(preds_dict, targets_dict) return _squad_compute(f1, exact_match, total) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index b33743bb262..9968d9547cd 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -151,21 +151,18 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} targets_dict = [ - { - "paragraphs": [ - { - "qas": [ - { - "answers": [ - {"text": answer_text} for answer_text in target["answers"]["text"] # type: ignore + dict(paragraphs=[ + dict(qas=[ + dict(answers=[ + dict(text=answer_text) for answer_text in target["answers"]["text"] # type: ignore ], - "id": target["id"], - } + id=target["id"] + ) for target in targets ] - } + ) ] - } + ) ] f1_score, exact_match, total = _squad_update(preds_dict, targets_dict) self.f1_score += f1_score From f5cca24066a23dd6c8435ecb70e21369413b5822 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Nov 2021 16:50:46 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/text/squad.py | 13 +++++++++---- torchmetrics/text/squad.py | 11 +++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index 8cf43f9bf50..793accd3e11 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -247,11 +247,16 @@ def squad( preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} targets_dict = [ - dict(paragraphs=[ - dict(qas=[ - dict(answers=[ + dict( + paragraphs=[ + dict( + qas=[ + dict( + answers=[ dict(text=answer_text) for answer_text in target["answers"]["text"] # type: ignore - ], id=target["id"]) + ], + id=target["id"], + ) for target in targets ] ) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 9968d9547cd..bbb707828ce 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -151,12 +151,15 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} targets_dict = [ - dict(paragraphs=[ - dict(qas=[ - dict(answers=[ + dict( + paragraphs=[ + dict( + qas=[ + dict( + answers=[ dict(text=answer_text) for answer_text in target["answers"]["text"] # type: ignore ], - id=target["id"] + id=target["id"], ) for target in targets ] From ef592dd638ee556f7dbc45a88e0f637cdd43b996 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Nov 2021 17:53:43 +0100 Subject: [PATCH 13/17] Apply suggestions from code review --- torchmetrics/text/squad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index bbb707828ce..8784ef09d9b 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -134,8 +134,8 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign if "answers" not in keys or "id" not in keys: raise KeyError( "Expected keys in a single target are 'answers' and 'id'." - "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key" - "string.\n" + " Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key" + " string.\n" "SQuAD Format: " f"{SQuAD_FORMAT}" ) From b43436d355fffb2ccaf6900b664a3391d4592894 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Nov 2021 17:57:25 +0100 Subject: [PATCH 14/17] Apply suggestions from code review --- torchmetrics/text/squad.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 8784ef09d9b..857e78f617b 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -150,19 +150,16 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign ) preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} + _fn_answer = lamda tgt: dict( + answers=[ + dict(text=txt) for txt in tgt["answers"]["text"] # type: ignore + ], id=tgt["id"] + ) targets_dict = [ dict( paragraphs=[ dict( - qas=[ - dict( - answers=[ - dict(text=answer_text) for answer_text in target["answers"]["text"] # type: ignore - ], - id=target["id"], - ) - for target in targets - ] + qas=[_fn_answer(target) for target in targets] ) ] ) From b6176da6e2c9f658699ec2941004ec3ff4367e4e Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 18 Nov 2021 18:04:19 +0100 Subject: [PATCH 15/17] simple --- torchmetrics/functional/text/squad.py | 30 ++++++++------------------- torchmetrics/text/squad.py | 18 ++++------------ 2 files changed, 13 insertions(+), 35 deletions(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index 793accd3e11..42718d6e3e1 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -112,10 +112,11 @@ def _squad_update( >>> predictions = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> targets = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} - >>> targets_dict = [dict(paragraphs=[ - ... dict(qas=[dict(answers=[{"text": answer_text} for answer_text in target["answers"]["text"]], id=target["id"]) - ... for target in targets] - ... )])] + >>> targets_dict = [ + ... dict(paragraphs=[dict(qas=[dict(answers=[ + ... {"text": txt} for txt in target["answers"]["text"]], id=target["id"]) for target in targets + ... ])]) + ... ] >>> _squad_update(preds_dict, targets_dict) (tensor(1.), tensor(1.), tensor(1)) """ @@ -246,22 +247,9 @@ def squad( ) preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} - targets_dict = [ - dict( - paragraphs=[ - dict( - qas=[ - dict( - answers=[ - dict(text=answer_text) for answer_text in target["answers"]["text"] # type: ignore - ], - id=target["id"], - ) - for target in targets - ] - ) - ] - ) - ] + _fn_answer = lambda tgt: dict( + answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore + ) + targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])] f1, exact_match, total = _squad_update(preds_dict, targets_dict) return _squad_compute(f1, exact_match, total) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 857e78f617b..60dde54b17b 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -150,20 +150,10 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign ) preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} - _fn_answer = lamda tgt: dict( - answers=[ - dict(text=txt) for txt in tgt["answers"]["text"] # type: ignore - ], id=tgt["id"] - ) - targets_dict = [ - dict( - paragraphs=[ - dict( - qas=[_fn_answer(target) for target in targets] - ) - ] - ) - ] + _fn_answer = lambda tgt: dict( + answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore + ) + targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])] f1_score, exact_match, total = _squad_update(preds_dict, targets_dict) self.f1_score += f1_score self.exact_match += exact_match From e5939800a2fddab2271a9aaf457620ed2b77f3dc Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Sat, 20 Nov 2021 19:02:26 +0530 Subject: [PATCH 16/17] Remove extra typing in methods. --- torchmetrics/functional/text/squad.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index 42718d6e3e1..55d41cd1cca 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -67,15 +67,15 @@ def _compute_f1_score(predicted_answer: str, target_answer: str) -> Tensor: target_tokens = _get_tokens(target_answer) predicted_tokens = _get_tokens(predicted_answer) common = Counter(target_tokens) & Counter(predicted_tokens) - num_same: Tensor = tensor(sum(common.values())) + num_same = tensor(sum(common.values())) if len(target_tokens) == 0 or len(predicted_tokens) == 0: # If either is no-answer, then F1 is 1 if they agree, 0 otherwise return tensor(int(target_tokens == predicted_tokens)) if num_same == 0: return tensor(0.0) - precision: Tensor = 1.0 * num_same / tensor(len(predicted_tokens)) - recall: Tensor = 1.0 * num_same / tensor(len(target_tokens)) - f1: Tensor = (2 * precision * recall) / (precision + recall) + precision = 1.0 * num_same / tensor(len(predicted_tokens)) + recall = 1.0 * num_same / tensor(len(target_tokens)) + f1 = (2 * precision * recall) / (precision + recall) return f1 From e6224a75e23cdab90e587e414e39f111a3b941dc Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Sat, 20 Nov 2021 19:09:18 +0530 Subject: [PATCH 17/17] Reduce code duplication during input type checking. --- torchmetrics/functional/text/squad.py | 85 +++++++++++++++------------ torchmetrics/text/squad.py | 48 +++------------ 2 files changed, 55 insertions(+), 78 deletions(-) diff --git a/torchmetrics/functional/text/squad.py b/torchmetrics/functional/text/squad.py index 55d41cd1cca..f2fdcad8207 100644 --- a/torchmetrics/functional/text/squad.py +++ b/torchmetrics/functional/text/squad.py @@ -91,6 +91,52 @@ def _metric_max_over_ground_truths( return max(metric_fn(prediction, truth) for truth in ground_truths) +def _squad_input_check( + preds: PREDS_TYPE, targets: TARGETS_TYPE +) -> Tuple[Dict[str, str], List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]]]: + """Check for types and convert the input to necessary format to compute the input.""" + + if isinstance(preds, Dict): + preds = [preds] + + if isinstance(targets, Dict): + targets = [targets] + + for pred in preds: + keys = pred.keys() + if "prediction_text" not in keys or "id" not in keys: + raise KeyError( + "Expected keys in a single prediction are 'prediction_text' and 'id'." + "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string." + ) + + for target in targets: + keys = target.keys() + if "answers" not in keys or "id" not in keys: + raise KeyError( + "Expected keys in a single target are 'answers' and 'id'." + "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n" + "SQuAD Format: " + f"{SQuAD_FORMAT}" + ) + + answers: Dict[str, Union[List[str], List[int]]] = target["answers"] # type: ignore + if "text" not in answers.keys(): + raise KeyError( + "Expected keys in a 'answers' are 'text'." + "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" + "SQuAD Format: " + f"{SQuAD_FORMAT}" + ) + + preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} + _fn_answer = lambda tgt: dict( + answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore + ) + targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])] + return preds_dict, targets_dict + + def _squad_update( preds: Dict[str, str], targets: List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]], @@ -213,43 +259,6 @@ def squad( Lopyrev, Percy Liang `SQuAD Metric`_ . """ - if isinstance(preds, Dict): - preds = [preds] - - if isinstance(targets, Dict): - targets = [targets] - - for pred in preds: - keys = pred.keys() - if "prediction_text" not in keys or "id" not in keys: - raise KeyError( - "Expected keys in a single prediction are 'prediction_text' and 'id'." - "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string." - ) - - for target in targets: - keys = target.keys() - if "answers" not in keys or "id" not in keys: - raise KeyError( - "Expected keys in a single target are 'answers' and 'id'." - "Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key string.\n" - "SQuAD Format: " - f"{SQuAD_FORMAT}" - ) - - answers: Dict[str, Union[List[str], List[int]]] = target["answers"] # type: ignore - if "text" not in answers.keys(): - raise KeyError( - "Expected keys in a 'answers' are 'text'." - "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" - "SQuAD Format: " - f"{SQuAD_FORMAT}" - ) - - preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} - _fn_answer = lambda tgt: dict( - answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore - ) - targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])] + preds_dict, targets_dict = _squad_input_check(preds, targets) f1, exact_match, total = _squad_update(preds_dict, targets_dict) return _squad_compute(f1, exact_match, total) diff --git a/torchmetrics/text/squad.py b/torchmetrics/text/squad.py index 60dde54b17b..8b8cf6560d5 100644 --- a/torchmetrics/text/squad.py +++ b/torchmetrics/text/squad.py @@ -17,7 +17,13 @@ from torch import Tensor from torchmetrics import Metric -from torchmetrics.functional.text.squad import PREDS_TYPE, TARGETS_TYPE, SQuAD_FORMAT, _squad_compute, _squad_update +from torchmetrics.functional.text.squad import ( + PREDS_TYPE, + TARGETS_TYPE, + _squad_compute, + _squad_input_check, + _squad_update, +) class SQuAD(Metric): @@ -115,45 +121,7 @@ def update(self, preds: PREDS_TYPE, targets: TARGETS_TYPE) -> None: # type: ign KeyError: If the required keys are missing in either predictions or targets. """ - if isinstance(preds, Dict): - preds = [preds] - - if isinstance(targets, Dict): - targets = [targets] - - for pred in preds: - keys = pred.keys() - if "prediction_text" not in keys or "id" not in keys: - raise KeyError( - "Expected keys in a single prediction are 'prediction_text' and 'id'." - "Please make sure that 'prediction_text' maps to the answer string and 'id' maps to the key string." - ) - - for target in targets: - keys = target.keys() - if "answers" not in keys or "id" not in keys: - raise KeyError( - "Expected keys in a single target are 'answers' and 'id'." - " Please make sure that 'answers' maps to a `SQuAD` format dictionary and 'id' maps to the key" - " string.\n" - "SQuAD Format: " - f"{SQuAD_FORMAT}" - ) - - answers: Dict[str, Any] = target["answers"] # type: ignore - if "text" not in answers.keys(): - raise KeyError( - "Expected keys in a 'answers' are 'text'." - "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" - "SQuAD Format: " - f"{SQuAD_FORMAT}" - ) - - preds_dict = {prediction["id"]: prediction["prediction_text"] for prediction in preds} - _fn_answer = lambda tgt: dict( - answers=[dict(text=txt) for txt in tgt["answers"]["text"]], id=tgt["id"] # type: ignore - ) - targets_dict = [dict(paragraphs=[dict(qas=[_fn_answer(target) for target in targets])])] + preds_dict, targets_dict = _squad_input_check(preds, targets) f1_score, exact_match, total = _squad_update(preds_dict, targets_dict) self.f1_score += f1_score self.exact_match += exact_match