Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of calibration error metrics #394

Merged
merged 51 commits into from
Aug 3, 2021
Merged
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a202225
basic ECE functional + class metric working
edwardclem May 21, 2021
67300d5
max calibration error and multidim-multiclass
edwardclem Jun 4, 2021
2a65d97
comb metrics, working functional l2, class broken
edwardclem Jun 6, 2021
0cd7a33
removed debias term, ddp still broken
edwardclem Jul 23, 2021
be2cee1
updated docs
edwardclem Jul 23, 2021
d6fe8ab
Merge branch 'master' into master
Borda Jul 26, 2021
91b0451
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
d9e004c
fixed part of ddp, added changelog
edwardclem Jul 31, 2021
cd6a334
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
e316f64
fixed ddp, still need to fix input unit tests
edwardclem Jul 31, 2021
ed2430f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
190ea63
removing sklearn_calibration
edwardclem Jul 31, 2021
b2e8ca6
more docstring fixes
edwardclem Jul 31, 2021
5c661f0
fixed tests for invalid inputs and added regex
edwardclem Jul 31, 2021
41d6bd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
8dd2a2d
added test for non-int val bins
edwardclem Jul 31, 2021
ab0f0e1
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Jul 31, 2021
9e542fb
removed doctest from calibration_error
edwardclem Jul 31, 2021
7a980a7
flake8/typing cleanup
edwardclem Jul 31, 2021
8f837ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
6d711a2
fix docs
edwardclem Jul 31, 2021
9094203
Merge branch 'master' into master
SkafteNicki Aug 2, 2021
b50f155
Apply suggestions from code review
Borda Aug 2, 2021
5fcac0c
Merge branch 'master' into master
SkafteNicki Aug 2, 2021
59c0338
fix order
SkafteNicki Aug 2, 2021
9051a1a
flake8 + rendering
SkafteNicki Aug 2, 2021
f97be31
fix styling
SkafteNicki Aug 2, 2021
98fc849
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
8682241
Apply suggestions from code review
Borda Aug 2, 2021
b11a80b
Update torchmetrics/classification/calibration_error.py
SkafteNicki Aug 2, 2021
e6cb17c
Merge branch 'master' into master
SkafteNicki Aug 2, 2021
88365ad
Merge branch 'master' into master
Borda Aug 2, 2021
a81252b
Merge branch 'master' into master
Borda Aug 2, 2021
086886f
Apply suggestions from code review
Borda Aug 2, 2021
77da9ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
c11acc9
:
Borda Aug 2, 2021
9fa9863
Merge branch 'master' of https://github.com/edwardclem/metrics into e…
Borda Aug 2, 2021
53c58b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
c0db244
...
Borda Aug 2, 2021
3bbc9f5
Merge branch 'master' of https://github.com/edwardclem/metrics into e…
Borda Aug 2, 2021
f50ec75
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Aug 3, 2021
7fb4508
fixed class variable issue
edwardclem Aug 3, 2021
2d71884
added docstrings
edwardclem Aug 3, 2021
940fa6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
8d5a4a1
more flake8 fixes
edwardclem Aug 3, 2021
939bb75
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Aug 3, 2021
4870b3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
75cfcac
removed duplicate reference
edwardclem Aug 3, 2021
3d5e91a
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Aug 3, 2021
7e9cf6d
Apply suggestions from code review
Borda Aug 3, 2021
984b879
Merge branch 'master' into master
mergify[bot] Aug 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed ddp, still need to fix input unit tests
edwardclem committed Jul 31, 2021

Verified

This commit was signed with the committer’s verified signature.
commit e316f645637985141f56ea677b0debf6eb8d7d04
15 changes: 14 additions & 1 deletion tests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
@@ -97,5 +97,18 @@ def test_invalid_input(preds, targets):
]
)
def test_invalid_norm(preds, target):
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Norm l3 is not supported. Please select from l1, l2, or max. "):
calibration_error(preds, target, norm="l3")


@pytest.mark.parametrize("n_bins", [-10, 0, -1])
@pytest.mark.parametrize(
"preds, target", [
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
]
)
def test_invalid_bins(preds, target, n_bins):
with pytest.raises(ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}"):
calibration_error(preds, target, n_bins=n_bins)
17 changes: 8 additions & 9 deletions torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
@@ -55,20 +55,17 @@ def __init__(

Where :math:p_i is the top-1 prediction accuracy in bin i and :math:c_i is the average confidence of predictions in bin i.
Borda marked this conversation as resolved.
Show resolved Hide resolved

# NOTE: L2-norm debiasing is not yet supported.

NOTE: L2-norm debiasing is not yet supported.

Args:
n_bins (int, optional): Number of bins to use when computing t. Defaults to 15.
n_bins (int, optional): Number of bins to use when computing probabilites and accuracies. Defaults to 15.
norm (str, optional): Norm used to compare empirical and expected probability bins.
Defaults to "l1", or Expected Calibration Error.
debias (bool, optional): Applies debiasing term, only implemented for l2 norm. Defaults to True.
compute_on_step (bool, optional): Forward only calls ``update()`` and return None if this is set to False. Defaults to False.
dist_sync_on_step (bool, optional): Synchronize metric state across processes at each ``forward()``
before returning the value at the step.. Defaults to False.
process_group (Optional[Any], optional): Specify the process group on which synchronization is called. default: None (which selects the entire world). Defaults to None.
dist_sync_fn (Callable, optional): Callback that performs the ``allgather`` operation on the metric state. When ``None``, DDP
will be used to perform the ``allgather``.. Defaults to None.
"""
super().__init__(
compute_on_step=compute_on_step,
@@ -80,20 +77,22 @@ def __init__(
if norm not in ["l1", "l2", "max"]:
Borda marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ")

if not isinstance(n_bins, int) and n_bins <= 0:
raise ValueError(f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}")
self.n_bins = n_bins
edwardclem marked this conversation as resolved.
Show resolved Hide resolved
self.register_buffer("bin_boundaries", torch.linspace(0, 1, n_bins + 1))
self.norm = norm

self.add_state("confidences", [], dist_reduce_fx="cat")
self.add_state("accuracies", [], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor):
def update(self, preds: Tensor, target: Tensor) -> None:
"""
Computes top-level confidences and accuracies for the input probabilites and appends them to internal state.

Args:
preds (Tensor): [description]
target (Tensor): [description]
preds (Tensor): Model output probabilities.
target (Tensor): Ground-truth target class labels.
"""
confidences, accuracies = _ce_update(preds, target)

@@ -105,7 +104,7 @@ def compute(self) -> Tensor:
Computes calibration error across all confidences and accuracies.

Returns:
Tensor: [description]
Tensor: Calibration error across previously collected examples.
"""
confidences = dim_zero_cat(self.confidences)
accuracies = dim_zero_cat(self.accuracies)
19 changes: 11 additions & 8 deletions torchmetrics/functional/classification/calibration_error.py
Original file line number Diff line number Diff line change
@@ -14,15 +14,15 @@
from typing import Optional, Tuple

import torch
from torch import Tensor, tensor
from torch import Tensor, tensor, FloatTensor
from torch.nn import functional as F

from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType


def _ce_compute(
confidences: Tensor, accuracies: Tensor, bin_boundaries: Tensor, norm: str = "l1", debias: bool = False
confidences: FloatTensor, accuracies: FloatTensor, bin_boundaries: FloatTensor, norm: str = "l1", debias: bool = False
) -> Tensor:
edwardclem marked this conversation as resolved.
Show resolved Hide resolved

conf_bin = torch.zeros_like(bin_boundaries)
@@ -56,7 +56,7 @@ def _ce_compute(
return ce


def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
def _ce_update(preds: Tensor, target: Tensor) -> Tuple[FloatTensor, FloatTensor]:
edwardclem marked this conversation as resolved.
Show resolved Hide resolved
_, _, mode = _input_format_classification(preds, target)

if mode == DataType.BINARY:
@@ -74,8 +74,8 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
raise ValueError(
f"Calibration error is not well-defined for data with size {preds.size()} and targets {target.size()}"
)

return confidences, accuracies
# must be cast to float for ddp allgather to work
return confidences.float(), accuracies.float()


def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str = "l1"):
edwardclem marked this conversation as resolved.
Show resolved Hide resolved
@@ -106,8 +106,8 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str


Args:
preds (Tensor): [description]
target (Tensor): [description]
preds (Tensor): Model output probabilities.
target (Tensor): Ground-truth target class labels.
n_bins (int, optional): Number of bins to use when computing t. Defaults to 15.
norm (str, optional): Norm used to compare empirical and expected probability bins.
Defaults to "l1", or Expected Calibration Error.
@@ -118,6 +118,9 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str

confidences, accuracies = _ce_update(preds, target)

bin_boundaries = torch.linspace(0, 1, n_bins + 1).to(preds.device)
if not isinstance(n_bins, int) and n_bins <= 0:
raise ValueError(f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}")

bin_boundaries = torch.linspace(0, 1, n_bins + 1, dtype=torch.float).to(preds.device)
Borda marked this conversation as resolved.
Show resolved Hide resolved

return _ce_compute(confidences, accuracies, bin_boundaries, norm=norm)