-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
baichuan2_7b.py
287 lines (242 loc) · 10.6 KB
/
baichuan2_7b.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
import nemo_run as run
import pytorch_lightning as pl
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from pytorch_lightning.callbacks.callback import Callback
from nemo import lightning as nl
from nemo.collections.llm import Baichuan2Config7B, Baichuan2Model
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.peft.lora import LoRA
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback
NAME = "baichuan2_7b"
@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Baichuan2 7B model configuration.
Returns:
run.Config[pl.LightningModule]: Configuration for the Baichuan2 7B model.
Examples:
CLI usage:
$ nemo llm pretrain model=baichuan2_7b ...
Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(Baichuan2Model, config=run.Config(Baichuan2Config7B))
def trainer(
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 2,
sequence_parallelism: bool = False,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Baichuan2 7B model.
This function sets up the distributed training strategy and other training parameters.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=baichuan2_7b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
>>> print(trainer_config)
Note:
For more information on distributed training strategies, refer to the
NeMo documentation on multi-GPU and multi-node training.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
),
)
trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
callbacks=callbacks,
devices=num_gpus_per_node,
limit_test_batches=50,
limit_val_batches=32,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=2000,
)
return trainer
@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for Baichuan2 7B model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory baichuan2_7b
$ nemo llm pretrain --factory "baichuan2_7b(num_nodes=2, name='my_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="baichuan2_7b_pretrain", num_nodes=2)
>>> print(recipe)
Note:
For more details on pre-training LLMs with NeMo, see the pre-training
guide in the `examples/llm/pretrain/` directory.
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[run.Config(TimingCallback)],
),
data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)
@run.cli.factory(target=pretrain, name=NAME + "_optimized")
def pretrain_recipe_performance(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Baichuan2 7B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
$ nemo llm pretrain --factory baichuan2_7b_optimized
Python API usage:
>>> recipe = pretrain_recipe_performance(name="baichuan2_7b_perf", num_nodes=4)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=False,
)
)
return recipe
@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
packed_sequence: bool = False,
) -> run.Partial:
"""
Create a fine-tuning recipe for Baichuan2 7B model.
This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None.
packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048.
Returns:
run.Partial: Partial configuration for fine-tuning.
Examples:
CLI usage:
$ nemo llm finetune --factory baichuan2_7b
Python API usage:
>>> recipe = finetune_recipe(name="baichuan2_7b_finetune", num_nodes=2)
>>> print(recipe)
Note:
This recipe uses the SQuAD dataset for fine-tuning. For more information
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
recipe = default_finetune_recipe(
model(), "baichuan-inc/Baichuan2-7B-Base", dir, name, num_nodes, num_gpus_per_node, packed_sequence
)
if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.strategy.tensor_model_parallel_size = 2
recipe.optim.config.lr = 5e-6
elif peft_scheme.lower() == 'lora':
recipe.peft = run.Config(LoRA)
recipe.optim.config.lr = 1e-4
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")
return recipe