Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak in TraceEnumELBO #3131

Merged
merged 6 commits into from
Aug 29, 2022
Merged

Fix memory leak in TraceEnumELBO #3131

merged 6 commits into from
Aug 29, 2022

Conversation

fehiepsi
Copy link
Member

Fixes #3068 and fixes #3014. I'm not sure what changes in PyTorch caused the issue but replacing x[...] with x.clone() fixes the memory leak.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great sleuthing! How did you find the offending line? Did you log tensor count or use any detective tooling code?

pyro/ops/einsum/torch_log.py Show resolved Hide resolved
tests/ops/einsum/test_adjoint.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member Author

How did you find the offending line?

Hi Fritz, I used your trick here to check for memory leaks in this example. After removing some gpu codes, I found that the leak also happens in CPU. Then I first copied the TraceEnum_ELBO implementation and removed line by line to see where is the offending code. That leads to Dice.compute_expectations and then this line with root._pyro_backward() call. When I print out the root, and root._pyro_backward I found that they are pretty simple expressions that only invoke _EinsumBackward and _LeafBackward, which are independent of the rest of pyro codes. I tried to play a little bit with _LeafBackward but couldn't replicate the leak so I guessed that the leak happened in some tensor operator of torch_marginal._EinsumBackward. It is lucky that only pyro.ops.einsum.torch_log.einsum is called and inside it, only operands[0][...] is called, so I found the leak. This process took me several hours though. (-:

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your debugging story!

@fritzo fritzo merged commit 7102cf5 into pyro-ppl:dev Aug 29, 2022
@qinqian
Copy link

qinqian commented Aug 29, 2022

Thank you @fehiepsi for the detailed fixing history!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Memory leak using TraceEnum_ELBO [bug] Memory leak on GPU
3 participants