Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] DeepNorm/DeepNet support (#227) #230

Merged
merged 3 commits into from
Mar 14, 2022
Merged

[feat] DeepNorm/DeepNet support (#227) #230

merged 3 commits into from
Mar 14, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Mar 8, 2022

What does this PR do?

  • Implements DeepNorm (aka DeepNet paper, 1000 layers Transformer) support, as a "layer norm style" config flag
  • cleans up the full model weight init

Open question : should we support something else than xavier init for the weights in the full model case ?

TODO

  • Implement cleanish support
  • Get some curves out on ViT/CIFAR and GPT

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • Did you update the changelog? (if needed)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 8, 2022
@blefaudeux blefaudeux marked this pull request as draft March 8, 2022 04:35
@blefaudeux
Copy link
Contributor Author

cc @jramapuram, I think that at minima the default weight init was probably not the best, but since it was overridden in your example it probably does not explain #219

blefaudeux added a commit that referenced this pull request Mar 8, 2022
@blefaudeux blefaudeux linked an issue Mar 8, 2022 that may be closed by this pull request
@blefaudeux blefaudeux force-pushed the feat_227 branch 3 times, most recently from f977704 to 33a5131 Compare March 8, 2022 05:40
@codecov-commenter
Copy link

codecov-commenter commented Mar 8, 2022

Codecov Report

Merging #230 (7f0c810) into main (db1ce91) will increase coverage by 0.06%.
The diff coverage is 94.66%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #230      +/-   ##
==========================================
+ Coverage   92.17%   92.24%   +0.06%     
==========================================
  Files          60       60              
  Lines        3247     3313      +66     
==========================================
+ Hits         2993     3056      +63     
- Misses        254      257       +3     
Flag Coverage Δ
Python 92.24% <94.66%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/factory/model_factory.py 95.76% <91.17%> (-1.92%) ⬇️
xformers/factory/block_factory.py 93.95% <94.73%> (+0.01%) ⬆️
xformers/components/residual.py 97.10% <100.00%> (+0.94%) ⬆️
xformers/triton/layer_norm.py 88.13% <0.00%> (+1.69%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update db1ce91...7f0c810. Read the comment docs.

@blefaudeux
Copy link
Contributor Author

example with the microGPT case: not too much change for 8 layers, but 12 layers and the pre-layer norm case NaNs, while with this PR ("DeepNorm") it's smooth sailing, as claimed by the paper
gnome-shell-screenshot-6kq323

same hyperparams, just changing the layer_norm config

@blefaudeux
Copy link
Contributor Author

Codecov Report

Merging #230 (33a5131) into main (f5c1d01) will decrease coverage by 1.11%.
The diff coverage is 41.66%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #230      +/-   ##
==========================================
- Coverage   92.06%   90.94%   -1.12%     
==========================================
  Files          60       60              
  Lines        3227     3292      +65     
==========================================
+ Hits         2971     2994      +23     
- Misses        256      298      +42     

Flag Coverage Δ
Python 90.94% <41.66%> (-1.12%) arrow_down

Flags with carried forward coverage won't be shown. Click here to find out more.
Impacted Files Coverage Δ
xformers/factory/model_factory.py 78.63% <30.30%> (-19.05%) arrow_down
xformers/components/residual.py 80.00% <50.00%> (-15.35%) arrow_down
xformers/factory/block_factory.py 89.56% <52.63%> (-4.38%) arrow_down

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f5c1d01...33a5131. Read the comment docs.

extended unit test since that, should be better now

@blefaudeux
Copy link
Contributor Author

ViT/CIFAR example, pre/post/deep norm, same hyper params. The validation accuracy seems likely to end up close (kind of a dummy example), but here again the dynamics seem a fair bit better with deepnorm

Screenshot-20220307224151-990x897

@@ -248,7 +248,7 @@ def test_step(self, batch, _):

# compute total number of steps
batch_size = BATCH * GPUS
steps = dm.num_samples // batch_size * MAX_EPOCHS
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should have been fixed earlier, I had a PR stack and somehow lost it with rebases..

@@ -172,7 +172,7 @@ def test_pytorch_tranformer_parity(device=torch.device("cuda")):
dim_feedforward=4 * EMB,
dropout=DROP,
activation=ACTIVATION,
layer_norm_eps=1e-05,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should have been part of #221 (1e-6 is the default everywhere now), unrelated to this PR but I spotted that while writing this one



# CREDITS: the following is inspired by FastAI's Transformer implementation
class Residual(nn.Module):
"""Object-oriented handling of the residual path"""
"""Object-oriented handling of the residual path.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR change 1: support scaling the residual path

DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])


def get_deepnorm_coefficients(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR change 2: get the residual scaling and init scaling given the whole model, following the paper

if layer_norm_style == LayerNormStyle.Pre
else PostNorm(d_model, Residual(sublayer), use_triton)
)
if layer_norm_style == LayerNormStyle.Pre:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR change 3: handle 3 layernorm options

@classmethod
def from_config(cls, config: xFormerConfig):
return cls(config.stack_configs, config.tie_embedding_weights)

def _deepnorm_weight_init(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR change 4: handle the required init change, as per the paper

@blefaudeux
Copy link
Contributor Author

cc @stephenroller @suchenzang, just in case you're interested. Small scale tests but seems to confirm the paper indeed

@blefaudeux blefaudeux requested review from jieru-hu, dianaml0 and fmassa and removed request for jieru-hu March 8, 2022 06:58
@blefaudeux blefaudeux changed the title [DRAFT] DeepNorm support (#227) [feat] DeepNorm support (#227) Mar 8, 2022
@blefaudeux blefaudeux marked this pull request as ready for review March 8, 2022 06:58
@blefaudeux blefaudeux changed the title [feat] DeepNorm support (#227) [feat] DeepNorm/DeepNet support (#227) Mar 9, 2022
@blefaudeux
Copy link
Contributor Author

rebased on top of #233, I'll redo some curves for a better comparison

@blefaudeux
Copy link
Contributor Author

rebased on top of #233, I'll redo some curves for a better comparison

it does not change the accuracy difference really, the NaN case for pre-norm seems gone, but there's still an accuracy difference in that case in favor of deepnorm

@blefaudeux blefaudeux mentioned this pull request Mar 12, 2022
6 tasks
@blefaudeux
Copy link
Contributor Author

ping reviewers @fmassa @dianaml0 @jieru-hu

@@ -33,7 +33,7 @@ class VisionTransformer(pl.LightningModule):
def __init__(
self,
steps,
learning_rate=1e-3,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should have been part of #234, lost in a rebase I guess. It's on purpose, seems to work well

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to have this! :) 🚀

@blefaudeux blefaudeux merged commit c8baac0 into main Mar 14, 2022
@blefaudeux blefaudeux deleted the feat_227 branch March 20, 2022 20:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Add DeepNorm/DeepNet residual path
4 participants