Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of calibration error metrics #394

Merged
merged 51 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a202225
basic ECE functional + class metric working
edwardclem May 21, 2021
67300d5
max calibration error and multidim-multiclass
edwardclem Jun 4, 2021
2a65d97
comb metrics, working functional l2, class broken
edwardclem Jun 6, 2021
0cd7a33
removed debias term, ddp still broken
edwardclem Jul 23, 2021
be2cee1
updated docs
edwardclem Jul 23, 2021
d6fe8ab
Merge branch 'master' into master
Borda Jul 26, 2021
91b0451
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2021
d9e004c
fixed part of ddp, added changelog
edwardclem Jul 31, 2021
cd6a334
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
e316f64
fixed ddp, still need to fix input unit tests
edwardclem Jul 31, 2021
ed2430f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
190ea63
removing sklearn_calibration
edwardclem Jul 31, 2021
b2e8ca6
more docstring fixes
edwardclem Jul 31, 2021
5c661f0
fixed tests for invalid inputs and added regex
edwardclem Jul 31, 2021
41d6bd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
8dd2a2d
added test for non-int val bins
edwardclem Jul 31, 2021
ab0f0e1
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Jul 31, 2021
9e542fb
removed doctest from calibration_error
edwardclem Jul 31, 2021
7a980a7
flake8/typing cleanup
edwardclem Jul 31, 2021
8f837ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
6d711a2
fix docs
edwardclem Jul 31, 2021
9094203
Merge branch 'master' into master
SkafteNicki Aug 2, 2021
b50f155
Apply suggestions from code review
Borda Aug 2, 2021
5fcac0c
Merge branch 'master' into master
SkafteNicki Aug 2, 2021
59c0338
fix order
SkafteNicki Aug 2, 2021
9051a1a
flake8 + rendering
SkafteNicki Aug 2, 2021
f97be31
fix styling
SkafteNicki Aug 2, 2021
98fc849
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
8682241
Apply suggestions from code review
Borda Aug 2, 2021
b11a80b
Update torchmetrics/classification/calibration_error.py
SkafteNicki Aug 2, 2021
e6cb17c
Merge branch 'master' into master
SkafteNicki Aug 2, 2021
88365ad
Merge branch 'master' into master
Borda Aug 2, 2021
a81252b
Merge branch 'master' into master
Borda Aug 2, 2021
086886f
Apply suggestions from code review
Borda Aug 2, 2021
77da9ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
c11acc9
:
Borda Aug 2, 2021
9fa9863
Merge branch 'master' of https://github.com/edwardclem/metrics into e…
Borda Aug 2, 2021
53c58b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
c0db244
...
Borda Aug 2, 2021
3bbc9f5
Merge branch 'master' of https://github.com/edwardclem/metrics into e…
Borda Aug 2, 2021
f50ec75
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Aug 3, 2021
7fb4508
fixed class variable issue
edwardclem Aug 3, 2021
2d71884
added docstrings
edwardclem Aug 3, 2021
940fa6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
8d5a4a1
more flake8 fixes
edwardclem Aug 3, 2021
939bb75
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Aug 3, 2021
4870b3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2021
75cfcac
removed duplicate reference
edwardclem Aug 3, 2021
3d5e91a
Merge branch 'master' of github.com:edwardclem/metrics
edwardclem Aug 3, 2021
7e9cf6d
Apply suggestions from code review
Borda Aug 3, 2021
984b879
Merge branch 'master' into master
mergify[bot] Aug 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386))

- Added calibration error metrics ([#218](https://github.com/PyTorchLightning/metrics/issues/218))
Borda marked this conversation as resolved.
Show resolved Hide resolved


- Added support for negative targets in `nDCG` metric ([#378](https://github.com/PyTorchLightning/metrics/pull/378))

Expand Down
10 changes: 9 additions & 1 deletion docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ average_precision [func]
:noindex:


calibration_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.calibration_error
:noindex:


cohen_kappa [func]
~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -167,8 +174,9 @@ recall [func]
.. autofunction:: torchmetrics.functional.recall
:noindex:


select_topk [func]
~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.data.select_topk
:noindex:
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ BinnedRecallAtFixedPrecision
.. autoclass:: torchmetrics.BinnedRecallAtFixedPrecision
:noindex:

CalibrationError
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.CalibrationError
:noindex:

CohenKappa
~~~~~~~~~~

Expand Down
117 changes: 117 additions & 0 deletions tests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import functools
import re

import numpy as np
import pytest

from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all

# TODO: replace this with official sklearn implementation after next sklearn release
from tests.helpers.non_sklearn_metrics import calibration_error as sk_calib
from tests.helpers.testers import THRESHOLD, MetricTester
from torchmetrics import CalibrationError
from torchmetrics.functional import calibration_error
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

seed_all(42)


def _sk_calibration(preds, target, n_bins, norm, debias=False):
_, _, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
sk_preds, sk_target = preds.numpy(), target.numpy()

if mode == DataType.MULTICLASS:
# binary label is whether or not the predicted class is correct
sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target)
sk_preds = np.max(sk_preds, axis=1)
elif mode == DataType.MULTIDIM_MULTICLASS:
# reshape from shape (N, C, ...) to (N*EXTRA_DIMS, C)
sk_preds = np.transpose(sk_preds, axes=(0, 2, 1))
sk_preds = sk_preds.reshape(np.prod(sk_preds.shape[:-1]), sk_preds.shape[-1])
# reshape from shape (N, ...) to (N*EXTRA_DIMS,)
# binary label is whether or not the predicted class is correct
sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target.flatten())
sk_preds = np.max(sk_preds, axis=1)
return sk_calib(y_true=sk_target, y_prob=sk_preds, norm=norm, n_bins=n_bins, reduce_bias=debias)


@pytest.mark.parametrize("n_bins", [10, 15, 20])
@pytest.mark.parametrize("norm", ["l1", "l2", "max"])
@pytest.mark.parametrize(
"preds, target", [
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
]
)
class TestCE(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=CalibrationError,
sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"n_bins": n_bins,
"norm": norm
}
)

def test_ce_functional(self, preds, target, n_bins, norm):
self.run_functional_metric_test(
preds,
target,
metric_functional=calibration_error,
sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm),
metric_args={
"n_bins": n_bins,
"norm": norm
}
)


@pytest.mark.parametrize("preds, targets", [(_input_mlb_prob.preds, _input_mlb_prob.target)])
def test_invalid_input(preds, targets):
for p, t in zip(preds, targets):
with pytest.raises(
ValueError,
match=re.
escape(f"Calibration error is not well-defined for data with size {p.size()} and targets {t.size()}.")
):
calibration_error(p, t)


@pytest.mark.parametrize(
"preds, target", [
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
]
)
def test_invalid_norm(preds, target):
with pytest.raises(ValueError, match="Norm l3 is not supported. Please select from l1, l2, or max. "):
calibration_error(preds, target, norm="l3")


@pytest.mark.parametrize("n_bins", [-10, -1, "fsd"])
@pytest.mark.parametrize(
"preds, targets", [
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
]
)
def test_invalid_bins(preds, targets, n_bins):
for p, t in zip(preds, targets):
with pytest.raises(ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}"):
calibration_error(p, t, n_bins=n_bins)
193 changes: 163 additions & 30 deletions tests/helpers/non_sklearn_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""File for non sklearn metrics that are to be used for reference for tests"""
from typing import Optional
from typing import Optional, Union

import numpy as np
from sklearn.metrics._regression import _check_reg_targets
from sklearn.utils.validation import check_consistent_length
from sklearn.utils import assert_all_finite, check_consistent_length, column_or_1d


def symmetric_mean_absolute_percentage_error(
Expand All @@ -12,40 +12,40 @@ def symmetric_mean_absolute_percentage_error(
sample_weight: Optional[np.ndarray] = None,
multioutput: str = "uniform_average",
):
r"""Symmetric mean absolute percentage error regression loss.
r"""
Symmetric mean absolute percentage error regression loss.
<https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error>`_ (SMAPE):

.. math:: \text{SMAPE} = \frac{2}{n}\sum_1^n\frac{max(| y_i - \hat{y_i} |}{| y_i | + | \hat{y_i} |, \epsilon)}

Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
Ground truth (correct) target values.
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
multioutput : {'raw_values', 'uniform_average'} or array-like
Defines aggregating of multiple output values.
Array-like value defines weights used to average errors.
If input is list then the shape must be (n_outputs,).
'raw_values' :
Returns a full set of errors in case of multioutput input.
'uniform_average' :
Errors of all outputs are averaged with uniform weight.
Returns
-------
loss : float or ndarray of floats in the range [0, 1]
If multioutput is 'raw_values', then symmetric mean absolute percentage error
is returned for each output separately.
If multioutput is 'uniform_average' or an ndarray of weights, then the
weighted average of all output errors is returned.
MAPE output is non-negative floating point. The best value is 0.0.
But note the fact that bad predictions can lead to arbitarily large
MAPE values, especially if some y_true values are very close to zero.
Note that we return a large value instead of `inf` when y_true is zero.
Args:
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
Ground truth (correct) target values.
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
multioutput : {'raw_values', 'uniform_average'} or array-like
Defines aggregating of multiple output values.
Array-like value defines weights used to average errors.
If input is list then the shape must be (n_outputs,).
'raw_values' :
Returns a full set of errors in case of multioutput input.
'uniform_average' :
Errors of all outputs are averaged with uniform weight.

Returns:
loss : float or ndarray of floats in the range [0, 1]
If multioutput is 'raw_values', then symmetric mean absolute percentage error
is returned for each output separately.
If multioutput is 'uniform_average' or an ndarray of weights, then the
weighted average of all output errors is returned.
MAPE output is non-negative floating point. The best value is 0.0.
But note the fact that bad predictions can lead to arbitarily large
MAPE values, especially if some y_true values are very close to zero.
Note that we return a large value instead of `inf` when y_true is zero.

"""
_, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
Expand All @@ -60,3 +60,136 @@ def symmetric_mean_absolute_percentage_error(
multioutput = None

return np.average(output_errors, weights=multioutput)


# sklearn reference function from
# https://github.com/samronsin/scikit-learn/blob/calibration-loss/sklearn/metrics/_classification.py.
# TODO: when the PR into sklearn is accepted, update this to use the official function.
def calibration_error(
y_true: np.ndarray,
y_prob: np.ndarray,
sample_weight: Optional[np.ndarray] = None,
norm: str = 'l2',
n_bins: int = 10,
strategy: str = 'uniform',
pos_label: Optional[Union[int, str]] = None,
reduce_bias: bool = True
) -> float:
"""
Compute calibration error of a binary classifier.
Across all items in a set of N predictions, the calibration error measures
the aggregated difference between (1) the average predicted probabilities
assigned to the positive class, and (2) the frequencies
of the positive class in the actual outcome.
The calibration error is only appropriate for binary categorical outcomes.
Which label is considered to be the positive label is controlled via the
parameter pos_label, which defaults to 1.

Args:
y_true : array-like of shape (n_samples,)
True targets of a binary classification task.
y_prob : array-like of (n_samples,)
Probabilities of the positive class.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
norm : {'l1', 'l2', 'max'}, default='l2'
Norm method. The l1-norm is the Expected Calibration Error (ECE),
and the max-norm corresponds to Maximum Calibration Error (MCE).
n_bins : int, default=10
The number of bins to compute error on.
strategy : {'uniform', 'quantile'}, default='uniform'
Strategy used to define the widths of the bins.
uniform
All bins have identical widths.
quantile
All bins have the same number of points.
pos_label : int or str, default=None
Label of the positive class. If None, the maximum label is used as
positive class.
reduce_bias : bool, default=True
Add debiasing term as in Verified Uncertainty Calibration, A. Kumar.
Only effective for the l2-norm.

Returns:
score : float with calibration error
"""
y_true = column_or_1d(y_true)
y_prob = column_or_1d(y_prob)
assert_all_finite(y_true)
assert_all_finite(y_prob)
check_consistent_length(y_true, y_prob, sample_weight)
if any(y_prob < 0) or any(y_prob > 1):
raise ValueError("y_prob has values outside of [0, 1] range")

labels = np.unique(y_true)
if len(labels) > 2:
raise ValueError("Only binary classification is supported. " "Provided labels %s." % labels)

if pos_label is None:
pos_label = y_true.max()
if pos_label not in labels:
raise ValueError("pos_label=%r is not a valid label: " "%r" % (pos_label, labels))
y_true = np.array(y_true == pos_label, int)

norm_options = ('l1', 'l2', 'max')
if norm not in norm_options:
raise ValueError(f'norm has to be one of {norm_options}, got: {norm}.')

remapping = np.argsort(y_prob)
y_true = y_true[remapping]
y_prob = y_prob[remapping]
if sample_weight is not None:
sample_weight = sample_weight[remapping]
else:
sample_weight = np.ones(y_true.shape[0])

n_bins = int(n_bins)
if strategy == 'quantile':
quantiles = np.percentile(y_prob, np.arange(0, 1, 1.0 / n_bins) * 100)
elif strategy == 'uniform':
quantiles = np.arange(0, 1, 1.0 / n_bins)
else:
raise ValueError(
f"Invalid entry to 'strategy' input. The strategy must be either quantile' or 'uniform'. Got {strategy} instead."
)

threshold_indices = np.searchsorted(y_prob, quantiles).tolist()
threshold_indices.append(y_true.shape[0])
avg_pred_true = np.zeros(n_bins)
bin_centroid = np.zeros(n_bins)
delta_count = np.zeros(n_bins)
debias = np.zeros(n_bins)

loss = 0.
count = float(sample_weight.sum())
for i, i_start in enumerate(threshold_indices[:-1]):
i_end = threshold_indices[i + 1]
# ignore empty bins
if i_end == i_start:
continue
delta_count[i] = float(sample_weight[i_start:i_end].sum())
avg_pred_true[i] = (np.dot(y_true[i_start:i_end], sample_weight[i_start:i_end]) / delta_count[i])
bin_centroid[i] = (np.dot(y_prob[i_start:i_end], sample_weight[i_start:i_end]) / delta_count[i])
if norm == "l2" and reduce_bias:
# NOTE: I think there's a mistake in the original implementation.
# delta_debias = (
# avg_pred_true[i] * (avg_pred_true[i] - 1) * delta_count[i]
# )
# delta_debias /= (count * delta_count[i] - 1)
delta_debias = (avg_pred_true[i] * (avg_pred_true[i] - 1) * delta_count[i])
delta_debias /= count * (delta_count[i] - 1)
debias[i] = delta_debias

if norm == "max":
loss = np.max(np.abs(avg_pred_true - bin_centroid))
elif norm == "l1":
delta_loss = np.abs(avg_pred_true - bin_centroid) * delta_count
loss = np.sum(delta_loss) / count
elif norm == "l2":
delta_loss = (avg_pred_true - bin_centroid)**2 * delta_count
loss = np.sum(delta_loss) / count
if reduce_bias:
# convert nans to zero
loss += np.sum(np.nan_to_num(debias))
loss = np.sqrt(max(loss, 0.))
return loss
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
BinnedRecallAtFixedPrecision,
CalibrationError,
CohenKappa,
ConfusionMatrix,
FBeta,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.f_beta import F1, FBeta # noqa: F401
Expand Down
Loading