Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include micro averaging - Bleu Score #2179

Merged
merged 21 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 120 additions & 44 deletions ignite/metrics/nlp/bleu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
from collections import Counter
from typing import Any, Callable, Sequence, Tuple, Union, ValuesView
from typing import Any, Callable, Sequence, Tuple, Union

import torch

Expand Down Expand Up @@ -29,35 +28,37 @@ def __init__(self, method: str):
raise ValueError(f"Smooth is not valid (expected: {valid}, got: {method})")
self.smooth = method

def __call__(self, numerators: Counter, denominators: Counter) -> Sequence[float]:
def __call__(self, numerators: torch.Tensor, denominators: torch.Tensor) -> Sequence[float]:
method = getattr(self, self.smooth)
return method(numerators, denominators)

@staticmethod
def smooth1(numerators: Counter, denominators: Counter) -> Sequence[float]:
def smooth1(numerators: torch.Tensor, denominators: torch.Tensor) -> Sequence[float]:
epsilon = 0.1
denominators_ = [max(1, d) for d in denominators.values()]
return [n / d if n != 0 else epsilon / d for n, d in zip(numerators.values(), denominators_)]
denominators_ = [max(1, d.item()) for d in denominators]
return [n.item() / d if n != 0 else epsilon / d for n, d in zip(numerators, denominators_)]

@staticmethod
def nltk_smooth2(numerators: Counter, denominators: Counter) -> Sequence[float]:
denominators_ = [max(1, d) for d in denominators.values()]
return _Smoother._smooth2(numerators.values(), denominators_)
def nltk_smooth2(numerators: torch.Tensor, denominators: torch.Tensor) -> Sequence[float]:
denominators_ = torch.tensor([max(1, d.item()) for d in denominators])
return _Smoother._smooth2(numerators, denominators_)

@staticmethod
def smooth2(numerators: Counter, denominators: Counter) -> Sequence[float]:
return _Smoother._smooth2(numerators.values(), denominators.values())
def smooth2(numerators: torch.Tensor, denominators: torch.Tensor) -> Sequence[float]:
return _Smoother._smooth2(numerators, denominators)

@staticmethod
def _smooth2(
numerators: Union[ValuesView[int], Sequence[int]], denominators: Union[ValuesView[int], Sequence[int]]
) -> Sequence[float]:
return [(n + 1) / (d + 1) if i != 0 else n / d for i, (n, d) in enumerate(zip(numerators, denominators))]
def _smooth2(numerators: torch.Tensor, denominators: torch.Tensor) -> Sequence[float]:

return [
(n.item() + 1) / (d.item() + 1) if i != 0 else n.item() / d.item()
for i, (n, d) in enumerate(zip(numerators, denominators))
]

@staticmethod
def no_smooth(numerators: Counter, denominators: Counter) -> Sequence[float]:
denominators_ = [max(1, d) for d in denominators.values()]
return [n / d for n, d in zip(numerators.values(), denominators_)]
def no_smooth(numerators: torch.Tensor, denominators: torch.Tensor) -> Sequence[float]:
denominators_ = [max(1, d) for d in denominators]
return [n.item() / d for n, d in zip(numerators, denominators_)]


class Bleu(Metric):
Expand Down Expand Up @@ -97,6 +98,9 @@ class Bleu(Metric):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
average: specifies which type of averaging to use (macro or micro)
for more details refer https://www.nltk.org/_modules/nltk/translate/bleu_score.html
Default: "macro"

Example:
Ishan-Kumar2 marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -114,6 +118,8 @@ class Bleu(Metric):
print(m.compute())

.. versionadded:: 0.4.5
.. versionchanged:: 0.5.0
Ishan-Kumar2 marked this conversation as resolved.
Show resolved Hide resolved
added ``average`` option to handle micro and macro averaging modes.
"""

def __init__(
Expand All @@ -122,26 +128,37 @@ def __init__(
smooth: str = "no_smooth",
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
average: str = "macro",
):
if ngram <= 0:
raise ValueError(f"ngram order must be greater than zero (got: {ngram})")
self.ngrams_order = ngram
self.weights = [1 / self.ngrams_order] * self.ngrams_order
self.smoother = _Smoother(method=smooth)

if average not in ["macro", "micro"]:
Ishan-Kumar2 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'Average must be either "macro" or "micro" (got: {average})')
self.average = average

super(Bleu, self).__init__(output_transform=output_transform, device=device)

def _corpus_bleu(
self, references: Sequence[Sequence[Sequence[Any]]], candidates: Sequence[Sequence[Any]],
) -> float:
p_numerators: Counter = Counter()
p_denominators: Counter = Counter()
def _n_gram_counter(
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
self,
references: Sequence[Sequence[Sequence[Any]]],
candidates: Sequence[Sequence[Any]],
p_numerators: torch.Tensor,
p_denominators: torch.Tensor,
) -> Tuple[int, int]:

if len(references) != len(candidates):
raise ValueError(
f"nb of candidates should be equal to nb of reference lists ({len(candidates)} != "
f"{len(references)})"
)

hyp_lengths = 0
ref_lengths = 0

# Iterate through each hypothesis and their corresponding references.
for refs, hyp in zip(references, candidates):
# For each order of ngram, calculate the numerator and
Expand All @@ -151,54 +168,113 @@ def _corpus_bleu(
p_numerators[i] += numerator
p_denominators[i] += denominator

# Calculate the hypothesis lengths
hyp_lengths += len(hyp)

# Calculate the closest reference lengths.
ref_lengths += _closest_ref_length(refs, len(hyp))

return hyp_lengths, ref_lengths

def _brevity_penalty_smoothing(
self, p_numerators: torch.Tensor, p_denominators: torch.Tensor, hyp_length_sum: int, ref_length_sum: int,
) -> float:

# Returns 0 if there's no matching n-grams
# We only need to check for p_numerators[1] == 0, since if there's
# no unigrams, there won't be any higher order ngrams.
if p_numerators[1] == 0:
return 0

# If no smoother, returns 0 if there's at least one a not matching n-grams
if self.smoother.smooth == "no_smooth" and min(p_numerators.values()) == 0:
# If no smoother, returns 0 if there's at least one a not matching n-grams]
if self.smoother.smooth == "no_smooth" and min(p_numerators[1:]).item() == 0:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
return 0

# Calculate the hypothesis lengths
hyp_lengths = [len(hyp) for hyp in candidates]

# Calculate the closest reference lengths.
ref_lengths = [_closest_ref_length(refs, hyp_len) for refs, hyp_len in zip(references, hyp_lengths)]

# Sum of hypothesis and references lengths
hyp_len = sum(hyp_lengths)
ref_len = sum(ref_lengths)

# Calculate corpus-level brevity penalty.
if hyp_len < ref_len:
bp = math.exp(1 - ref_len / hyp_len) if hyp_len > 0 else 0.0
if hyp_length_sum < ref_length_sum:
bp = math.exp(1 - ref_length_sum / hyp_length_sum) if hyp_length_sum > 0 else 0.0
else:
bp = 1.0

# Smoothing
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
p_n = self.smoother(p_numerators, p_denominators)
p_n = self.smoother(p_numerators[1:], p_denominators[1:])

# Compute the geometric mean
s = [w_i * math.log(p_i) for w_i, p_i in zip(self.weights, p_n)]
gm = bp * math.exp(math.fsum(s))
return gm

def _sentence_bleu(self, references: Sequence[Sequence[Any]], candidates: Sequence[Any],) -> float:
return self._corpus_bleu([references], [candidates])

def _corpus_bleu(
self, references: Sequence[Sequence[Sequence[Any]]], candidates: Sequence[Sequence[Any]],
) -> float:

p_numerators: torch.Tensor = torch.zeros(self.ngrams_order + 1)
p_denominators: torch.Tensor = torch.zeros(self.ngrams_order + 1)

hyp_length_sum, ref_length_sum = self._n_gram_counter(
references=references, candidates=candidates, p_numerators=p_numerators, p_denominators=p_denominators,
)
bleu_score = self._brevity_penalty_smoothing(
p_numerators=p_numerators,
p_denominators=p_denominators,
hyp_length_sum=hyp_length_sum,
ref_length_sum=ref_length_sum,
)

return bleu_score

@reinit__is_reduced
def reset(self) -> None:
Ishan-Kumar2 marked this conversation as resolved.
Show resolved Hide resolved
self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._num_sentences = 0

if self.average == "macro":
self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._num_sentences = 0

if self.average == "micro":
self.p_numerators = torch.zeros(self.ngrams_order + 1)
self.p_denominators = torch.zeros(self.ngrams_order + 1)
self.hyp_length_sum = 0
self.ref_length_sum = 0

@reinit__is_reduced
def update(self, output: Tuple[Sequence[Sequence[Any]], Sequence[Sequence[Sequence[Any]]]]) -> None:
y_pred, y = output
for _y_pred, _y in zip(y_pred, y):
self._sum_of_bleu += self._corpus_bleu(references=[_y], candidates=[_y_pred])
self._num_sentences += 1

if self.average == "macro":
for refs, hyp in zip(y, y_pred):
self._sum_of_bleu += self._sentence_bleu(references=refs, candidates=hyp)
self._num_sentences += 1

elif self.average == "micro":
hyp_lengths, ref_lengths = self._n_gram_counter(
references=y, candidates=y_pred, p_numerators=self.p_numerators, p_denominators=self.p_denominators
)
self.hyp_length_sum += hyp_lengths
self.ref_length_sum += ref_lengths

@sync_all_reduce("_sum_of_bleu", "_num_sentences")
def compute(self) -> torch.Tensor:
def _compute_macro(self) -> torch.Tensor:
if self._num_sentences == 0:
raise NotComputableError("Bleu must have at least one example before it can be computed.")

return self._sum_of_bleu / self._num_sentences

@sync_all_reduce("p_numerators", "p_denominators", "hyp_length_sum", "ref_length_sum")
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
def _compute_micro(self) -> float:

bleu_score = self._brevity_penalty_smoothing(
p_numerators=self.p_numerators,
p_denominators=self.p_denominators,
hyp_length_sum=self.hyp_length_sum,
ref_length_sum=self.ref_length_sum,
)
return bleu_score

def compute(self) -> None:
if self.average == "macro":
return self._compute_macro()
elif self.average == "micro":
return self._compute_micro()
Loading