Skip to content

Commit

Permalink
Add ChrF++ (#641)
Browse files Browse the repository at this point in the history
* [WIP] Add some basics for chrf++

* Add ChrF++. Need to add some last missing params and finish docs

* Add support for whitespace and return_sentence_level_score

* Update CHANGELOG.md

* Fix some docs

* Apply some suggestions from code review

* Fix two mypy issues
  • Loading branch information
stancld authored Nov 29, 2021
1 parent 1fef5ae commit f50bbd6
Show file tree
Hide file tree
Showing 11 changed files with 1,126 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `MatchErrorRate` ([#619](https://github.com/PyTorchLightning/metrics/pull/619))
- `WordInfoLost` and `WordInfoPreserved` ([#630](https://github.com/PyTorchLightning/metrics/pull/630))
- `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623))
- `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641))


- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@
.. _Scikit_Learn-Ranking.py: https: //github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
.. _Verified Uncertainty Calibration: https://arxiv.org/abs/1909.10155
.. _SQuAD Metric: https://arxiv.org/pdf/1606.05250.pdf
.. _chrF score: https://aclanthology.org/W15-3049.pdf
.. _chrF++ score: https://aclanthology.org/W17-4770.pdf
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ char_error_rate [func]
.. autofunction:: torchmetrics.functional.char_error_rate
:noindex:

chrf_score [func]
~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.chrf_score
:noindex:

match_error_rate [func]
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,12 @@ CharErrorRate
.. autoclass:: torchmetrics.CharErrorRate
:noindex:

CHRFScore
~~~~~~~~~

.. autoclass:: torchmetrics.CHRFScore
:noindex:

MatchErrorRate
~~~~~~~~~~~~~~

Expand Down
166 changes: 166 additions & 0 deletions tests/text/test_chrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from functools import partial
from typing import Sequence

import pytest
from torch import Tensor, tensor

from tests.text.helpers import INPUT_ORDER, TextTester
from torchmetrics.functional.text.chrf import chrf_score
from torchmetrics.text.chrf import CHRFScore
from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE

if _SACREBLEU_AVAILABLE:
from sacrebleu.metrics import CHRF

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu and adjusted
# EXAMPLE 1
HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party"
REFERENCE_1A = "It is a guide to action that ensures that the military will forever heed Party commands"
REFERENCE_2A = "It is a guiding principle which makes the military forces always being under the command of the Party"

# EXAMPLE 2
HYPOTHESIS_B = "he read the book because he was interested in world history"
REFERENCE_1B = "he was interested in world history because he read the book"
REFERENCE_2B = "It is the practical guide for the army always to heed the directions of the party"

# EXAMPLE 3 (add intentionally whitespaces)
HYPOTHESIS_C = "the cat the cat on the mat "
REFERENCE_1C = "the cat is on the mat "
REFERENCE_2C = "there is a cat on the mat"

TUPLE_OF_REFERENCES = (
((REFERENCE_1A, REFERENCE_2A), (REFERENCE_1B, REFERENCE_2B)),
((REFERENCE_1B, REFERENCE_2B), (REFERENCE_1C, REFERENCE_2C)),
)
TUPLE_OF_HYPOTHESES = ((HYPOTHESIS_A, HYPOTHESIS_B), (HYPOTHESIS_B, HYPOTHESIS_C))

BATCHES = {"preds": TUPLE_OF_HYPOTHESES, "targets": TUPLE_OF_REFERENCES}


def sacrebleu_chrf_fn(
targets: Sequence[Sequence[str]],
preds: Sequence[str],
char_order: int,
word_order: int,
lowercase: bool,
whitespace: bool,
) -> Tensor:
sacrebleu_chrf = CHRF(
char_order=char_order, word_order=word_order, lowercase=lowercase, whitespace=whitespace, eps_smoothing=True
)
# Sacrebleu CHRF expects different format of input
targets = [[target[i] for target in targets] for i in range(len(targets[0]))]
sacrebleu_chrf = sacrebleu_chrf.corpus_score(preds, targets).score / 100
return tensor(sacrebleu_chrf)


@pytest.mark.parametrize(
["char_order", "word_order", "lowercase", "whitespace"],
[
pytest.param(6, 2, False, False),
pytest.param(6, 2, False, True),
pytest.param(4, 2, True, False),
pytest.param(6, 0, True, False),
pytest.param(6, 0, True, True),
pytest.param(4, 0, False, True),
],
)
@pytest.mark.parametrize(
["preds", "targets"],
[
pytest.param(BATCHES["preds"], BATCHES["targets"]),
],
)
@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu")
class TestCHRFScore(TextTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_chrf_score_class(
self, ddp, dist_sync_on_step, preds, targets, char_order, word_order, lowercase, whitespace
):
metric_args = {
"n_char_order": char_order,
"n_word_order": word_order,
"lowercase": lowercase,
"whitespace": whitespace,
}
nltk_metric = partial(
sacrebleu_chrf_fn, char_order=char_order, word_order=word_order, lowercase=lowercase, whitespace=whitespace
)

self.run_class_metric_test(
ddp=ddp,
preds=preds,
targets=targets,
metric_class=CHRFScore,
sk_metric=nltk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)

def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace):
metric_args = {
"n_char_order": char_order,
"n_word_order": word_order,
"lowercase": lowercase,
"whitespace": whitespace,
}
nltk_metric = partial(
sacrebleu_chrf_fn, char_order=char_order, word_order=word_order, lowercase=lowercase, whitespace=whitespace
)

self.run_functional_metric_test(
preds,
targets,
metric_functional=chrf_score,
sk_metric=nltk_metric,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)

def test_chrf_score_differentiability(self, preds, targets, char_order, word_order, lowercase, whitespace):
metric_args = {
"n_char_order": char_order,
"n_word_order": word_order,
"lowercase": lowercase,
"whitespace": whitespace,
}

self.run_differentiability_test(
preds=preds,
targets=targets,
metric_module=CHRFScore,
metric_functional=chrf_score,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
)


def test_chrf_empty_functional():
hyp = []
ref = [[]]
assert chrf_score(ref, hyp) == tensor(0.0)


def test_chrf_empty_class():
chrf = CHRFScore()
hyp = []
ref = [[]]
assert chrf(ref, hyp) == tensor(0.0)


def test_chrf_return_sentence_level_score_functional():
hyp = [HYPOTHESIS_B]
ref = [[REFERENCE_1B, REFERENCE_2B]]
_, chrf_sentence_score = chrf_score(ref, hyp, return_sentence_level_score=True)
isinstance(chrf_sentence_score, Tensor)


def test_chrf_return_sentence_level_class():
chrf = CHRFScore(return_sentence_level_score=True)
hyp = [HYPOTHESIS_B]
ref = [[REFERENCE_1B, REFERENCE_2B]]
_, chrf_sentence_score = chrf(ref, hyp)
isinstance(chrf_sentence_score, Tensor)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
BERTScore,
BLEUScore,
CharErrorRate,
CHRFScore,
MatchErrorRate,
ROUGEScore,
SacreBLEUScore,
Expand All @@ -94,6 +95,7 @@
"BootStrapper",
"CalibrationError",
"CatMetric",
"CHRFScore",
"CohenKappa",
"ConfusionMatrix",
"CosineSimilarity",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from torchmetrics.functional.text.bert import bert_score
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.functional.text.cer import char_error_rate
from torchmetrics.functional.text.chrf import chrf_score
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
Expand All @@ -83,6 +84,7 @@
"bert_score",
"bleu_score",
"calibration_error",
"chrf_score",
"cohen_kappa",
"confusion_matrix",
"cosine_similarity",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
from torchmetrics.functional.text.cer import char_error_rate # noqa: F401
from torchmetrics.functional.text.chrf import chrf_score # noqa: F401
from torchmetrics.functional.text.mer import match_error_rate # noqa: F401
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score # noqa: F401
from torchmetrics.functional.text.squad import squad # noqa: F401
Expand Down
Loading

0 comments on commit f50bbd6

Please sign in to comment.