Skip to content

Commit

Permalink
Clarify input shape expectation in classification for samplewise re…
Browse files Browse the repository at this point in the history
…duction (#2119)

Clarify input shape expectation in classification for `samplewise` reduction
  • Loading branch information
SkafteNicki authored Oct 2, 2023
1 parent bd8e556 commit 2c18063
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 22 deletions.
9 changes: 8 additions & 1 deletion src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class BinaryAccuracy(BinaryStatScores):
If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
value per sample.
Additional dimension ``...`` will be flattened into the batch dimension.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
Expand Down Expand Up @@ -176,6 +177,9 @@ class MulticlassAccuracy(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -325,6 +329,9 @@ class MultilabelAccuracy(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ class MulticlassExactMatch(Metric):
probabilities/logits into an int tensor.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mcem`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument:
- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_classes: Integer specifying the number of labels
multidim_average:
Expand Down Expand Up @@ -206,14 +208,16 @@ class MultilabelExactMatch(Metric):
sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mlem`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``multidim_average`` argument:
- If ``multidim_average`` is set to ``global`` the output will be a scalar tensor
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
20 changes: 18 additions & 2 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class BinaryFBetaScore(BinaryStatScores):
- If ``multidim_average`` is set to ``samplewise`` the output will be a tensor of shape ``(N,)`` consisting of
a scalar value per sample.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
threshold: Threshold for transforming probability to binary {0,1} predictions
Expand Down Expand Up @@ -202,7 +205,6 @@ class MulticlassFBetaScore(MulticlassStatScores):
probabilities/logits into an int tensor.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mcfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -218,6 +220,9 @@ class MulticlassFBetaScore(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
num_classes: Integer specifying the number of classes
Expand Down Expand Up @@ -382,7 +387,6 @@ class MultilabelFBetaScore(MultilabelStatScores):
per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mlfbs`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -398,6 +402,9 @@ class MultilabelFBetaScore(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
beta: Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight
num_labels: Integer specifying the number of labels
Expand Down Expand Up @@ -566,6 +573,9 @@ class BinaryF1Score(BinaryFBetaScore):
- If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar
value per sample.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -706,6 +716,9 @@ class MulticlassF1Score(MulticlassFBetaScore):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
preds: Tensor with predictions
target: Tensor with true labels
Expand Down Expand Up @@ -876,6 +889,9 @@ class MultilabelF1Score(MultilabelFBetaScore):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)```
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class BinaryHammingDistance(BinaryStatScores):
- If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a
scalar value per sample.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -171,7 +174,6 @@ class MulticlassHammingDistance(MulticlassStatScores):
probabilities/logits into an int tensor.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mchd`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -187,6 +189,9 @@ class MulticlassHammingDistance(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -324,7 +329,6 @@ class MultilabelHammingDistance(MultilabelStatScores):
``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mlhd`` (:class:`~torch.Tensor`): A tensor whose returned shape depends on the ``average`` and
Expand All @@ -340,6 +344,9 @@ class MultilabelHammingDistance(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
18 changes: 18 additions & 0 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class BinaryPrecision(BinaryStatScores):
value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a
scalar value per sample.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -187,6 +190,9 @@ class MulticlassPrecision(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -340,6 +346,9 @@ class MultilabelPrecision(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down Expand Up @@ -479,6 +488,9 @@ class BinaryRecall(BinaryStatScores):
value. If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of
a scalar value per sample.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -608,6 +620,9 @@ class MulticlassRecall(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -760,6 +775,9 @@ class MultilabelRecall(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
10 changes: 9 additions & 1 deletion src/torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class BinarySpecificity(BinaryStatScores):
If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value
per sample.
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
multidim_average:
Expand Down Expand Up @@ -174,6 +177,9 @@ class MulticlassSpecificity(MulticlassStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_classes: Integer specifying the number of classes
average:
Expand Down Expand Up @@ -307,7 +313,6 @@ class MultilabelSpecificity(MultilabelStatScores):
per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``mls`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average``
Expand All @@ -323,6 +328,9 @@ class MultilabelSpecificity(MultilabelStatScores):
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)``
- If ``average=None/'none'``, the shape will be ``(N, C)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
43 changes: 29 additions & 14 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ class BinaryStatScores(_AbstractStatScores):
to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape
depends on the ``multidim_average`` parameter:
- If ``multidim_average`` is set to ``global``, the shape will be ``(5,)``
- If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)``
- If ``multidim_average`` is set to ``global``, the shape will be ``(5,)``
- If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
threshold: Threshold for transforming probability to binary {0,1} predictions
Expand Down Expand Up @@ -208,12 +211,18 @@ class MulticlassStatScores(_AbstractStatScores):
to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape
depends on ``average`` and ``multidim_average`` parameters:
- If ``multidim_average`` is set to ``global``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``
- If ``multidim_average`` is set to ``samplewise``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``
- If ``multidim_average`` is set to ``global``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``
- If ``multidim_average`` is set to ``samplewise``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_classes: Integer specifying the number of classes
Expand Down Expand Up @@ -352,12 +361,18 @@ class MultilabelStatScores(_AbstractStatScores):
to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape
depends on ``average`` and ``multidim_average`` parameters:
- If ``multidim_average`` is set to ``global``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``
- If ``multidim_average`` is set to ``samplewise``
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``
- If ``multidim_average`` is set to ``global``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)``
- If ``average=None/'none'``, the shape will be ``(C, 5)``
- If ``multidim_average`` is set to ``samplewise``:
- If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)``
- If ``average=None/'none'``, the shape will be ``(N, C, 5)``
If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present,
which the reduction will then be applied over instead of the sample dimension ``N``.
Args:
num_labels: Integer specifying the number of labels
Expand Down

0 comments on commit 2c18063

Please sign in to comment.