Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

binary_precision_recall_curve with thresholds: memory issues #1722

Closed
jogepari opened this issue Apr 20, 2023 · 4 comments
Closed

binary_precision_recall_curve with thresholds: memory issues #1722

jogepari opened this issue Apr 20, 2023 · 4 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed v0.11.x

Comments

@jogepari
Copy link

🐛 Bug

Using binary_precision_recall_curve with thresholds set results in huge memory consumption (and fast decreasing time savings)

To Reproduce

Notebook demo using memory_profiler

Preview
import torchmetrics
torchmetrics.__version__
'0.11.4'
import torch
from sklearn.metrics import precision_recall_curve
from torchmetrics.functional.classification import binary_precision_recall_curve

%load_ext memory_profiler
n_samples = 20_000_000
proba = torch.rand(n_samples)
labels = torch.randint(2, (n_samples, ))

sklearn for comparison

%time %memit precision_recall_curve(labels.numpy(), proba.numpy())
peak memory: 1371.58 MiB, increment: 733.07 MiB
CPU times: total: 4.11 s
Wall time: 6.65 s

No thresholds

%time %memit binary_precision_recall_curve(proba, labels)
peak memory: 1313.06 MiB, increment: 672.86 MiB
CPU times: total: 6.95 s
Wall time: 6.74 s

Thresholds with steps: 0.1, 0.05, 0.025

ths_defined = torch.arange(0.1, 1.00000001, step=0.1)
%time %memit binary_precision_recall_curve(proba, labels, thresholds=ths_defined)
peak memory: 5166.79 MiB, increment: 4522.11 MiB
CPU times: total: 4.17 s
Wall time: 2.09 s
ths_defined = torch.arange(0.1, 1.00000001, step=0.05)
%time %memit binary_precision_recall_curve(proba, labels, thresholds=ths_defined)
peak memory: 9196.10 MiB, increment: 8551.08 MiB
CPU times: total: 7.94 s
Wall time: 2.94 s
ths_defined = torch.arange(0.1, 1.00000001, step=0.025)
%time %memit binary_precision_recall_curve(proba, labels, thresholds=ths_defined)
peak memory: 17557.59 MiB, increment: 16912.54 MiB
CPU times: total: 15.4 s
Wall time: 5.59 s

Expected behavior

Being both faster and more memory efficient than without thresholds set.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.11.4, conda-forge
  • Python & PyTorch Version (e.g., 1.0): python 3.10.10, pytorch 2.0.0
  • Any other relevant information such as OS (e.g., Linux): Windows 11
@jogepari jogepari added bug / fix Something isn't working help wanted Extra attention is needed labels Apr 20, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @jogepari,
Happy to report that this have already been fixed in PR #1493. Here is a snippet running your code on master branch:
image
As you can see memory is now constant over a certain point.

To get the changes now, please install from master:
pip install https://github.com/Lightning-AI/metrics/archive/master.zip
or wait for the next release.

@jogepari
Copy link
Author

Thank you, @SkafteNicki ! May I ask, seeing the fix was merged into master on Feb 23rd, how come it wasn't included in two releases since then?

Also, on your screenshot, calculation time with step=0.025 is already exceeding one without thresholds, it this normal behaviour?

@SkafteNicki
Copy link
Member

@jogepari I can see its due to it being in the changed section https://github.com/Lightning-AI/torchmetrics/blob/master/CHANGELOG.md#changed and not the bugfix fixed section. Only linked PRs in the fixed section is included in bugfix releases.
We can argue if this is a bugfix, but the author of the PR saw it more as in improvement/change because the algorithm is actually working, it is just consuming a lot of memory.

Also, on your screenshot, calculation time with step=0.025 is already exceeding one without thresholds, it this normal behaviour?

So the consequence of lowering the memory requirements is that we need to use an alternative algorithm that is slower when using a lot of thresholds. So yes, I would say the results are expected.

@Borda Borda added the v0.11.x label Aug 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v0.11.x
Projects
None yet
Development

No branches or pull requests

3 participants