Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving the efficiency of the gradient calculation #1

Open
henrycharlesworth opened this issue Aug 10, 2021 · 1 comment
Open

Improving the efficiency of the gradient calculation #1

henrycharlesworth opened this issue Aug 10, 2021 · 1 comment

Comments

@henrycharlesworth
Copy link

Hi! This is really interesting work and it's great you released the code like this. I just thought it would be worth mentioning - it seems like the way you calculate the gradients for different environments is a bit inefficient (basically using a for loop, right?) It might be worth checking out this: https://github.com/cybertronai/autograd-hacks#per-example-gradients (which I came to from a thread here: https://discuss.pytorch.org/t/how-to-efficiently-compute-gradient-for-each-training-sample/60001), which in theory should allow you to efficiently compute per example gradients.

@neitzal
Copy link
Collaborator

neitzal commented Aug 11, 2021

Hi Henry, Thanks a lot for your suggestion! You are right that the current way of computing example-wise gradients is unnecessarily inefficient. Using autograd-hacks could be a good workaround, but it looks like it currently only supports Linear and Conv2d-layers (but not, e.g., BatchNorm).
Another possibility would be to implement the AND-mask in JAX, where vmap makes it easy to compute example-wise gradients natively.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants