You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think the MetricCollection was designed to be used as follows:
metrics = MetricCollection([...])
for _ in range(...):
metrics.update(...)
metrics.compute()
metrics.reset()
But I used it the following way:
metrics = MetricCollection([...])
for _ in range(...):
metrics.update(...)
metrics.compute()
metrics.reset()
This does not work. After the first compute call only the first metric in the metric collection will return a new value. This is because after calling .compute of all metrics in the metric collection the respective ._computed property is set to a value. If the metrics are in the same compute_group then the next .update call of the metric collection will invoke the .update method only for the first metric. However it is exactly this method that sets ._computed to None again. For the other metrics it is still not None and thus will return the cached value of the first computation the whole time.
To Reproduce
from torch import Tensor
import numpy as np
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryPrecision, BinaryAccuracy, BinaryRecall, BinarySpecificity
# Generate random predictions and labels for a binary classification problem
np.random.seed(42) # For reproducibility
n_samples = 100 # Number of samples
# Metrics calculation
metrics = MetricCollection([
BinaryRecall(),
BinaryAccuracy(),
BinaryPrecision(),
BinarySpecificity()
])
n_epochs = 10
for _ in range(n_epochs):
# Generate random predictions (0 or 1)
predictions = Tensor(np.random.randint(2, size=n_samples))
# Generate random true labels (0 or 1)
true_labels = Tensor(np.random.randint(2, size=n_samples))
metrics.update(predictions, true_labels)
print(metrics.compute())
Expected behavior
The compute method should return the new calculations for all metrics.
Environment
TorchMetrics 1.2.0, installed with pip
Python 3.11
PyTorch 2.1.0
MacOs
Additional context
The text was updated successfully, but these errors were encountered:
🐛 Bug
I think the MetricCollection was designed to be used as follows:
But I used it the following way:
This does not work. After the first compute call only the first metric in the metric collection will return a new value. This is because after calling
.compute
of all metrics in the metric collection the respective._computed
property is set to a value. If the metrics are in the same compute_group then the next.update
call of the metric collection will invoke the.update
method only for the first metric. However it is exactly this method that sets._computed
to None again. For the other metrics it is still not None and thus will return the cached value of the first computation the whole time.To Reproduce
Expected behavior
The compute method should return the new calculations for all metrics.
Environment
Additional context
The text was updated successfully, but these errors were encountered: