-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subtle memory leak in _ReversibleModuleFunction #1
Comments
Just an update: using |
Thank you very much for reporting this issue and also providing a fix! I am currently on vacation and will have a look at this as soon as I get back home. |
Hi Robin, thanks for your feedback! In the meantime I found myself working with |
Hi @spezold , I have created Version 0.2.0 with the fix that you suggested ( |
Hi, first of all: very nice work and congrats to your MICCAI paper!
I would like to point out to you a subtle memory leak in
_ReversibleModuleFunction
, which is due to not usingctx.save_for_backward()
for storingx
. The memory leak occurs under rare conditions, namely if a network output is not consumed by the loss term, thus it is not backpropagated through, and thusdel ctx.y
in_ReversibleModuleFunction.backward()
never happens, as_ReversibleModuleFunction.backward()
for this network output is never called in the first place (at least, this is my uneducated guess on the source of the leak).Consider the following minimal example:
As you can see, the second network output,
y2
, is nowhere used in the loss calculation, and memory consumption is building up. I found two ways to fix the leak:ctx.save_for_backward()
andctx.saved_tensors
for storing and retrievingx
, respectively, in_ReversibleModuleFunction
.Maybe you want to try to reproduce the memory leak, as I am not sure if it depends on the PyTorch version and/or operating system (my setup is PyTorch 1.2.0 on Windows 10). You may then want to decide whether you change the implementation of
_ReversibleModuleFunction
or whether you point out to the users the need to "consume" all network outputs, as described above.The text was updated successfully, but these errors were encountered: