-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metric not moved to device #531
Comments
Hi! thanks for your contribution!, great first issue! |
Interestingly, any states registered inside the metric are moved to the right device but |
see #340 |
@Borda The linked issue does not resolve this problem. I am already doing what it recommends. The documentation claims that the metric's device will be updated, but it is not. I consider this a bug report not a question. Either the documentation or implementation are wrong. Please reopen this issue. |
I see, then the issue is in docs, no metrics is automatically moved unless you use it with PL logging... Mind send PR fix for it? |
The example code found at https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devices does say anything about having to use PL logging. Quoting the relevant parts:
It sounds a bit odd that you have to use PL logging in order for a metric to get moved to the correct device... Can you point me to the relevant code in PL that moves the metric? |
The problem seems to be that if you call for m in self.modules():
m.cuda() but it instead calls for module in self.children():
module._apply(fn) this will move the metric states to the correct device, but currently the |
This is not resolved in 2.0.1. Setting up a
And logging with
Gives the error
Where the stacktrace errors on
|
@jlehrer1in your case, it has to with the initialization which should use a self.metrics = torch.nn.ModuleDict({
"train": {"accuracy", Accuracy(task="binary")},
"val": {"accuracy", Accuracy(task="binary")}
}) You can read more about why here: https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metrics-and-devices |
@SkafteNicki, it doesn't seem like your latest code snippet works with PyTorch 2.0+, since a |
Hi @amorehead, in my last example I do not refer to a standard |
Hi, @SkafteNicki. When I referred to the |
@amorehead thanks, and sorry for the confusion on my part, you are indeed correct :) |
No worries! Thanks for the original suggestion. It reminded me to organize my torchmetrics more cleanly :) |
🐛 Bug
Version 1.4.7
Per https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devices if a metric is properly defined (identified as a child of a module) then it is supposed to be automatically moved to the same device as the module. Unfortunately, in my own project this does not occur.
When I run this code:
I get:
Expected behavior
I expect the metric to be on
cuda:0
.Environment
conda
,pip
, source): pipThe text was updated successfully, but these errors were encountered: