Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to plot MetricCollection containing prefix using MetricCollection.plot() #2411

Closed
damos97 opened this issue Feb 27, 2024 · 1 comment · Fixed by #2429
Closed

Unable to plot MetricCollection containing prefix using MetricCollection.plot() #2411

damos97 opened this issue Feb 27, 2024 · 1 comment · Fixed by #2429
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x
Milestone

Comments

@damos97
Copy link

damos97 commented Feb 27, 2024

🐛 Bug

I want to evaluate and plot the same metrics, namely confusion matrix and ROC curve, during training, validation and testing phase. For this, I set up a base MetricCollection which is then cloned three times. The respective prefix, e.g. 'train_', is added using the prefix argument of the MetricCollection.clone() function.

When trying to plot the MetricCollection with the prefix using MetricCollection.plot(), a KeyError exception is thrown. When no prefix is provided, MetricCollection.plot() works correctly.

To Reproduce

Steps to reproduce the behavior...

Fails: with prefix

Code sample
import torchmetrics
import torch
import matplotlib.pyplot as plt

#test data
preds = torch.tensor([[ 0.,  0.,  1.],
                      [ 1.,  0.,  0.],
                      [ 0.,  1.,  0.],
                      [ 0.,  0.,  1.],
                      [ 0.,  0.,  1.],
                      [ 1.,  0.,  0.]])
target = torch.tensor([2, 0, 1, 2, 0, 1])
# base metric collection
conf_and_roc = torchmetrics.MetricCollection([torchmetrics.ROC(task="multiclass", num_classes=3),
                                              torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)])
# clone base metric collection and add prefix
train_conf_and_roc = conf_and_roc.clone(prefix='train_')
# update metrics in collection
train_conf_and_roc.update(preds, target)
# plot metrics
res = train_conf_and_roc.plot()
# show resulting figures
plt.show()
Error message
File "/Users/daniel/lightning_minimal_example/demo.py", line 21, in <module>
    res = train_conf_and_roc.plot()
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/daniel/anaconda3/envs/lightning_example/lib/python3.11/site-packages/torchmetrics/collections.py", line 657, in plot
    f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
                  ~~~^^^
KeyError: 'MulticlassROC'

Works: without prefix

Code sample
import torchmetrics
import torch
import matplotlib.pyplot as plt

#test data
preds = torch.tensor([[ 0.,  0.,  1.],
                      [ 1.,  0.,  0.],
                      [ 0.,  1.,  0.],
                      [ 0.,  0.,  1.],
                      [ 0.,  0.,  1.],
                      [ 1.,  0.,  0.]])
target = torch.tensor([2, 0, 1, 2, 0, 1])
# base metric collection
conf_and_roc = torchmetrics.MetricCollection([torchmetrics.ROC(task="multiclass", num_classes=3),
                                              torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)])
# clone base metric collection and add prefix
train_conf_and_roc = conf_and_roc.clone()
# update metrics in collection
train_conf_and_roc.update(preds, target)
# plot metrics
res = train_conf_and_roc.plot()
# show resulting figures
plt.show()

Expected behavior

No KeyError when plotting a MetricCollection with a set prefix property.

Environment

  • OS: macOS Sonoma 14.3.1
  • python: 3.11.5
  • torch: 2.2.1
  • torchmetrics: 1.3.1
@damos97 damos97 added bug / fix Something isn't working help wanted Extra attention is needed labels Feb 27, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants