From 2c1806376a2f0c4ab18b43cdf339a93b899f4ce2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 2 Oct 2023 15:14:24 +0200 Subject: [PATCH] Clarify input shape expectation in classification for `samplewise` reduction (#2119) Clarify input shape expectation in classification for `samplewise` reduction --- src/torchmetrics/classification/accuracy.py | 9 +++- .../classification/exact_match.py | 8 +++- src/torchmetrics/classification/f_beta.py | 20 ++++++++- src/torchmetrics/classification/hamming.py | 11 ++++- .../classification/precision_recall.py | 18 ++++++++ .../classification/specificity.py | 10 ++++- .../classification/stat_scores.py | 43 +++++++++++++------ 7 files changed, 97 insertions(+), 22 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 60188aff5c9..117a89cb667 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -49,7 +49,8 @@ class BinaryAccuracy(BinaryStatScores): If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. - Additional dimension ``...`` will be flattened into the batch dimension. + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. Args: threshold: Threshold for transforming probability to binary {0,1} predictions @@ -176,6 +177,9 @@ class MulticlassAccuracy(MulticlassStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_classes: Integer specifying the number of classes average: @@ -325,6 +329,9 @@ class MultilabelAccuracy(MultilabelStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_labels: Integer specifying the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 481441a820c..20ab5344373 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -54,7 +54,6 @@ class MulticlassExactMatch(Metric): probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mcem`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument: @@ -62,6 +61,9 @@ class MulticlassExactMatch(Metric): - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_classes: Integer specifying the number of labels multidim_average: @@ -206,7 +208,6 @@ class MultilabelExactMatch(Metric): sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mlem`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument: @@ -214,6 +215,9 @@ class MultilabelExactMatch(Metric): - If ``multidim_average`` is set to ``global`` the output will be a scalar tensor - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_labels: Integer specifying the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 16f8a1408b4..0386a8b2eb9 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -66,6 +66,9 @@ class BinaryFBetaScore(BinaryStatScores): - If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` consisting of a scalar value per sample. + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight threshold: Threshold for transforming probability to binary {0,1} predictions @@ -202,7 +205,6 @@ class MulticlassFBetaScore(MulticlassStatScores): probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mcfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and @@ -218,6 +220,9 @@ class MulticlassFBetaScore(MulticlassStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight num_classes: Integer specifying the number of classes @@ -382,7 +387,6 @@ class MultilabelFBetaScore(MultilabelStatScores): per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mlfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and @@ -398,6 +402,9 @@ class MultilabelFBetaScore(MultilabelStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight num_labels: Integer specifying the number of labels @@ -566,6 +573,9 @@ class BinaryF1Score(BinaryFBetaScore): - If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: threshold: Threshold for transforming probability to binary {0,1} predictions multidim_average: @@ -706,6 +716,9 @@ class MulticlassF1Score(MulticlassFBetaScore): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: preds: Tensor with predictions target: Tensor with true labels @@ -876,6 +889,9 @@ class MultilabelF1Score(MultilabelFBetaScore): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)``` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_labels: Integer specifying the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index dd577d92b76..340a647aa8d 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -58,6 +58,9 @@ class BinaryHammingDistance(BinaryStatScores): - If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: threshold: Threshold for transforming probability to binary {0,1} predictions multidim_average: @@ -171,7 +174,6 @@ class MulticlassHammingDistance(MulticlassStatScores): probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mchd`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and @@ -187,6 +189,9 @@ class MulticlassHammingDistance(MulticlassStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_classes: Integer specifying the number of classes average: @@ -324,7 +329,6 @@ class MultilabelHammingDistance(MultilabelStatScores): ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mlhd`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and @@ -340,6 +344,9 @@ class MultilabelHammingDistance(MultilabelStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_labels: Integer specifying the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index d3530b4c769..d221584c336 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -57,6 +57,9 @@ class BinaryPrecision(BinaryStatScores): value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: threshold: Threshold for transforming probability to binary {0,1} predictions multidim_average: @@ -187,6 +190,9 @@ class MulticlassPrecision(MulticlassStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_classes: Integer specifying the number of classes average: @@ -340,6 +346,9 @@ class MultilabelPrecision(MultilabelStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_labels: Integer specifying the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions @@ -479,6 +488,9 @@ class BinaryRecall(BinaryStatScores): value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: threshold: Threshold for transforming probability to binary {0,1} predictions multidim_average: @@ -608,6 +620,9 @@ class MulticlassRecall(MulticlassStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_classes: Integer specifying the number of classes average: @@ -760,6 +775,9 @@ class MultilabelRecall(MultilabelStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_labels: Integer specifying the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 31d736881cf..d9124968cfc 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -50,6 +50,9 @@ class BinarySpecificity(BinaryStatScores): If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: threshold: Threshold for transforming probability to binary {0,1} predictions multidim_average: @@ -174,6 +177,9 @@ class MulticlassSpecificity(MulticlassStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_classes: Integer specifying the number of classes average: @@ -307,7 +313,6 @@ class MultilabelSpecificity(MultilabelStatScores): per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` - As output to ``forward`` and ``compute`` the metric returns the following output: - ``mls`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average`` @@ -323,6 +328,9 @@ class MultilabelSpecificity(MultilabelStatScores): - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - If ``average=None/'none'``, the shape will be ``(N, C)`` + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + Args: num_labels: Integer specifying the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 7c72725d57a..ce671e202bf 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -107,8 +107,11 @@ class BinaryStatScores(_AbstractStatScores): to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on the ``multidim_average`` parameter: - - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` - - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. Args: threshold: Threshold for transforming probability to binary {0,1} predictions @@ -208,12 +211,18 @@ class MulticlassStatScores(_AbstractStatScores): to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on ``average`` and ``multidim_average`` parameters: - - If ``multidim_average`` is set to ``global`` - - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` - - If ``average=None/'none'``, the shape will be ``(C, 5)`` - - If ``multidim_average`` is set to ``samplewise`` - - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` - - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. Args: num_classes: Integer specifying the number of classes @@ -352,12 +361,18 @@ class MultilabelStatScores(_AbstractStatScores): to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on ``average`` and ``multidim_average`` parameters: - - If ``multidim_average`` is set to ``global`` - - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` - - If ``average=None/'none'``, the shape will be ``(C, 5)`` - - If ``multidim_average`` is set to ``samplewise`` - - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` - - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. Args: num_labels: Integer specifying the number of labels