diff --git a/distributed/FSDP/.gitignore b/distributed/FSDP/.gitignore new file mode 100644 index 0000000000..d3c28380f9 --- /dev/null +++ b/distributed/FSDP/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pt +*.csv \ No newline at end of file diff --git a/distributed/FSDP/README.md b/distributed/FSDP/README.md new file mode 100644 index 0000000000..b6c07c6efe --- /dev/null +++ b/distributed/FSDP/README.md @@ -0,0 +1,24 @@ +## FSDP T5 + +To run the T5 example with FSDP for text summarization: + +## Get the wikihow dataset +```bash + +sh download_dataset.sh + +``` + +## Install the requirements: +~~~ +pip install -r requirements.txt +~~~ +## Ensure you are running a recent version of PyTorch: +see https://pytorch.org to install at least 1.12 and ideally a current nightly build. + +Start the training with Torchrun (adjust nproc_per_node to your GPU count): + +``` +torchrun --nnodes 1 --nproc_per_node 4 T5_training.py + +``` diff --git a/distributed/FSDP/T5_training.py b/distributed/FSDP/T5_training.py new file mode 100644 index 0000000000..1aae5d0990 --- /dev/null +++ b/distributed/FSDP/T5_training.py @@ -0,0 +1,215 @@ +import os +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from transformers import AutoTokenizer, GPT2TokenizerFast +from transformers import T5Tokenizer, T5ForConditionalGeneration +import functools +from torch.optim.lr_scheduler import StepLR +import torch.nn.functional as F +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from transformers.models.t5.modeling_t5 import T5Block + +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + CPUOffload, + MixedPrecision, + BackwardPrefetch, + ShardingStrategy, + FullStateDictConfig, + StateDictType, +) + +from functools import partial +from torch.utils.data import DataLoader +from pathlib import Path +from summarization_dataset import * +import policies +import model_checkpointing +from configs import fsdp_config, train_config +from utils import (bfloat_support, setup, + cleanup, get_date_of_run, + format_metrics_to_gb, + train,validation,setup_model) +from transformers.models.t5.modeling_t5 import T5Block +from typing import Type +import time +import tqdm +from datetime import datetime + + +def get_policies(cfg, rank): + + """establish current policies for mixed precision and fsdp wrapping""" + + mixed_precision_policy = None + wrapping_policy = None + + # mixed precision ----- + if cfg.mixed_precision: + bfloat_available = bfloat_support() + if bfloat_available and not cfg.use_fp16: + mixed_precision_policy = policies.bfSixteen + if rank == 0: + print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + elif cfg.use_fp16: + mixed_precision_policy = policies.fpSixteen + if rank == 0: + print(f"FP16 enabled. ") + else: + # mixed_precision_policy = policies.fpSixteen + print( + f"bFloat16 support not present. Will use FP32, and not mixed precision" + ) + + wrapping_policy = policies.get_t5_wrapper() + + return mixed_precision_policy, wrapping_policy + + +def fsdp_main(args): + + model, tokenizer = setup_model(train_config.model_name) + + local_rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + + + dataset = load_dataset('wikihow', 'all', data_dir='data/') + print(dataset.keys()) + print("Size of train dataset: ", dataset['train'].shape) + print("Size of Validation dataset: ", dataset['validation'].shape) + + + #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False) + train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) + val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False) + + sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True) + sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size) + + setup() + + + train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1} + test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2} + cuda_kwargs = {'num_workers': 2, + 'pin_memory': True, + 'shuffle': False} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) + val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs) + + torch.cuda.set_device(local_rank) + + # Set up FSDP parameters + mixed_precision_policy, t5_auto_wrap_policy = get_policies(train_config, rank) + + # Apply FSDP wrapping to the model + model = FSDP(model, + auto_wrap_policy=t5_auto_wrap_policy, + mixed_precision=mixed_precision_policy, + sharding_strategy=fsdp_config.sharding_strategy, + device_id=torch.cuda.current_device(), + limit_all_gathers=fsdp_config.limit_all_gathers) + + if fsdp_config.fsdp_activation_checkpointing: + policies.apply_fsdp_checkpointing(model) + + # Set up optimizer and scheduler + optimizer = optim.AdamW(model.parameters(), lr=train_config.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) + best_val_loss = float("inf") + curr_val_loss = float("inf") + file_save_name = "T5-model-" + + if rank == 0: + time_of_run = get_date_of_run() + dur = [] + train_acc_tracking = [] + val_acc_tracking = [] + training_start_time = time.time() + + if rank == 0 and args.track_memory: + mem_alloc_tracker = [] + mem_reserved_tracker = [] + + for epoch in range(1, args.epochs + 1): + t0 = time.time() + train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1) + if args.run_validation: + curr_val_loss = validation(model, rank, world_size, val_loader) + scheduler.step() + + if rank == 0: + + print(f"--> epoch {epoch} completed...entering save and stats zone") + + dur.append(time.time() - t0) + train_acc_tracking.append(train_accuracy.item()) + + if args.run_validation: + val_acc_tracking.append(curr_val_loss.item()) + + if args.track_memory: + mem_alloc_tracker.append( + format_metrics_to_gb(torch.cuda.memory_allocated()) + ) + mem_reserved_tracker.append( + format_metrics_to_gb(torch.cuda.memory_reserved()) + ) + + if train_config.save_model and curr_val_loss < best_val_loss: + + if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: + model_checkpointing.save_model_checkpoint( + model, optimizer, rank, fsdp_config, epoch=1 + ) + elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: + model_checkpointing.save_model_and_optimizer_sharded(model, rank, fsdp_config) + if fsdp_config.save_optimizer: + model_checkpointing.save_model_and_optimizer_sharded(model, rank, fsdp_config, optim=optimizer) + + if fsdp_config.save_optimizer: + model_checkpointing.save_optimizer_checkpoint( + model, optimizer, rank, fsdp_config, epoch=1 + ) + if curr_val_loss < best_val_loss: + + best_val_loss = curr_val_loss + if rank==0: + print(f"-->>>> New Val Loss Record: {best_val_loss}") + + dist.barrier() + cleanup() + + +if __name__ == '__main__': + # Training settings + parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example') + parser.add_argument('--batch-size', type=int, default=4, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=4, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=2, metavar='N', + help='number of epochs to train (default: 3)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--track_memory', action='store_false', default=True, + help='track the gpu memory') + parser.add_argument('--run_validation', action='store_false', default=True, + help='running the validation') + args = parser.parse_args() + + torch.manual_seed(args.seed) + + fsdp_main(args) diff --git a/distributed/FSDP/configs/__init__.py b/distributed/FSDP/configs/__init__.py new file mode 100644 index 0000000000..70dba21cc9 --- /dev/null +++ b/distributed/FSDP/configs/__init__.py @@ -0,0 +1,2 @@ +from .fsdp import fsdp_config +from .training import train_config diff --git a/distributed/FSDP/configs/fsdp.py b/distributed/FSDP/configs/fsdp.py new file mode 100644 index 0000000000..301771cd26 --- /dev/null +++ b/distributed/FSDP/configs/fsdp.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import ClassVar +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + +@dataclass +class fsdp_config: + mixed_precision: bool=True + use_fp16: bool=False + seed: int=42 + fsdp_activation_checkpointing: bool=True + limit_all_gathers: bool=True + sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP + checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs + save_optimizer: bool=False + + + + \ No newline at end of file diff --git a/distributed/FSDP/configs/training.py b/distributed/FSDP/configs/training.py new file mode 100644 index 0000000000..99b7bfdceb --- /dev/null +++ b/distributed/FSDP/configs/training.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import ClassVar + + +@dataclass +class train_config: + model_name: str="t5-base" + run_validation: bool=True + batch_size_training: int=4 + num_workers_dataloader: int=2 + lr: float=0.002 + weight_decay: float=0.0 + gamma: float= 0.85 + use_fp16: bool=False + mixed_precision: bool=True + save_model: bool=False + + + \ No newline at end of file diff --git a/distributed/FSDP/download_dataset.sh b/distributed/FSDP/download_dataset.sh new file mode 100644 index 0000000000..f8d3ebd7b4 --- /dev/null +++ b/distributed/FSDP/download_dataset.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# Create the "data" folder if it doesn't exist +mkdir -p data + +# Download the files into the "data" folder +wget -P data https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowAll.csv +wget -P data https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowSep.csv diff --git a/distributed/FSDP/model_checkpointing/__init__.py b/distributed/FSDP/model_checkpointing/__init__.py new file mode 100644 index 0000000000..9af8dbbd54 --- /dev/null +++ b/distributed/FSDP/model_checkpointing/__init__.py @@ -0,0 +1,10 @@ +from .checkpoint_handler import ( + load_model_checkpoint, + save_model_checkpoint, + save_distributed_model_checkpoint, + load_distributed_model_checkpoint, + load_optimizer_checkpoint, + save_optimizer_checkpoint, + save_model_and_optimizer_sharded, + load_model_sharded, +) diff --git a/distributed/FSDP/model_checkpointing/checkpoint_handler.py b/distributed/FSDP/model_checkpointing/checkpoint_handler.py new file mode 100644 index 0000000000..5f6858476f --- /dev/null +++ b/distributed/FSDP/model_checkpointing/checkpoint_handler.py @@ -0,0 +1,307 @@ +from pathlib import Path +from datetime import datetime +import torch +import time + +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + FullStateDictConfig, # general model non-sharded, non-flattened params + LocalStateDictConfig, # flattened params, usable only by FSDP + # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. +) + +from torch.distributed._shard.checkpoint import ( + FileSystemReader, + FileSystemWriter, + save_state_dict, + load_state_dict, +) +from torch.distributed.checkpoint.default_planner import ( + DefaultSavePlanner, + DefaultLoadPlanner, +) + + +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +import torch.distributed._shard.checkpoint as dist_cp +import torch.distributed as dist + + +def get_date_of_run(): + """create date and time for file save uniqueness + example: 2022-05-07-08:31:12_PM' + """ + date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") + print(f"--> current date and time of run = {date_of_run}") + return date_of_run + + +# create singleton saving policies to avoid making over and over +fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + +def load_model_sharded(model, rank, cfg, verbose=True): + # torch.manual_seed(103) + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + load_dir = Path.cwd() / folder_name + + if not load_dir.exists(): + if rank == 0: + print(f"No sharded_state_dict checkpoint directory found...skipping") + return + + reader = FileSystemReader(load_dir) + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + checkpoint = model.state_dict() + if rank == 0: + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + + dist_cp.load_state_dict( + state_dict=checkpoint, + storage_reader=reader, + ) + if rank == 0: + print(f"checkpoint after load_state_dict()") + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + model.load_state_dict(checkpoint) + if rank == 0: + print(f"Sharded state checkpoint loaded from {load_dir}") + + +def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True): + """save model and optimizer via sharded_state_dict to save_dir""" + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + save_dir = Path.cwd() / folder_name + if rank == 0: + print(f"Saving model to {save_dir}") + + distributed_writer = dist_cp.FileSystemWriter( + save_dir, + ) + t0 = time.perf_counter() + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + + state_dict = {"model": model.state_dict()} + if optim is not None: + state_dict["optim"] = FSDP.optim_state_dict(model, optim) + + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=distributed_writer, + planner=DefaultSavePlanner(), + + ) + dist.barrier() + t1 = time.perf_counter() + if rank == 0: + print(f"Sharded state checkpoint saved to {save_dir}") + print( + f"Checkpoint Time = {t1-t0:.4f}\n using {cfg.save_using_num_threads=} total threads" + ) + +def save_model_checkpoint( + model, + optimizer, + rank, + cfg, + epoch=1, +): + """saving model via rank0 cpu streaming and full_state_dict""" + + # saving with rank0 cpu + if not cfg.checkpoint_type == StateDictType.FULL_STATE_DICT: + print(f" unable to handle checkpoint type {cfg.checkpoint_type}, aborting") + + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, fullstate_save_policy + ): + cpu_state = model.state_dict() + + if cfg.verbose: + print(f"saving process: rank {rank} done w model state_dict\n") + + + if rank == 0: + print(f"--> saving model ...") + # create save path + save_dir = Path.cwd() / cfg.checkpoint_folder + save_dir.mkdir(parents=True, exist_ok=True) + save_name = cfg.model_save_name + "-" + str(epoch) + ".pt" + save_full_path = str(save_dir) + "/" + save_name + + # save model + torch.save(cpu_state, save_full_path) + + if cfg.verbose: + print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + + + +def load_model_checkpoint(model, rank, cfg, verbose=True): + """load local checkpoint to rank0 cpu + must be called * before * passing to FSDP""" + + if rank != 0: + return + + # where is the checkpoint at... + full_state_dict_model_path = ( + Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename + ) + # is it present... + if not full_state_dict_model_path.is_file(): + print( + f"model checkpoint {full_state_dict_model_path} not present. Returning..." + ) + return + + + model_checkpoint = torch.load(full_state_dict_model_path) + # integrate into loaded model + model.load_state_dict(model_checkpoint) + + if cfg.verbose: + print(f"model checkpoint loaded to rank0 cpu") + + +def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): + """save optimizer state via full state dict""" + + if cfg.verbose: + print(f"--> optim state call on rank {rank}\n") + + # pull all sharded optimizer states to rank0 cpu... + + optim_state = FSDP.full_optim_state_dict(model, optimizer) + + if cfg.verbose: + print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") + + if rank == 0: + save_dir = Path.cwd() / cfg.checkpoint_folder + save_dir.mkdir(parents=True, exist_ok=True) + + opt_save_name = ( + cfg.optimizer_name + "-" + cfg.model_save_name + "-" + str(epoch) + ".pt" + ) + opt_save_full_path = save_dir / opt_save_name + + print(f"--> saving optimizer state...") + + torch.save(optim_state, opt_save_full_path) + + print(f"--> saved {opt_save_full_path} to disk") + + +def load_optimizer_checkpoint(model, optimizer, rank, cfg): + """load an fdsp optimizer full_state checkpoint using scatter method + this ensures only rank 0 loads the optimizer state dict and scatters to other ranks + """ + + opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file + + if not opt_file_path.is_file(): + print( + f"warning - optimizer checkpoint not present {opt_file_path}. Returning. " + ) + return + + full_osd = None + + if rank == 0: + full_osd = torch.load(opt_file_path) + + if cfg.verbose: + print(f"loaded full osd on rank 0") + + # called from all ranks, though only rank0 has a valid param for full_osd + sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) + + if cfg.verbose: + print(f"optimizer shard loaded on rank {rank}") + + + +def load_distributed_model_checkpoint(model, rank, cfg): + if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT: + print(f"loading distributed checkpoint, rank {rank}...") + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + checkdir = Path.cwd() / folder_name + + if not checkdir.exists(): + if rank == 0: + print(f"No checkpoint directory found...skipping") + return + + + reader = FileSystemReader(checkdir) + + with FSDP.state_dict_type( + model, + StateDictType.LOCAL_STATE_DICT, + ): + state_dict = model.state_dict() + load_state_dict(state_dict, reader) + model.load_state_dict(state_dict) + + print(f"--> local state loaded on rank {rank}") + + return + + +def save_distributed_model_checkpoint(model, rank, cfg, epoch=1): + # distributed checkpoint saving + + # confirm type of checkpoint and save + if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT: + # create writer to current path + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + save_dir = Path.cwd() / folder_name + + writer = FileSystemWriter( + save_dir, + ) + + with FSDP.state_dict_type( + model, + StateDictType.LOCAL_STATE_DICT, + ): + state_dict = model.state_dict() + + + # write out distributed checkpoint + save_state_dict(state_dict, writer) + + return diff --git a/distributed/FSDP/policies/__init__.py b/distributed/FSDP/policies/__init__.py new file mode 100644 index 0000000000..8109e6a747 --- /dev/null +++ b/distributed/FSDP/policies/__init__.py @@ -0,0 +1,3 @@ +from .mixed_precision import * +from .wrapping import * +from .activation_checkpointing_functions import apply_fsdp_checkpointing diff --git a/distributed/FSDP/policies/activation_checkpointing_functions.py b/distributed/FSDP/policies/activation_checkpointing_functions.py new file mode 100644 index 0000000000..e041d26276 --- /dev/null +++ b/distributed/FSDP/policies/activation_checkpointing_functions.py @@ -0,0 +1,31 @@ +import torch +import os +import torch.distributed as dist +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, +) + +from transformers.models.t5.modeling_t5 import T5Block + +from functools import partial + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + +check_fn = lambda submodule: isinstance(submodule, T5Block) + + +def apply_fsdp_checkpointing(model): + """apply activation checkpointing to model + returns None as model is updated directly + """ + print(f"--> applying fdsp activation checkpointing...") + + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) diff --git a/distributed/FSDP/policies/mixed_precision.py b/distributed/FSDP/policies/mixed_precision.py new file mode 100644 index 0000000000..76e59982ec --- /dev/null +++ b/distributed/FSDP/policies/mixed_precision.py @@ -0,0 +1,38 @@ +import torch + +from torch.distributed.fsdp import ( + # FullyShardedDataParallel as FSDP, + # CPUOffload, + MixedPrecision, + # BackwardPrefetch, + # ShardingStrategy, +) + +# requires grad scaler in main loop +fpSixteen = MixedPrecision( + param_dtype=torch.float16, + # Gradient communication precision. + reduce_dtype=torch.float16, + # Buffer precision. + buffer_dtype=torch.float16, +) + +bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + # Gradient communication precision. + reduce_dtype=torch.bfloat16, + # Buffer precision. + buffer_dtype=torch.bfloat16, +) + +bfSixteen_working = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, +) + +fp32_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, +) diff --git a/distributed/FSDP/policies/wrapping.py b/distributed/FSDP/policies/wrapping.py new file mode 100644 index 0000000000..5e1f0d89c8 --- /dev/null +++ b/distributed/FSDP/policies/wrapping.py @@ -0,0 +1,47 @@ +# holds various wrapping policies for fsdp + + +import torch.distributed as dist +import torch.nn as nn +import torch + +from transformers.models.t5.modeling_t5 import T5Block + +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, + CPUOffload, + BackwardPrefetch, + MixedPrecision, +) +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, + size_based_auto_wrap_policy, + enable_wrap, + wrap, +) + +import functools +from typing import Type + + +def get_size_policy(min_params=1e8): + num_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=min_params + ) + return num_wrap_policy + + +def get_t5_wrapper(): + """we register our main layer class and use the fsdp transformer wrapping policy + ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers + """ + # ==== use new transformer wrapper + + t5_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + T5Block, + }, + ) + + return t5_auto_wrap_policy diff --git a/distributed/FSDP/requirements.txt b/distributed/FSDP/requirements.txt new file mode 100644 index 0000000000..a59c5bacb2 --- /dev/null +++ b/distributed/FSDP/requirements.txt @@ -0,0 +1,5 @@ +transformers +datasets +tqdm +protobuf +SentencePiece diff --git a/distributed/FSDP/summarization_dataset.py b/distributed/FSDP/summarization_dataset.py new file mode 100644 index 0000000000..679ea48ec0 --- /dev/null +++ b/distributed/FSDP/summarization_dataset.py @@ -0,0 +1,83 @@ +import argparse +import glob +import os +import json +import time +import logging +import random +import re +from itertools import chain +from string import punctuation + +import pandas as pd +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader + +from datasets import load_dataset, load_metric + + +from transformers import ( + AdamW, + T5ForConditionalGeneration, + T5Tokenizer, + get_linear_schedule_with_warmup +) + +class wikihow(Dataset): + def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False): + self.dataset = load_dataset('wikihow', 'all', data_dir='data/', split=type_path) + if num_samples: + self.dataset = self.dataset.select(list(range(0, num_samples))) + self.input_length = input_length + self.tokenizer = tokenizer + self.output_length = output_length + self.print_text = print_text + + def __len__(self): + return self.dataset.shape[0] + + def clean_text(self, text): + text = text.replace('Example of text:', '') + text = text.replace('Example of Summary:', '') + text = text.replace('\n','') + text = text.replace('``', '') + text = text.replace('"', '') + + return text + + + def convert_to_features(self, example_batch): + # Tokenize contexts and questions (as pairs of inputs) + + if self.print_text: + print("Input Text: ", self.clean_text(example_batch['text'])) +# input_ = self.clean_text(example_batch['text']) + " " +# target_ = self.clean_text(example_batch['headline']) + " " + + input_ = self.clean_text(example_batch['text']) + target_ = self.clean_text(example_batch['headline']) + + source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length, + padding='max_length', truncation=True, return_tensors="pt") + + targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length, + padding='max_length', truncation=True, return_tensors="pt") + + + return source, targets + + def __getitem__(self, index): + source, targets = self.convert_to_features(self.dataset[index]) + + source_ids = source["input_ids"].squeeze() + target_ids = targets["input_ids"].squeeze() + + src_mask = source["attention_mask"].squeeze() + target_mask = targets["attention_mask"].squeeze() + + return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask} + +def get_dataset(tokenizer, type_path, num_samples, args): + return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length, + output_length=max_output_length) diff --git a/distributed/FSDP/utils/__init__.py b/distributed/FSDP/utils/__init__.py new file mode 100644 index 0000000000..c9811ecf98 --- /dev/null +++ b/distributed/FSDP/utils/__init__.py @@ -0,0 +1,4 @@ +from .environment import bfloat_support +from .train_utils import setup, cleanup, get_date_of_run, format_metrics_to_gb, train, validation,setup_model + + \ No newline at end of file diff --git a/distributed/FSDP/utils/environment.py b/distributed/FSDP/utils/environment.py new file mode 100644 index 0000000000..1e00da3acd --- /dev/null +++ b/distributed/FSDP/utils/environment.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022 Meta Platforms, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the Apache-style license found in the +# LICENSE file in the root directory of this source tree. + +# This is a simple check to confirm that your current server has full bfloat support - +# both GPU native support, and Network communication support. + +# Be warned that if you run on V100 without a check like this, you will be running without native Bfloat16 +# support and will find significant performance degradation (but it will not complain via an error). +# Hence the reason for a checker! + +from pkg_resources import packaging +import torch +import torch.cuda.nccl as nccl +import torch.distributed as dist + +# global flag that confirms ampere architecture, cuda version and +# nccl version to verify bfloat16 native support is ready + +def bfloat_support(): + return ( + torch.version.cuda + and torch.cuda.is_bf16_supported() + and packaging.version.parse(torch.version.cuda).release >= (11, 0) + and dist.is_nccl_available() + and nccl.version() >= (2, 10) + ) diff --git a/distributed/FSDP/utils/train_utils.py b/distributed/FSDP/utils/train_utils.py new file mode 100644 index 0000000000..aaf0127dd8 --- /dev/null +++ b/distributed/FSDP/utils/train_utils.py @@ -0,0 +1,102 @@ +import os +import torch +import torch.distributed as dist +from datetime import datetime +import tqdm +from transformers import AutoTokenizer, GPT2TokenizerFast +from transformers import T5Tokenizer, T5ForConditionalGeneration + +g_gigabyte = 1024**3 + +def setup(): + # initialize the process group + dist.init_process_group("nccl") + + +def cleanup(): + dist.destroy_process_group() + +def get_date_of_run(): + """create date and time for file save uniqueness + example: 2022-05-07-08:31:12_PM' + """ + date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") + print(f"--> current date and time of run = {date_of_run}") + return date_of_run + + + +def format_metrics_to_gb(item): + """quick function to format numbers to gigabyte and round to 4 digit precision""" + metric_num = item / g_gigabyte + metric_num = round(metric_num, ndigits=4) + return metric_num + +def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None): + model.train() + local_rank = int(os.environ['LOCAL_RANK']) + fsdp_loss = torch.zeros(2).to(local_rank) + + if sampler: + sampler.set_epoch(epoch) + if rank==0: + inner_pbar = tqdm.tqdm( + range(len(train_loader)), colour="blue", desc="r0 Training Epoch" + ) + for batch in train_loader: + for key in batch.keys(): + batch[key] = batch[key].to(local_rank) + optimizer.zero_grad() + output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] ) + loss = output["loss"] + loss.backward() + optimizer.step() + fsdp_loss[0] += loss.item() + fsdp_loss[1] += len(batch) + if rank==0: + inner_pbar.update(1) + + dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) + train_accuracy = fsdp_loss[0] / fsdp_loss[1] + + + if rank == 0: + inner_pbar.close() + print( + f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}" + ) + return train_accuracy + + +def validation(model, rank, world_size, val_loader): + model.eval() + correct = 0 + local_rank = int(os.environ['LOCAL_RANK']) + fsdp_loss = torch.zeros(3).to(local_rank) + if rank == 0: + inner_pbar = tqdm.tqdm( + range(len(val_loader)), colour="green", desc="Validation Epoch" + ) + with torch.no_grad(): + for batch in val_loader: + for key in batch.keys(): + batch[key] = batch[key].to(local_rank) + output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"]) + fsdp_loss[0] += output["loss"].item() # sum up batch loss + fsdp_loss[1] += len(batch) + + if rank==0: + inner_pbar.update(1) + + dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM) + val_loss = fsdp_loss[0] / fsdp_loss[1] + if rank == 0: + inner_pbar.close() + print(f"Validation Loss: {val_loss:.4f}") + return val_loss + + +def setup_model(model_name): + model = T5ForConditionalGeneration.from_pretrained(model_name) + tokenizer = T5Tokenizer.from_pretrained(model_name) + return model, tokenizer