Skip to content

Commit

Permalink
Add torchmetrics' own implementation of Rouge score metrics (#443)
Browse files Browse the repository at this point in the history
* Make Rouge-N working
* Add RougeL score calculation
* Add some docs + enable using Porter stemmer
* Enable RougeLSum calculation
* Add a few references and clean some parts
* Clean some minor stuff
* Return decimal_places argument to ROUGEScore and prepare depreciation warning
* Use 0 in (x,y) instead of x == 0 or y == 0
* Enable _rouge_score_update method takes List of sentence_results
* Replace dangerous default dict() values
* Replace _RougeScore class with dict
* Update error messages

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
  • Loading branch information
4 people authored Aug 17, 2021
1 parent 9df5f88 commit a2712eb
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 208 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

- Removed `rouge-score` as dependency for text package ([#443](https://github.com/PyTorchLightning/metrics/pull/443))

### Fixed

- Fixed bug in the ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448))
- Fixed ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448))


## [0.5.0] - 2021-08-09
Expand Down
3 changes: 3 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ mir_eval>=0.6
#pesq @ https://github.com/ludlows/python-pesq/archive/refs/heads/master.zip
#SRMRpy @ https://github.com/jfsantos/SRMRpy/archive/refs/heads/master.zip
speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip

# text
rouge-score>=0.0.4
1 change: 0 additions & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
jiwer>=2.2.0
nltk>=3.6
rouge-score>=0.0.4
bert-score==0.3.10
188 changes: 89 additions & 99 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import pytest
import torch
from torch import tensor

from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.text.rouge import ROUGEScore
Expand All @@ -30,16 +29,13 @@

ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum")

PRECISION = 0
RECALL = 1
F_MEASURE = 2

SINGLE_SENTENCE_EXAMPLE_PREDS = "The quick brown fox jumps over the lazy dog"
SINGLE_SENTENCE_EXAMPLE_TARGET = "The quick brown dog jumps on the log."

PREDS = "My name is John".split()
TARGETS = "Is your name John".split()


BATCHES_RS_PREDS = [SINGLE_SENTENCE_EXAMPLE_PREDS]
BATCHES_RS_PREDS.extend(PREDS)
BATCHES_RS_TARGETS = [SINGLE_SENTENCE_EXAMPLE_TARGET]
Expand All @@ -55,145 +51,139 @@ def _compute_rouge_score(preds: List[str], targets: List[str], use_stemmer: bool
scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = BootstrapAggregator()
for pred, target in zip(preds, targets):
aggregator.add_scores(scorer.score(pred, target))
aggregator.add_scores(scorer.score(target, pred))
return aggregator.aggregate()


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_functional_single_sentence(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
scorer = RougeScorer(ROUGE_KEYS)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)
def test_rouge_metric_functional_single_sentence(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS)
rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32)

pl_output = rouge_score(
[SINGLE_SENTENCE_EXAMPLE_PREDS],
[SINGLE_SENTENCE_EXAMPLE_TARGET],
newline_sep=newline_sep,
use_stemmer=use_stemmer,
decimal_places=decimal_places,
)
pl_output = rouge_score([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET], use_stemmer=use_stemmer)

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_functional(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
def test_rouge_metric_functional(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

rs_scores = _compute_rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer)
rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)
rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32)

pl_output = rouge_score(
PREDS, TARGETS, newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places
)
pl_output = rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer)

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_class(pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep):
scorer = RougeScorer(ROUGE_KEYS)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)
def test_rouge_metric_class(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_TARGET, SINGLE_SENTENCE_EXAMPLE_PREDS)
rs_result = torch.tensor(getattr(rs_scores[rouge_level], metric), dtype=torch.float32)

rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
rouge = ROUGEScore(use_stemmer=use_stemmer)
pl_output = rouge([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET])

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason="test requires nltk and rouge-score")
@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
["pl_rouge_metric_key", "use_stemmer"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
pytest.param("rouge1_precision", True),
pytest.param("rouge1_recall", True),
pytest.param("rouge1_fmeasure", False),
pytest.param("rouge2_precision", False),
pytest.param("rouge2_recall", True),
pytest.param("rouge2_fmeasure", True),
pytest.param("rougeL_precision", False),
pytest.param("rougeL_recall", False),
pytest.param("rougeL_fmeasure", True),
pytest.param("rougeLsum_precision", True),
pytest.param("rougeLsum_recall", False),
pytest.param("rougeLsum_fmeasure", False),
],
)
def test_rouge_metric_class_batches(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
def test_rouge_metric_class_batches(pl_rouge_metric_key, use_stemmer):
rouge_level, metric = pl_rouge_metric_key.split("_")

rs_scores = _compute_rouge_score(BATCHES_RS_PREDS, BATCHES_RS_TARGETS, use_stemmer=use_stemmer)
rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)
rs_result = torch.tensor(getattr(rs_scores[rouge_level].mid, metric), dtype=torch.float32)

rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
rouge = ROUGEScore(use_stemmer=use_stemmer)
for batch in BATCHES:
rouge.update(batch["preds"], batch["targets"])
pl_output = rouge.compute()

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))
assert torch.allclose(pl_output[pl_rouge_metric_key], rs_result)


def test_rouge_metric_raises_errors_and_warnings():
"""Test that expected warnings and errors are raised."""
if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE):
if not _NLTK_AVAILABLE:
with pytest.raises(
ValueError,
match="ROUGE metric requires that both nltk and rouge-score is installed."
"Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`",
match="ROUGE metric requires that nltk is installed."
"Either as `pip install torchmetrics[text]` or `pip install nltk`",
):
ROUGEScore()

Expand Down
Loading

0 comments on commit a2712eb

Please sign in to comment.