You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi Henry, Thanks a lot for your suggestion! You are right that the current way of computing example-wise gradients is unnecessarily inefficient. Using autograd-hacks could be a good workaround, but it looks like it currently only supports Linear and Conv2d-layers (but not, e.g., BatchNorm).
Another possibility would be to implement the AND-mask in JAX, where vmap makes it easy to compute example-wise gradients natively.
Hi! This is really interesting work and it's great you released the code like this. I just thought it would be worth mentioning - it seems like the way you calculate the gradients for different environments is a bit inefficient (basically using a for loop, right?) It might be worth checking out this: https://github.com/cybertronai/autograd-hacks#per-example-gradients (which I came to from a thread here: https://discuss.pytorch.org/t/how-to-efficiently-compute-gradient-for-each-training-sample/60001), which in theory should allow you to efficiently compute per example gradients.
The text was updated successfully, but these errors were encountered: