Skip to content

Commit

Permalink
[doc][c10d] fixup fsdp tutorial (#1297)
Browse files Browse the repository at this point in the history
Summary:
Fix up the FSDP tutorial to get it functional again.
1. Add missing import for load_dataset.
2. Use `checkpoint` instead of `_shard.checkpoint` to get rid of a
   warning.
3. Add nlp to requirements.txt
4. Get rid of `load_metric` as this function does not exist in new
   `datasets` module.
5. Add `legacy=False` to get rid of tokenizer warnings.

Test Plan:
Ran the tutorial as follows and ensured that it ran successfully:
```
torchrun --nnodes=1 --nproc_per_node=2 T5_training.py
W1031 09:46:49.166000 2847649 torch/distributed/run.py:793]
W1031 09:46:49.166000 2847649 torch/distributed/run.py:793]
*****************************************
W1031 09:46:49.166000 2847649 torch/distributed/run.py:793] Setting
OMP_NUM_THREADS environment variable for each process to be 1 in
default, to avoid your system being overloaded, please further tune the
variable for optimal performance in your application as needed.
W1031 09:46:49.166000 2847649 torch/distributed/run.py:793]
*****************************************
dict_keys(['train', 'validation', 'test'])
Size of train dataset:  (157252, 3)
Size of Validation dataset:  (5599, 3)
dict_keys(['train', 'validation', 'test'])
Size of train dataset:  (157252, 3)
Size of Validation dataset:  (5599, 3)
bFloat16 enabled for mixed precision - using bfSixteen policy
```
  • Loading branch information
c-p-i-o authored Nov 8, 2024
1 parent 47d0c2e commit 1bef748
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 41 deletions.
23 changes: 12 additions & 11 deletions distributed/FSDP/T5_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block
from nlp import load_dataset

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
Expand Down Expand Up @@ -86,11 +87,11 @@ def fsdp_main(args):
print("Size of train dataset: ", dataset['train'].shape)
print("Size of Validation dataset: ", dataset['validation'].shape)


#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

Expand All @@ -107,20 +108,20 @@ def fsdp_main(args):

train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

torch.cuda.set_device(local_rank)

# Set up FSDP parameters
mixed_precision_policy, t5_auto_wrap_policy = get_policies(train_config, rank)

# Apply FSDP wrapping to the model
model = FSDP(model,
auto_wrap_policy=t5_auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=fsdp_config.limit_all_gathers)

# Enabling this causes https://github.com/pytorch/examples/issues/1210
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
Expand Down Expand Up @@ -150,7 +151,7 @@ def fsdp_main(args):
if args.run_validation:
curr_val_loss = validation(model, rank, world_size, val_loader)
scheduler.step()

if rank == 0:

print(f"--> epoch {epoch} completed...entering save and stats zone")
Expand All @@ -170,7 +171,7 @@ def fsdp_main(args):
)

if train_config.save_model and curr_val_loss < best_val_loss:

if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
model_checkpointing.save_model_checkpoint(
model, optimizer, rank, fsdp_config, epoch=1
Expand All @@ -183,7 +184,7 @@ def fsdp_main(args):
if fsdp_config.save_optimizer:
model_checkpointing.save_optimizer_checkpoint(
model, optimizer, rank, fsdp_config, epoch=1
)
)
if curr_val_loss < best_val_loss:

best_val_loss = curr_val_loss
Expand Down Expand Up @@ -212,5 +213,5 @@ def fsdp_main(args):
args = parser.parse_args()

torch.manual_seed(args.seed)

fsdp_main(args)
16 changes: 8 additions & 8 deletions distributed/FSDP/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
)

from torch.distributed._shard.checkpoint import (
from torch.distributed.checkpoint import (
FileSystemReader,
FileSystemWriter,
save_state_dict,
Expand All @@ -24,7 +24,7 @@


from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
import torch.distributed._shard.checkpoint as dist_cp
import torch.distributed.checkpoint as dist_cp
import torch.distributed as dist


Expand Down Expand Up @@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
if rank == 0:
ck = checkpoint.keys()
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")

dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=reader,
Expand Down Expand Up @@ -108,7 +108,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
state_dict=state_dict,
storage_writer=distributed_writer,
planner=DefaultSavePlanner(),

)
dist.barrier()
t1 = time.perf_counter()
Expand All @@ -117,7 +117,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
print(
f"Checkpoint Time = {t1-t0:.4f}\n using {cfg.save_using_num_threads=} total threads"
)

def save_model_checkpoint(
model,
optimizer,
Expand All @@ -138,7 +138,7 @@ def save_model_checkpoint(

if cfg.verbose:
print(f"saving process: rank {rank} done w model state_dict\n")


if rank == 0:
print(f"--> saving model ...")
Expand All @@ -153,7 +153,7 @@ def save_model_checkpoint(

if cfg.verbose:
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")



def load_model_checkpoint(model, rank, cfg, verbose=True):
Expand Down Expand Up @@ -299,7 +299,7 @@ def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
StateDictType.LOCAL_STATE_DICT,
):
state_dict = model.state_dict()


# write out distributed checkpoint
save_state_dict(state_dict, writer)
Expand Down
1 change: 1 addition & 0 deletions distributed/FSDP/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ datasets
tqdm
protobuf
SentencePiece
nlp
39 changes: 19 additions & 20 deletions distributed/FSDP/summarization_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import torch
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset, load_metric

from nlp import load_dataset

from transformers import (
AdamW,
Expand All @@ -25,59 +24,59 @@
)

class wikihow(Dataset):
def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):
def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):
self.dataset = load_dataset('wikihow', 'all', data_dir='data/', split=type_path)
if num_samples:
self.dataset = self.dataset.select(list(range(0, num_samples)))
self.input_length = input_length
self.tokenizer = tokenizer
self.output_length = output_length
self.print_text = print_text

def __len__(self):
return self.dataset.shape[0]

def clean_text(self, text):
text = text.replace('Example of text:', '')
text = text.replace('Example of Summary:', '')
text = text.replace('\n','')
text = text.replace('``', '')
text = text.replace('"', '')

return text


def convert_to_features(self, example_batch):
# Tokenize contexts and questions (as pairs of inputs)

if self.print_text:
print("Input Text: ", self.clean_text(example_batch['text']))
# input_ = self.clean_text(example_batch['text']) + " </s>"
# target_ = self.clean_text(example_batch['headline']) + " </s>"

input_ = self.clean_text(example_batch['text'])
target_ = self.clean_text(example_batch['headline'])
source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length,

source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length,
padding='max_length', truncation=True, return_tensors="pt")
targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length,

targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length,
padding='max_length', truncation=True, return_tensors="pt")


return source, targets

def __getitem__(self, index):
source, targets = self.convert_to_features(self.dataset[index])

source_ids = source["input_ids"].squeeze()
target_ids = targets["input_ids"].squeeze()

src_mask = source["attention_mask"].squeeze()
target_mask = targets["attention_mask"].squeeze()

return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}

def get_dataset(tokenizer, type_path, num_samples, args):
return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length,
return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=max_input_length,
output_length=max_output_length)
4 changes: 2 additions & 2 deletions distributed/FSDP/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)

if sampler:
sampler.set_epoch(epoch)
if rank==0:
Expand Down Expand Up @@ -98,5 +98,5 @@ def validation(model, rank, world_size, val_loader):

def setup_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False)
return model, tokenizer

0 comments on commit 1bef748

Please sign in to comment.