Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OLMo November 2024 #34551

Merged
merged 26 commits into from
Nov 18, 2024
Merged

Conversation

2015aroras
Copy link
Contributor

@2015aroras 2015aroras commented Oct 31, 2024

What does this PR do?

An updated OLMo model will be released in November. The new model has a few small architecture changes compared to the existing model in transformers:

  • RMSNorm is used instead of standard layer norm.
  • Norm is applied to attention queries and keys.
  • Norm is applied after attention/feedforward rather than before.

The original PR #34497 updated the OLMo implementation in transformers to support the November release. This PR instead adds a new model using the modular approach.

@ArthurZucker

Fixes #34496

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@2015aroras
Copy link
Contributor Author

I tested this before we locked in the Olmo1124 naming conventions. Will update once tested.

@2015aroras 2015aroras marked this pull request as draft October 31, 2024 23:48
@2015aroras 2015aroras marked this pull request as ready for review November 4, 2024 23:42
@2015aroras
Copy link
Contributor Author

2015aroras commented Nov 4, 2024

Tests are passing, including slow ones (except for Olmo1124ModelTest::test_generate_compile_1_end_to_end, but this appears to be broken for base OLMo too so I'm considering it an existing problem).. I've used a test HF hub repo (shanearora/OLMo-7B-1124-hf) since the official final model is not ready yet.

@2015aroras
Copy link
Contributor Author

PR checks were passing before I merged main again, and PR check failures relate to other models.

@2015aroras
Copy link
Contributor Author

@ArthurZucker Gentle ping

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks marvellous thanks for your hard work, let's get this merged asap! 🤗 Left very small comments it's great.
apologies again for the delay


## Overview

The OLMo November 2024 model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these would need to be filled

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new init should look more like thisL

# you may not use this file except in compliance with the License.

Should make it simpler 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f29dc50 Done, though it took me a while to debug why it wasn't working. The simplified init requires __all__ to be explicitly set: 3960e35.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very well down, the modular is super simple, makes it easy to identify differences!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than some difficulties getting started/finished and a bit less docs, modular was a nice experience!

convert_and_export_with_cache,
)

olmo_1124_model = "shanearora/OLMo-7B-1124-hf"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the final checkpoint? 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I just grabbed an intermediate checkpoint to use for the implementation. It's from pretty close to end of training.

We will upload the official final and intermediate checkpoints in an official HF Hub repo under the allenai org. Now that this PR is approved, I think we can start the uploading.

Comment on lines 438 to 446
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
},
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's create it outside the call!

Copy link
Contributor Author

@2015aroras 2015aroras Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5b7cad9 This was auto-generated from transformers-cli add-new-model-like, but fixed anyways.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

Something has gone wrong with the rebasing it seems 😓

@ArthurZucker
Copy link
Collaborator

We can merge once this is fixed!

@2015aroras
Copy link
Contributor Author

I'm just going to add something for the model card (better than the blank state it is, we can change it later).

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging! Thanks for the clean work 🤗

Comment on lines +26 to +28
- RMSNorm is used instead of standard layer norm.
- Norm is applied to attention queries and keys.
- Norm is applied after attention/feedforward layers rather than before.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💘 super clear, love this!

@ArthurZucker ArthurZucker merged commit 3ee24e2 into huggingface:main Nov 18, 2024
22 of 26 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Add model skeletion with transformers-cli add-new-model-like

* Convert config to modular, add rms_norm_eps, delete clip_qkv

* Convert model to modular, add RMSNorm

* Add flash attention with qk norm and no qkv clipping

* Add decoder layer with RMSNorm after attention/feedforward layers

* Add base and causal model

* Add converter improvements from OLMo repo

* Update weight loading in OLMo to HF converter

* Set correct default for rms_norm_eps

* Set correct pipeline_model_mapping in test

* Run make fixup

* Fix model type

* Re-run modular conversion

* Manually set config docs to fix build errors

* Convert olmo-1124 to olmo_1124 to fix flash attention docs errors

* Start updating tests

* Update tests

* Copy upstream test_eager_matches_sdpa_inference_1_bfloat16 changes to olmo_1124

* Rename input_layernorm and post_attention_layernorm to reflect their ops better

* Use correct tokenizer

* Remove test unsupported by GPT2 tokenizer

* Create GenerationConfig outside of from_pretrained call

* Use simpler init file structure

* Add explicit __all__ to support simplified init

* Make safetensor serialization the default

* Update OLMo November 2024 docs
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Add model skeletion with transformers-cli add-new-model-like

* Convert config to modular, add rms_norm_eps, delete clip_qkv

* Convert model to modular, add RMSNorm

* Add flash attention with qk norm and no qkv clipping

* Add decoder layer with RMSNorm after attention/feedforward layers

* Add base and causal model

* Add converter improvements from OLMo repo

* Update weight loading in OLMo to HF converter

* Set correct default for rms_norm_eps

* Set correct pipeline_model_mapping in test

* Run make fixup

* Fix model type

* Re-run modular conversion

* Manually set config docs to fix build errors

* Convert olmo-1124 to olmo_1124 to fix flash attention docs errors

* Start updating tests

* Update tests

* Copy upstream test_eager_matches_sdpa_inference_1_bfloat16 changes to olmo_1124

* Rename input_layernorm and post_attention_layernorm to reflect their ops better

* Use correct tokenizer

* Remove test unsupported by GPT2 tokenizer

* Create GenerationConfig outside of from_pretrained call

* Use simpler init file structure

* Add explicit __all__ to support simplified init

* Make safetensor serialization the default

* Update OLMo November 2024 docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for OLMo November release
3 participants