Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQuAD Metric. #623

Merged
merged 21 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added NLP metrics:
- `MatchErrorRate` ([#619](https://github.com/PyTorchLightning/metrics/pull/619))
- `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623))


- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@
.. _Python ROUGE Implementation: https://pypi.org/project/rouge-score/
.. _Scikit_Learn-Ranking.py: https: //github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
.. _Verified Uncertainty Calibration: https://arxiv.org/abs/1909.10155
.. _SQuAD Metric: https://arxiv.org/pdf/1606.05250.pdf
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ sacre_bleu_score [func]
.. autofunction:: torchmetrics.functional.sacre_bleu_score
:noindex:


squad [func]
~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.squad
:noindex:

wer [func]
~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,13 @@ SacreBLEUScore
.. autoclass:: torchmetrics.SacreBLEUScore
:noindex:


SQuAD
~~~~~

.. autoclass:: torchmetrics.SQuAD
:noindex:

WER
~~~

Expand Down
101 changes: 101 additions & 0 deletions tests/text/test_squad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from tests.helpers.testers import _assert_allclose, _assert_tensor
from torchmetrics.functional.text import squad
from torchmetrics.text.squad import SQuAD

SAMPLE_1 = {
"exact_match": 100.0,
"f1": 100.0,
"predictions": {"prediction_text": "1976", "id": "id1"},
"references": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
}

SAMPLE_2 = {
"exact_match": 0.0,
"f1": 0.0,
"predictions": {"prediction_text": "Hello", "id": "id2"},
"references": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
}

BATCH = {
"exact_match": [100.0, 0.0],
"f1": [100.0, 0.0],
"predictions": [
{"prediction_text": "1976", "id": "id1"},
{"prediction_text": "Hello", "id": "id2"},
],
"references": [
{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
{"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
],
}


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[
(SAMPLE_1["predictions"], SAMPLE_1["references"], SAMPLE_1["exact_match"], SAMPLE_1["exact_match"]),
(SAMPLE_2["predictions"], SAMPLE_2["references"], SAMPLE_2["exact_match"], SAMPLE_2["exact_match"]),
],
)
def test_score_fn(preds, targets, exact_match, f1):
"""Tests for functional."""
metrics_score = squad(preds, targets)
_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
_assert_allclose(metrics_score["exact_match"], exact_match)
_assert_allclose(metrics_score["f1"], f1)


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
)
def test_accumulation(preds, targets, exact_match, f1):
"""Tests for metric works with accumulation."""
squad_metric = SQuAD()
for pred, target in zip(preds, targets):
squad_metric.update(preds=[pred], targets=[target])
metrics_score = squad_metric.compute()

_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
_assert_allclose(metrics_score["exact_match"], torch.mean(torch.tensor(exact_match)))
_assert_allclose(metrics_score["f1"], torch.mean(torch.tensor(f1)))


def _squad_score_ddp(rank, world_size, pred, target, exact_match, f1):
"""Define a DDP process for SQuAD metric."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
squad_metric = SQuAD()
squad_metric.update(pred, target)
metrics_score = squad_metric.compute()
_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
_assert_allclose(metrics_score["exact_match"], exact_match)
_assert_allclose(metrics_score["f1"], f1)
dist.destroy_process_group()


def _test_score_ddp_fn(rank, world_size, preds, targets, exact_match, f1):
"""Core functionality for the `test_score_ddp` test."""
_squad_score_ddp(rank, world_size, preds[rank], targets[rank], exact_match[rank], f1[rank])


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
)
@pytest.mark.skipif(not dist.is_available(), reason="test requires torch distributed")
def test_score_ddp(preds, targets, exact_match, f1):
"""Tests for metric using DDP."""
world_size = 2
mp.spawn(_test_score_ddp_fn, args=(world_size, preds, targets, exact_match, f1), nprocs=world_size, join=False)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
MatchErrorRate,
ROUGEScore,
SacreBLEUScore,
SQuAD,
)
from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402

Expand Down Expand Up @@ -144,6 +145,7 @@
"SNR",
"SpearmanCorrcoef",
"Specificity",
"SQuAD",
"SSIM",
"StatScores",
"STOI",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
from torchmetrics.functional.text.squad import squad
from torchmetrics.functional.text.wer import wer

__all__ = [
Expand Down Expand Up @@ -130,6 +131,7 @@
"snr",
"spearman_corrcoef",
"specificity",
"squad",
"ssim",
"stat_scores",
"stoi",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from torchmetrics.functional.text.cer import char_error_rate # noqa: F401
from torchmetrics.functional.text.mer import match_error_rate # noqa: F401
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401
from torchmetrics.functional.text.squad import squad # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
Loading