-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HausdorffDTLoss leads to GPU memory leak. #7480
Comments
+1 to this. I'm also running into the same exact issue, where there appears to be a GPU memory leak in Monai's Things that I tried to alleviate this issue:
|
Thank you for bringing up this issue. I took the time to delve deeper into the situation using your provided sample code. From my findings, there doesn't appear to be a noticeable GPU memory leak. |
Thank you for your quick response! I appreciate you taking the time to get back to us. I modified the code sample above slightly to plot the memory allocated by PyTorch during the process: import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss
gpu_consumption = []
steps = []
for i in range(0, 100):
B, C, H, W = 16, 5, 512, 512
input = torch.rand(B, C, H, W)
target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)
self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
loss = self(input.to("cuda"), target.to("cuda"))
assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
gpu_consumption.append(memory_consumption)
steps.append(i)
print(f"GPU max memory allocated: {memory_consumption} GB")
plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show() Which generated the following graph: It appears that the memory consumption continuously increases with the number of steps, eventually leading to a crash because the system runs out of available GPU memory. Please let me know if the initialization should be handled different. The person above instantiates a new loss in every loop, which might be linked to what you're talking about about super instantiation taking some additional GPU memory, but I've noticed this independently with my training runs as well. Thanks! |
Hi @SarthakJShetty-path, I used the same code you shared, the graph looks like: What's your PyTorch and MONAI version? |
Interesting! Here are the versions: (venv) ┌─[pc@home] - [~/projects/]
└─[$] <git:(feature_branch*)> python
Python 3.8.10 (default, Nov 22 2023, 10:22:35)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import monai
>>> torch.__version__
'2.1.2+cu121'
>>> monai.__version__
'1.3.0'
>>> |
I've conducted tests using a new 1.3.0 image and unfortunately, I've been unable to reproduce your reported issue. Could I kindly recommend attempting the same process in a fresh environment on your end? |
Sure @KumoLiu I can make a fresh build and recheck this issue in a bit. Just to confirm: You weren't able to replicate the issue even with the Thank you! |
Yes, with the |
@SarthakJShetty-path any updates? |
Sorry about the delay with this. I haven't gotten around to pulling the Docker image and trying, but several members on our team are reporting this issue, with approximately the same CUDA + Torch version. I will pull the Docker image by EoD today and get back to you. Thank you, and sorry for the delay. |
@SarthakJShetty-path I did switch to #4205 ShapeLoss, but don't sure that it brings the expected results. |
Can you try running this piece of code and posting the results? import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss
gpu_consumption = []
steps = []
for i in range(0, 100):
B, C, H, W = 16, 5, 512, 512
input = torch.rand(B, C, H, W)
target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)
self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
loss = self(input.to("cuda"), target.to("cuda"))
assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
gpu_consumption.append(memory_consumption)
steps.append(i)
print(f"GPU max memory allocated: {memory_consumption} GB")
plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show() It looks like @KumoLiu got a much different graph from what I received. |
` GPU max memory allocated: 0.575145984 GB Google Colab notebook: GPU max memory allocated: 0.58143744 GB Local (Windows 11, RTX3060) 2.1.1+cu121 |
Chiming in here, since I worked on #7008 trying to make the HausdorffLoss work with cucim.
I just tested this myself, and I also get an increase in GPU memory usage with each step when running your script. (Windows 11, WSL2, Monai 1.3.0, Pytorch 2.2.1+cuda12.1, Python 3.11.8) import gc
import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss
gpu_consumption = []
steps = []
for i in range(0, 10):
B, C, H, W = 16, 5, 512, 512
input = torch.rand(B, C, H, W)
target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)
self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
loss = self(input.to("cuda"), target.to("cuda"))
assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
gc.collect()
memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
gpu_consumption.append(memory_consumption)
steps.append(i)
print(f"GPU max memory allocated: {memory_consumption} GB")
plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show() This seems to indicate some problem with recognizing unused tensors. Maybe there is some issue with cupy/cucim interoperability and weakrefs created by that? import gc
import numpy as np
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss
from torch.profiler import profile, ProfilerActivity
import torch
import torch.nn
import torch.optim
import torch.profiler
import torch.utils.data
gpu_consumption = []
steps = []
def calculate():
B, C, H, W = 16, 5, 512, 512
input = torch.rand(B, C, H, W)
target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
target = one_hot(target_idx[:, None, ...], num_classes=C)
self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
loss = self(input.to("cuda"), target.to("cuda"))
assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# Enabling with stack creates bad traces
# with_stack=True,
profile_memory=True,
record_shapes=True,
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/memleak'),
) as prof:
for i in range(0, 50):
prof.step()
calculate()
# Adding this line fixes the memory leak
# gc.collect()
memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
gpu_consumption.append(memory_consumption)
steps.append(i)
print(f"GPU max memory allocated: {memory_consumption} GB")
prof.export_chrome_trace("trace.json")
plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show() When running it without stack_traces, the memory view looks like this without the gc.collect() call: |
Thanks! I have tried with cucim on Kaggle Notebook (installed via !pip install cucim-cu12) and there is no gpu memory leak (without gc.collect()). But HausdorffLoss already use distance_transform_edt that use cucim? |
It does use cucim based on this logic, so both cucim and cupy have to be installed: distance_transform_edt, has_cucim = optional_import(
"cucim.core.operations.morphology", name="distance_transform_edt"
)
use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda" Sorry I'm not quite following. So before that, you ran it without cucim installed? |
Ok, so I just checked that myself, and it seems you are right. If you have both cucim and cupy installed there does not seem to be a memory leak anymore |
@KumoLiu I was able to reproduce the issue using the projectmonai/monai:1.3.0 container. If you run On the other hand, if you hardcode monai/transforms/utils.py:2112 (use_cp) to False with cupy installed, the memory leak does not occur. So it seems that cupy being installed changes something somewhere in the memory deallocation or garbage collection. Although I am not sure where this is. |
Thank you for looking into this @johnzielke! So this means that installing |
I mean test it on your setup, but I guess so. You should also get a nice ~10x performance boost in the calculation of the loss with both cupy and cucim since it will run on the GPU then. There should probably be a warning or at least some more docs that explain how to get the calculation to the GPU |
Hi @johnzielke, thanks for the detailed report. Your findings are insightful and indeed point to an interaction between CuPy, garbage collection, and memory deallocation which could be the root cause of the memory leak issue. I agree, looking into this could lead both to resolving the memory leak and potentially offering a substantial performance boost by enabling the loss calculation to run on the GPU. |
No worries @KumoLiu thank you (and @johnzielke) for taking a look at this nonetheless. I'll try installing CuPy and double checking that that avoids this GPU leak. |
Describe the bug
Using this loss method with Trainer from transformers library (Pytorch) and YOLOv8 (Pytorch) leads to crash training shortly after start due to cuda out of memory.
16 gb gpu memory, batch size is 1 with 128*128 image. Training crash after ~ 100 iterations.
Environment
Kaggle Notebook, python 3.10.12, last monai version from pip.
Also reproduced this bug under Windows 11 with code from example:
It ate about 5 gb memory, on the GPU consumption graph it looks like a flat line with several rises.
The text was updated successfully, but these errors were encountered: