forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adding FSDP example * adding slurm cluster setup instruction * adding setup model func * added missing features * sumamrizatioon_dataset * Updates training and remove unnecessary imports * updtaing the wrapping policy * Added Zero2 sharding * updates from testing on clean machine * updates from clean machine, add requirements.txt * updates from clean machine * added SentencePiece * removed activation checkpointing and added check for bf16 * clean up * removing cluster setup * fix progress bars, update readme * update progress bars, readme * correct ordering for curr_val_loss evaluation and model save * clean up the dataset links * fixing the dataset links * updates from clean machine * reverting lastest unnecesary changes * moving to a new folder * adding FSDP to dist folder * updates to address comments * adding utils and configs to make the code modular * clean up --------- Co-authored-by: lessw2020 <lessw@etrillium.com>
- Loading branch information
1 parent
79ef786
commit 7b7c708
Showing
18 changed files
with
949 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__/ | ||
*.pt | ||
*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
## FSDP T5 | ||
|
||
To run the T5 example with FSDP for text summarization: | ||
|
||
## Get the wikihow dataset | ||
```bash | ||
|
||
sh download_dataset.sh | ||
|
||
``` | ||
|
||
## Install the requirements: | ||
~~~ | ||
pip install -r requirements.txt | ||
~~~ | ||
## Ensure you are running a recent version of PyTorch: | ||
see https://pytorch.org to install at least 1.12 and ideally a current nightly build. | ||
|
||
Start the training with Torchrun (adjust nproc_per_node to your GPU count): | ||
|
||
``` | ||
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import os | ||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from transformers import AutoTokenizer, GPT2TokenizerFast | ||
from transformers import T5Tokenizer, T5ForConditionalGeneration | ||
import functools | ||
from torch.optim.lr_scheduler import StepLR | ||
import torch.nn.functional as F | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.utils.data.distributed import DistributedSampler | ||
from transformers.models.t5.modeling_t5 import T5Block | ||
|
||
from torch.distributed.fsdp import ( | ||
FullyShardedDataParallel as FSDP, | ||
CPUOffload, | ||
MixedPrecision, | ||
BackwardPrefetch, | ||
ShardingStrategy, | ||
FullStateDictConfig, | ||
StateDictType, | ||
) | ||
|
||
from functools import partial | ||
from torch.utils.data import DataLoader | ||
from pathlib import Path | ||
from summarization_dataset import * | ||
import policies | ||
import model_checkpointing | ||
from configs import fsdp_config, train_config | ||
from utils import (bfloat_support, setup, | ||
cleanup, get_date_of_run, | ||
format_metrics_to_gb, | ||
train,validation,setup_model) | ||
from transformers.models.t5.modeling_t5 import T5Block | ||
from typing import Type | ||
import time | ||
import tqdm | ||
from datetime import datetime | ||
|
||
|
||
def get_policies(cfg, rank): | ||
|
||
"""establish current policies for mixed precision and fsdp wrapping""" | ||
|
||
mixed_precision_policy = None | ||
wrapping_policy = None | ||
|
||
# mixed precision ----- | ||
if cfg.mixed_precision: | ||
bfloat_available = bfloat_support() | ||
if bfloat_available and not cfg.use_fp16: | ||
mixed_precision_policy = policies.bfSixteen | ||
if rank == 0: | ||
print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") | ||
elif cfg.use_fp16: | ||
mixed_precision_policy = policies.fpSixteen | ||
if rank == 0: | ||
print(f"FP16 enabled. ") | ||
else: | ||
# mixed_precision_policy = policies.fpSixteen | ||
print( | ||
f"bFloat16 support not present. Will use FP32, and not mixed precision" | ||
) | ||
|
||
wrapping_policy = policies.get_t5_wrapper() | ||
|
||
return mixed_precision_policy, wrapping_policy | ||
|
||
|
||
def fsdp_main(args): | ||
|
||
model, tokenizer = setup_model(train_config.model_name) | ||
|
||
local_rank = int(os.environ['LOCAL_RANK']) | ||
rank = int(os.environ['RANK']) | ||
world_size = int(os.environ['WORLD_SIZE']) | ||
|
||
|
||
dataset = load_dataset('wikihow', 'all', data_dir='data/') | ||
print(dataset.keys()) | ||
print("Size of train dataset: ", dataset['train'].shape) | ||
print("Size of Validation dataset: ", dataset['validation'].shape) | ||
|
||
|
||
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False) | ||
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) | ||
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False) | ||
|
||
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) | ||
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size) | ||
|
||
setup() | ||
|
||
|
||
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} | ||
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} | ||
cuda_kwargs = {'num_workers': 2, | ||
'pin_memory': True, | ||
'shuffle': False} | ||
train_kwargs.update(cuda_kwargs) | ||
test_kwargs.update(cuda_kwargs) | ||
|
||
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) | ||
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs) | ||
|
||
torch.cuda.set_device(local_rank) | ||
|
||
# Set up FSDP parameters | ||
mixed_precision_policy, t5_auto_wrap_policy = get_policies(train_config, rank) | ||
|
||
# Apply FSDP wrapping to the model | ||
model = FSDP(model, | ||
auto_wrap_policy=t5_auto_wrap_policy, | ||
mixed_precision=mixed_precision_policy, | ||
sharding_strategy=fsdp_config.sharding_strategy, | ||
device_id=torch.cuda.current_device(), | ||
limit_all_gathers=fsdp_config.limit_all_gathers) | ||
|
||
if fsdp_config.fsdp_activation_checkpointing: | ||
policies.apply_fsdp_checkpointing(model) | ||
|
||
# Set up optimizer and scheduler | ||
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr) | ||
|
||
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) | ||
best_val_loss = float("inf") | ||
curr_val_loss = float("inf") | ||
file_save_name = "T5-model-" | ||
|
||
if rank == 0: | ||
time_of_run = get_date_of_run() | ||
dur = [] | ||
train_acc_tracking = [] | ||
val_acc_tracking = [] | ||
training_start_time = time.time() | ||
|
||
if rank == 0 and args.track_memory: | ||
mem_alloc_tracker = [] | ||
mem_reserved_tracker = [] | ||
|
||
for epoch in range(1, args.epochs + 1): | ||
t0 = time.time() | ||
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) | ||
if args.run_validation: | ||
curr_val_loss = validation(model, rank, world_size, val_loader) | ||
scheduler.step() | ||
|
||
if rank == 0: | ||
|
||
print(f"--> epoch {epoch} completed...entering save and stats zone") | ||
|
||
dur.append(time.time() - t0) | ||
train_acc_tracking.append(train_accuracy.item()) | ||
|
||
if args.run_validation: | ||
val_acc_tracking.append(curr_val_loss.item()) | ||
|
||
if args.track_memory: | ||
mem_alloc_tracker.append( | ||
format_metrics_to_gb(torch.cuda.memory_allocated()) | ||
) | ||
mem_reserved_tracker.append( | ||
format_metrics_to_gb(torch.cuda.memory_reserved()) | ||
) | ||
|
||
if train_config.save_model and curr_val_loss < best_val_loss: | ||
|
||
if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: | ||
model_checkpointing.save_model_checkpoint( | ||
model, optimizer, rank, fsdp_config, epoch=1 | ||
) | ||
elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: | ||
model_checkpointing.save_model_and_optimizer_sharded(model, rank, fsdp_config) | ||
if fsdp_config.save_optimizer: | ||
model_checkpointing.save_model_and_optimizer_sharded(model, rank, fsdp_config, optim=optimizer) | ||
|
||
if fsdp_config.save_optimizer: | ||
model_checkpointing.save_optimizer_checkpoint( | ||
model, optimizer, rank, fsdp_config, epoch=1 | ||
) | ||
if curr_val_loss < best_val_loss: | ||
|
||
best_val_loss = curr_val_loss | ||
if rank==0: | ||
print(f"-->>>> New Val Loss Record: {best_val_loss}") | ||
|
||
dist.barrier() | ||
cleanup() | ||
|
||
|
||
if __name__ == '__main__': | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example') | ||
parser.add_argument('--batch-size', type=int, default=4, metavar='N', | ||
help='input batch size for training (default: 64)') | ||
parser.add_argument('--test-batch-size', type=int, default=4, metavar='N', | ||
help='input batch size for testing (default: 1000)') | ||
parser.add_argument('--epochs', type=int, default=2, metavar='N', | ||
help='number of epochs to train (default: 3)') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--track_memory', action='store_false', default=True, | ||
help='track the gpu memory') | ||
parser.add_argument('--run_validation', action='store_false', default=True, | ||
help='running the validation') | ||
args = parser.parse_args() | ||
|
||
torch.manual_seed(args.seed) | ||
|
||
fsdp_main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .fsdp import fsdp_config | ||
from .training import train_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from dataclasses import dataclass, field | ||
from typing import ClassVar | ||
from torch.distributed.fsdp import ShardingStrategy | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | ||
|
||
@dataclass | ||
class fsdp_config: | ||
mixed_precision: bool=True | ||
use_fp16: bool=False | ||
seed: int=42 | ||
fsdp_activation_checkpointing: bool=True | ||
limit_all_gathers: bool=True | ||
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP | ||
checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs | ||
save_optimizer: bool=False | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from dataclasses import dataclass | ||
from typing import ClassVar | ||
|
||
|
||
@dataclass | ||
class train_config: | ||
model_name: str="t5-base" | ||
run_validation: bool=True | ||
batch_size_training: int=4 | ||
num_workers_dataloader: int=2 | ||
lr: float=0.002 | ||
weight_decay: float=0.0 | ||
gamma: float= 0.85 | ||
use_fp16: bool=False | ||
mixed_precision: bool=True | ||
save_model: bool=False | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/bash | ||
|
||
# Create the "data" folder if it doesn't exist | ||
mkdir -p data | ||
|
||
# Download the files into the "data" folder | ||
wget -P data https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowAll.csv | ||
wget -P data https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowSep.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .checkpoint_handler import ( | ||
load_model_checkpoint, | ||
save_model_checkpoint, | ||
save_distributed_model_checkpoint, | ||
load_distributed_model_checkpoint, | ||
load_optimizer_checkpoint, | ||
save_optimizer_checkpoint, | ||
save_model_and_optimizer_sharded, | ||
load_model_sharded, | ||
) |
Oops, something went wrong.