Skip to content

Commit

Permalink
Add SQuAD Metric (#623)
Browse files Browse the repository at this point in the history
* Add SQuAD metric - functional
* Add SQuAD metric - module
* Add tests for squad metric
* Add documentation
* Changes from code review
* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
6 people authored Nov 24, 2021
1 parent 157c0f6 commit 02ea8de
Show file tree
Hide file tree
Showing 11 changed files with 523 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added NLP metrics:
- `MatchErrorRate` ([#619](https://github.com/PyTorchLightning/metrics/pull/619))
- `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623))


- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@
.. _Python ROUGE Implementation: https://pypi.org/project/rouge-score/
.. _Scikit_Learn-Ranking.py: https: //github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
.. _Verified Uncertainty Calibration: https://arxiv.org/abs/1909.10155
.. _SQuAD Metric: https://arxiv.org/pdf/1606.05250.pdf
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ sacre_bleu_score [func]
.. autofunction:: torchmetrics.functional.sacre_bleu_score
:noindex:


squad [func]
~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.squad
:noindex:

wer [func]
~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,13 @@ SacreBLEUScore
.. autoclass:: torchmetrics.SacreBLEUScore
:noindex:


SQuAD
~~~~~

.. autoclass:: torchmetrics.SQuAD
:noindex:

WER
~~~

Expand Down
101 changes: 101 additions & 0 deletions tests/text/test_squad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from tests.helpers.testers import _assert_allclose, _assert_tensor
from torchmetrics.functional.text import squad
from torchmetrics.text.squad import SQuAD

SAMPLE_1 = {
"exact_match": 100.0,
"f1": 100.0,
"predictions": {"prediction_text": "1976", "id": "id1"},
"references": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
}

SAMPLE_2 = {
"exact_match": 0.0,
"f1": 0.0,
"predictions": {"prediction_text": "Hello", "id": "id2"},
"references": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
}

BATCH = {
"exact_match": [100.0, 0.0],
"f1": [100.0, 0.0],
"predictions": [
{"prediction_text": "1976", "id": "id1"},
{"prediction_text": "Hello", "id": "id2"},
],
"references": [
{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
{"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
],
}


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[
(SAMPLE_1["predictions"], SAMPLE_1["references"], SAMPLE_1["exact_match"], SAMPLE_1["exact_match"]),
(SAMPLE_2["predictions"], SAMPLE_2["references"], SAMPLE_2["exact_match"], SAMPLE_2["exact_match"]),
],
)
def test_score_fn(preds, targets, exact_match, f1):
"""Tests for functional."""
metrics_score = squad(preds, targets)
_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
_assert_allclose(metrics_score["exact_match"], exact_match)
_assert_allclose(metrics_score["f1"], f1)


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
)
def test_accumulation(preds, targets, exact_match, f1):
"""Tests for metric works with accumulation."""
squad_metric = SQuAD()
for pred, target in zip(preds, targets):
squad_metric.update(preds=[pred], targets=[target])
metrics_score = squad_metric.compute()

_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
_assert_allclose(metrics_score["exact_match"], torch.mean(torch.tensor(exact_match)))
_assert_allclose(metrics_score["f1"], torch.mean(torch.tensor(f1)))


def _squad_score_ddp(rank, world_size, pred, target, exact_match, f1):
"""Define a DDP process for SQuAD metric."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
squad_metric = SQuAD()
squad_metric.update(pred, target)
metrics_score = squad_metric.compute()
_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
_assert_allclose(metrics_score["exact_match"], exact_match)
_assert_allclose(metrics_score["f1"], f1)
dist.destroy_process_group()


def _test_score_ddp_fn(rank, world_size, preds, targets, exact_match, f1):
"""Core functionality for the `test_score_ddp` test."""
_squad_score_ddp(rank, world_size, preds[rank], targets[rank], exact_match[rank], f1[rank])


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
)
@pytest.mark.skipif(not dist.is_available(), reason="test requires torch distributed")
def test_score_ddp(preds, targets, exact_match, f1):
"""Tests for metric using DDP."""
world_size = 2
mp.spawn(_test_score_ddp_fn, args=(world_size, preds, targets, exact_match, f1), nprocs=world_size, join=False)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
MatchErrorRate,
ROUGEScore,
SacreBLEUScore,
SQuAD,
)
from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402

Expand Down Expand Up @@ -144,6 +145,7 @@
"SNR",
"SpearmanCorrcoef",
"Specificity",
"SQuAD",
"SSIM",
"StatScores",
"STOI",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
from torchmetrics.functional.text.squad import squad
from torchmetrics.functional.text.wer import wer

__all__ = [
Expand Down Expand Up @@ -130,6 +131,7 @@
"snr",
"spearman_corrcoef",
"specificity",
"squad",
"ssim",
"stat_scores",
"stoi",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from torchmetrics.functional.text.cer import char_error_rate # noqa: F401
from torchmetrics.functional.text.mer import match_error_rate # noqa: F401
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401
from torchmetrics.functional.text.squad import squad # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
Loading

0 comments on commit 02ea8de

Please sign in to comment.