Skip to content

Commit

Permalink
Add an LLM fine-tuning example (#90)
Browse files Browse the repository at this point in the history
* WIP: Add an LLM finetuning example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* WIP: add / rename more configs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Finetuning example seems to be working

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Making progress, more self-contained example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Works! (need to fix the hash used for path though)

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Improve hashing, reduce default block size

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix val_loss logging and add docstring

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Increase the number of dataloader workers

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use smaller model for now

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use FSDP in the example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug in id generation from config classes

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Tweak config, try to setup mid-epoch checkpointing

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Rename `HFExample` -> `TextClassificationExample`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix broken links in nav

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove "huggingface" datamodule config

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issues in config/tests for text_classification

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add an entry to test the llm_finetuning_example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issues in the text classification example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix weird docstring issues with hydra-zen

- mit-ll-responsible-ai/hydra-zen#750

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix test and config of text_classification_example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move test from main_test.py to example_test.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* forward_pass is a method of LearningAlgorithmTests

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Various type hint fixes and tweaks

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* WIP: Adding some tests for LLM finetuning example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue in `jax.md`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add link to the example page in index.md

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix tests for the llm finetuning example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with tuples in regression files

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix test for `get_hash_of`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove unused _field function

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with built-in modules in autoref plugin

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add a bit of info in the example doc

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add more links in the doc of the module

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with the text classification example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add skipif mark for LLM finetuning test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix data_dir of text_classification_example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use the "auto" strategy for LLM Finetuning tests

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix error in fork_rng of LLM finetuning example

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Try a hacky fix for failing test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Don't run llm finetuning tests on github Cloud CI

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add missing regression files

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Rename llm_finetuning_example -> llm_finetuning

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix import error

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Nov 15, 2024
1 parent b7b52f5 commit f8e3a22
Show file tree
Hide file tree
Showing 38 changed files with 1,319 additions and 349 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
attention_mask:
device: cpu
max: 1
mean: '1.021e-01'
min: 0
shape:
- 32
- 128
sum: 418
input_ids:
device: cpu
max: 29043
mean: '1.648e+02'
min: 0
shape:
- 32
- 128
sum: 675172
labels:
device: cpu
max: -1
mean: '-1.e+00'
min: -1
shape:
- 32
sum: -32
token_type_ids:
device: cpu
max: 0
mean: '0.e+00'
min: 0
shape:
- 32
- 128
sum: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
attention_mask:
device: cpu
max: 1
mean: '8.374e-02'
min: 0
shape:
- 32
- 128
sum: 343
input_ids:
device: cpu
max: 26101
mean: '1.597e+02'
min: 0
shape:
- 32
- 128
sum: 654306
labels:
device: cpu
max: 1
mean: '7.188e-01'
min: 0
shape:
- 32
sum: 23
token_type_ids:
device: cpu
max: 0
mean: '0.e+00'
min: 0
shape:
- 32
- 128
sum: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
attention_mask:
device: cpu
max: 1
mean: '9.277e-02'
min: 0
shape:
- 32
- 128
sum: 380
input_ids:
device: cpu
max: 29043
mean: '1.362e+02'
min: 0
shape:
- 32
- 128
sum: 557879
labels:
device: cpu
max: 1
mean: '7.5e-01'
min: 0
shape:
- 32
sum: 24
token_type_ids:
device: cpu
max: 0
mean: '0.e+00'
min: 0
shape:
- 32
- 128
sum: 0
3 changes: 2 additions & 1 deletion docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
* [Examples 🧪](examples/index.md)
* [Image Classification (⚡)](examples/torch_sl_example.md)
* [Image Classification (jax+⚡)](examples/jax_sl_example.md)
* [NLP (🤗+⚡)](examples/nlp.md)
* [Text Classification (🤗+⚡)](examples/text_classification.md)
* [Fine-tuning an LLM (🤗+⚡)](examples/llm_finetuning.md)
* [RL (jax)](examples/jax_rl_example.md)
* [Running sweeps](examples/sweeps.md)
* [Profiling your code📎](examples/profiling.md)
Expand Down
23 changes: 17 additions & 6 deletions docs/examples/index.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
---
additional_python_references:
- project.algorithms.jax_rl_example
- project.algorithms.example
- project.algorithms.jax_example
- project.algorithms.text_classification_example
- project.algorithms.llm_finetuning
- project.trainers.jax_trainer
---

# Examples

This template includes examples that use either Jax, PyTorch, or both!

| Example link | Research Area | Reference link | Frameworks |
| --------------------------------------- | ------------------------------------------ | ------------------ | --------------- |
| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚡ |
| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚡ |
| [HFExample](nlp.md) | NLP (text classification) | `HFExample` | Torch + 🤗 + ⚡ |
| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax |
| Example link | Research Area | Reference link | Frameworks |
| --------------------------------------------------- | ------------------------------------------ | --------------------------- | --------------- |
| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚡ |
| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚡ |
| [TextClassificationExample](text_classification.md) | NLP (text classification) | `TextClassificationExample` | Torch + 🤗 + ⚡ |
| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax |
| [LLMFinetuningExample](llm_finetuning.md) | NLP (Causal language modeling) | `LLMFineTuningExample` | Torch + 🤗 + ⚡ |
22 changes: 22 additions & 0 deletions docs/examples/llm_finetuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
---
additional_python_references:
- project.algorithms.llm_finetuning
---
# Fine-tuning LLMs

This example is based on [this language modeling example from the HuggingFace transformers documentation](https://huggingface.co/docs/transformers/en/tasks/language_modeling).

To better understand what's going on in this example, it is a good idea to read through these tutorials first:
* [Causal language modeling simple example - HuggingFace docs](https://huggingface.co/docs/transformers/en/tasks/language_modeling)
* [Fine-tune a language model - Colab Notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb#scrollTo=X6HrpprwIrIz)

The main difference between this example and the original example from HuggingFace is that the `LLMFinetuningExample` is a `LightningModule`, that is trained by a `lightning.Trainer`.

This also means that this example doesn't use [`accelerate`](https://huggingface.co/docs/accelerate/en/index) or the HuggingFace Trainer.


## Running the example

```console
python project/main.py experiment=llm_finetuning_example
```
42 changes: 0 additions & 42 deletions docs/examples/nlp.md

This file was deleted.

41 changes: 41 additions & 0 deletions docs/examples/text_classification.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Text Classification ( + 🤗)

## Overview

The [TextClassificationExample][project.algorithms.text_classification_example.TextClassificationExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple text classification task.

It accepts a [TextClassificationDataModule][project.datamodules.text.TextClassificationDataModule] as input, along with a network.

??? note "Click to show the code for HFExample"
{{ inline('project.algorithms.text_classification_example.TextClassificationExample', 4) }}

## Config files

### Algorithm config

??? note "Click to show the Algorithm config"
Source: project/configs/algorithm/text_classification_example.yaml

{{ inline('project/configs/algorithm/text_classification_example.yaml', 4) }}

### Datamodule config

??? note "Click to show the Datamodule config"
Source: project/configs/datamodule/glue_cola.yaml

{{ inline('project/configs/datamodule/glue_cola.yaml', 4) }}

## Running the example

Here is a configuration file that you can use to launch a simple experiment:

??? note "Click to show the yaml config file"
Source: project/configs/experiment/text_classification_example.yaml

{{ inline('project/configs/experiment/text_classification_example.yaml', 4) }}

You can use it like so:

```console
python project/main.py experiment=text_classification_example
```
16 changes: 9 additions & 7 deletions docs/features/jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ additional_python_references:
- project.algorithms.jax_rl_example
- project.algorithms.example
- project.algorithms.jax_example
- project.algorithms.hf_example
- project.algorithms.text_classification_example
- project.trainers.jax_trainer
---

Expand All @@ -13,12 +13,14 @@ additional_python_references:
This template includes examples that use either Jax, PyTorch, or both!

| Example link | Reference | Framework | Lightning? |
| ------------------------------------------------- | ------------------ | ----------- | ------------ |
| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes |
| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes |
| [HFExample](../examples/nlp.md) | `HFExample` | Torch + 🤗 | yes |
| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) |
<!-- TODO: De-duplicate: This is a bit like a duplicate of the table from the examples/index.md -->

| Example link | Reference | Framework | Lightning? |
| --------------------------------------------------------------- | --------------------------- | ----------- | ------------ |
| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes |
| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes |
| [TextClassificationExample](../examples/text_classification.md) | `TextClassificationExample` | Torch + 🤗 | yes |
| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) |


In fact, here you can mix and match both Jax and Torch code. For example, you can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.
Expand Down
4 changes: 2 additions & 2 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .example import ExampleAlgorithm
from .hf_example import HFExample
from .jax_example import JaxExample
from .jax_rl_example import JaxRLExample
from .no_op import NoOp
from .text_classification_example import TextClassificationExample

__all__ = [
"ExampleAlgorithm",
"JaxExample",
"NoOp",
"HFExample",
"TextClassificationExample",
"JaxRLExample",
]
8 changes: 8 additions & 0 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import time
from typing import Any, Literal

import jax
import torch
from lightning import LightningModule, Trainer
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -90,6 +92,12 @@ def log(
def get_num_samples(self, batch: BatchType) -> int:
if is_sequence_of(batch, Tensor):
return batch[0].shape[0]
if isinstance(batch, dict):
return next(
v.shape[0]
for v in jax.tree.leaves(batch)
if isinstance(v, torch.Tensor) and v.ndim > 1
)
raise NotImplementedError(
f"Don't know how many 'samples' there are in batch of type {type(batch)}"
)
Expand Down
17 changes: 17 additions & 0 deletions project/algorithms/example_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Example showing how the test suite can be used to add tests for a new algorithm."""

import pytest
import torch
from transformers import PreTrainedModel

from project.algorithms.testsuites.algorithm_tests import LearningAlgorithmTests
from project.configs import Config
from project.conftest import command_line_overrides
from project.datamodules.image_classification.cifar10 import CIFAR10DataModule
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
Expand All @@ -12,6 +16,19 @@
from .example import ExampleAlgorithm


@pytest.mark.parametrize(
command_line_overrides.__name__, ["algorithm=example datamodule=cifar10"], indirect=True
)
def test_example_experiment_defaults(experiment_config: Config) -> None:
"""Test to check that the datamodule is required (even when just an algorithm is set?!)."""

assert experiment_config.algorithm["_target_"] == (
ExampleAlgorithm.__module__ + "." + ExampleAlgorithm.__qualname__
)

assert isinstance(experiment_config.datamodule, CIFAR10DataModule)


@run_for_all_configs_of_type("algorithm", ExampleAlgorithm)
@run_for_all_configs_of_type("datamodule", ImageClassificationDataModule)
@run_for_all_configs_of_type("algorithm/network", torch.nn.Module, excluding=PreTrainedModel)
Expand Down
Loading

0 comments on commit f8e3a22

Please sign in to comment.