-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] DeepNorm/DeepNet support (#227) #230
Conversation
cc @jramapuram, I think that at minima the default weight init was probably not the best, but since it was overridden in your example it probably does not explain #219 |
following up on that with #230
f977704
to
33a5131
Compare
Codecov Report
@@ Coverage Diff @@
## main #230 +/- ##
==========================================
+ Coverage 92.17% 92.24% +0.06%
==========================================
Files 60 60
Lines 3247 3313 +66
==========================================
+ Hits 2993 3056 +63
- Misses 254 257 +3
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
extended unit test since that, should be better now |
examples/microViT.py
Outdated
@@ -248,7 +248,7 @@ def test_step(self, batch, _): | |||
|
|||
# compute total number of steps | |||
batch_size = BATCH * GPUS | |||
steps = dm.num_samples // batch_size * MAX_EPOCHS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should have been fixed earlier, I had a PR stack and somehow lost it with rebases..
@@ -172,7 +172,7 @@ def test_pytorch_tranformer_parity(device=torch.device("cuda")): | |||
dim_feedforward=4 * EMB, | |||
dropout=DROP, | |||
activation=ACTIVATION, | |||
layer_norm_eps=1e-05, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should have been part of #221 (1e-6 is the default everywhere now), unrelated to this PR but I spotted that while writing this one
xformers/components/residual.py
Outdated
|
||
|
||
# CREDITS: the following is inspired by FastAI's Transformer implementation | ||
class Residual(nn.Module): | ||
"""Object-oriented handling of the residual path""" | ||
"""Object-oriented handling of the residual path. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR change 1: support scaling the residual path
DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"]) | ||
|
||
|
||
def get_deepnorm_coefficients( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR change 2: get the residual scaling and init scaling given the whole model, following the paper
if layer_norm_style == LayerNormStyle.Pre | ||
else PostNorm(d_model, Residual(sublayer), use_triton) | ||
) | ||
if layer_norm_style == LayerNormStyle.Pre: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR change 3: handle 3 layernorm options
@classmethod | ||
def from_config(cls, config: xFormerConfig): | ||
return cls(config.stack_configs, config.tie_embedding_weights) | ||
|
||
def _deepnorm_weight_init(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR change 4: handle the required init change, as per the paper
cc @stephenroller @suchenzang, just in case you're interested. Small scale tests but seems to confirm the paper indeed |
rebased on top of #233, I'll redo some curves for a better comparison |
it does not change the accuracy difference really, the NaN case for pre-norm seems gone, but there's still an accuracy difference in that case in favor of deepnorm |
@@ -33,7 +33,7 @@ class VisionTransformer(pl.LightningModule): | |||
def __init__( | |||
self, | |||
steps, | |||
learning_rate=1e-3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should have been part of #234, lost in a rebase I guess. It's on purpose, seems to work well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to have this! :) 🚀
What does this PR do?
Open question : should we support something else than xavier init for the weights in the full model case ?
TODO
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.