You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
consider $W' = Wa \odot Wb$, we can get $rank(W') \le rank(Wa) \times rank(Wb)$.
And then we use conventional method on $Wa$ and $Wb$. Which means it can use 2x dim to get square rank.
Rank != Information capacity, but they are relative
based on the experiment result from the paper, it seems like although the rank(Wa) * rank(Wb) is just upper bound, but almost everytime it will produce dW with rank = rank(Wa)*rank(Wb).
Why custom backward
with $dW = (Wa_1 \cdot Wa_2) \odot (Wb_1 \cdot Wb_2)$, when you need to calc the backpropogation, you will need $\Delta{dW}$ and $Wa$ to calc $\Delta{Wb}$, also $Wb$ for $\Delta{Wa}$.
With pytorch's autograd, this kind of operation will cache the $Wa$ and $Wb$ for calc the backward, which means it will cache 2x size of weight for backward.
To avoid this terrible situation, I impl a custom backward which will reconstruct $Wa$ and $Wb$ when actually needed, this method saved tons of memory.