-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an LLM fine-tuning example (#90)
* WIP: Add an LLM finetuning example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * WIP: add / rename more configs Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Finetuning example seems to be working Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Making progress, more self-contained example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Works! (need to fix the hash used for path though) Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Improve hashing, reduce default block size Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix val_loss logging and add docstring Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Increase the number of dataloader workers Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use smaller model for now Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use FSDP in the example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix bug in id generation from config classes Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Tweak config, try to setup mid-epoch checkpointing Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Rename `HFExample` -> `TextClassificationExample` Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix broken links in nav Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Remove "huggingface" datamodule config Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issues in config/tests for text_classification Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add an entry to test the llm_finetuning_example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issues in the text classification example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix weird docstring issues with hydra-zen - mit-ll-responsible-ai/hydra-zen#750 Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix test and config of text_classification_example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move test from main_test.py to example_test.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * forward_pass is a method of LearningAlgorithmTests Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Various type hint fixes and tweaks Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * WIP: Adding some tests for LLM finetuning example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issue in `jax.md` Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add link to the example page in index.md Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix tests for the llm finetuning example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issue with tuples in regression files Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix test for `get_hash_of` Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Remove unused _field function Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issue with built-in modules in autoref plugin Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add a bit of info in the example doc Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add more links in the doc of the module Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issue with the text classification example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add skipif mark for LLM finetuning test Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix data_dir of text_classification_example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use the "auto" strategy for LLM Finetuning tests Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix error in fork_rng of LLM finetuning example Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Try a hacky fix for failing test Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Don't run llm finetuning tests on github Cloud CI Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add missing regression files Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Rename llm_finetuning_example -> llm_finetuning Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix import error Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --------- Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
- Loading branch information
Showing
38 changed files
with
1,319 additions
and
349 deletions.
There are no files selected for viewing
35 changes: 35 additions & 0 deletions
35
...project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
attention_mask: | ||
device: cpu | ||
max: 1 | ||
mean: '1.021e-01' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 418 | ||
input_ids: | ||
device: cpu | ||
max: 29043 | ||
mean: '1.648e+02' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 675172 | ||
labels: | ||
device: cpu | ||
max: -1 | ||
mean: '-1.e+00' | ||
min: -1 | ||
shape: | ||
- 32 | ||
sum: -32 | ||
token_type_ids: | ||
device: cpu | ||
max: 0 | ||
mean: '0.e+00' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 0 |
35 changes: 35 additions & 0 deletions
35
...roject/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
attention_mask: | ||
device: cpu | ||
max: 1 | ||
mean: '8.374e-02' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 343 | ||
input_ids: | ||
device: cpu | ||
max: 26101 | ||
mean: '1.597e+02' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 654306 | ||
labels: | ||
device: cpu | ||
max: 1 | ||
mean: '7.188e-01' | ||
min: 0 | ||
shape: | ||
- 32 | ||
sum: 23 | ||
token_type_ids: | ||
device: cpu | ||
max: 0 | ||
mean: '0.e+00' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 0 |
35 changes: 35 additions & 0 deletions
35
...ect/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
attention_mask: | ||
device: cpu | ||
max: 1 | ||
mean: '9.277e-02' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 380 | ||
input_ids: | ||
device: cpu | ||
max: 29043 | ||
mean: '1.362e+02' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 557879 | ||
labels: | ||
device: cpu | ||
max: 1 | ||
mean: '7.5e-01' | ||
min: 0 | ||
shape: | ||
- 32 | ||
sum: 24 | ||
token_type_ids: | ||
device: cpu | ||
max: 0 | ||
mean: '0.e+00' | ||
min: 0 | ||
shape: | ||
- 32 | ||
- 128 | ||
sum: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,21 @@ | ||
--- | ||
additional_python_references: | ||
- project.algorithms.jax_rl_example | ||
- project.algorithms.example | ||
- project.algorithms.jax_example | ||
- project.algorithms.text_classification_example | ||
- project.algorithms.llm_finetuning | ||
- project.trainers.jax_trainer | ||
--- | ||
|
||
# Examples | ||
|
||
This template includes examples that use either Jax, PyTorch, or both! | ||
|
||
| Example link | Research Area | Reference link | Frameworks | | ||
| --------------------------------------- | ------------------------------------------ | ------------------ | --------------- | | ||
| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚡ | | ||
| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚡ | | ||
| [HFExample](nlp.md) | NLP (text classification) | `HFExample` | Torch + 🤗 + ⚡ | | ||
| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax | | ||
| Example link | Research Area | Reference link | Frameworks | | ||
| --------------------------------------------------- | ------------------------------------------ | --------------------------- | --------------- | | ||
| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚡ | | ||
| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚡ | | ||
| [TextClassificationExample](text_classification.md) | NLP (text classification) | `TextClassificationExample` | Torch + 🤗 + ⚡ | | ||
| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax | | ||
| [LLMFinetuningExample](llm_finetuning.md) | NLP (Causal language modeling) | `LLMFineTuningExample` | Torch + 🤗 + ⚡ | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
--- | ||
additional_python_references: | ||
- project.algorithms.llm_finetuning | ||
--- | ||
# Fine-tuning LLMs | ||
|
||
This example is based on [this language modeling example from the HuggingFace transformers documentation](https://huggingface.co/docs/transformers/en/tasks/language_modeling). | ||
|
||
To better understand what's going on in this example, it is a good idea to read through these tutorials first: | ||
* [Causal language modeling simple example - HuggingFace docs](https://huggingface.co/docs/transformers/en/tasks/language_modeling) | ||
* [Fine-tune a language model - Colab Notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb#scrollTo=X6HrpprwIrIz) | ||
|
||
The main difference between this example and the original example from HuggingFace is that the `LLMFinetuningExample` is a `LightningModule`, that is trained by a `lightning.Trainer`. | ||
|
||
This also means that this example doesn't use [`accelerate`](https://huggingface.co/docs/accelerate/en/index) or the HuggingFace Trainer. | ||
|
||
|
||
## Running the example | ||
|
||
```console | ||
python project/main.py experiment=llm_finetuning_example | ||
``` |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Text Classification ( + 🤗) | ||
|
||
## Overview | ||
|
||
The [TextClassificationExample][project.algorithms.text_classification_example.TextClassificationExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple text classification task. | ||
|
||
It accepts a [TextClassificationDataModule][project.datamodules.text.TextClassificationDataModule] as input, along with a network. | ||
|
||
??? note "Click to show the code for HFExample" | ||
{{ inline('project.algorithms.text_classification_example.TextClassificationExample', 4) }} | ||
|
||
## Config files | ||
|
||
### Algorithm config | ||
|
||
??? note "Click to show the Algorithm config" | ||
Source: project/configs/algorithm/text_classification_example.yaml | ||
|
||
{{ inline('project/configs/algorithm/text_classification_example.yaml', 4) }} | ||
|
||
### Datamodule config | ||
|
||
??? note "Click to show the Datamodule config" | ||
Source: project/configs/datamodule/glue_cola.yaml | ||
|
||
{{ inline('project/configs/datamodule/glue_cola.yaml', 4) }} | ||
|
||
## Running the example | ||
|
||
Here is a configuration file that you can use to launch a simple experiment: | ||
|
||
??? note "Click to show the yaml config file" | ||
Source: project/configs/experiment/text_classification_example.yaml | ||
|
||
{{ inline('project/configs/experiment/text_classification_example.yaml', 4) }} | ||
|
||
You can use it like so: | ||
|
||
```console | ||
python project/main.py experiment=text_classification_example | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,13 @@ | ||
from .example import ExampleAlgorithm | ||
from .hf_example import HFExample | ||
from .jax_example import JaxExample | ||
from .jax_rl_example import JaxRLExample | ||
from .no_op import NoOp | ||
from .text_classification_example import TextClassificationExample | ||
|
||
__all__ = [ | ||
"ExampleAlgorithm", | ||
"JaxExample", | ||
"NoOp", | ||
"HFExample", | ||
"TextClassificationExample", | ||
"JaxRLExample", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.