From ff2c6137355117fdca2324e2e8c876e974031a84 Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Thu, 24 Oct 2024 16:28:29 -0700 Subject: [PATCH 01/10] Not yet working attempt at MultiGPU --- examples/alpaca/train_multigpu.py | 183 ++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 examples/alpaca/train_multigpu.py diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py new file mode 100644 index 0000000..a4c16ad --- /dev/null +++ b/examples/alpaca/train_multigpu.py @@ -0,0 +1,183 @@ +import copy +import logging +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence + +import torch +import transformers +from torch.utils.data import Dataset +from transformers import Trainer + +from pyreft import ( + TaskType, + get_reft_model, + ReftConfig, + ReftTrainerForCausalLM, + LoreftIntervention, + ReftDataCollator, + ReftSupervisedDataset, + ReftModel +) +import pyvene as pv +import os +import sys +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +def count_parameters(model): + """Count parameters of a model that require gradients""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="yahma/llama-7b-hf") + + +@dataclass +class DataArguments: + data_path: str = field(default=None, metadata={"help": "Path to the training data."}) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + + layers: str = field( + default="all", + metadata={"help": "Intervening layers."}, + ) + position: str = field( + default="f1+l1", + metadata={"help": "Intervening position string."}, + ) + share_weights: bool = field(default=False) + remove_unused_columns: bool = field(default=False) + rank: int = field(default=1) + max_n_train_example: int = field(default=None) + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, model, layers, training_args, data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = ReftSupervisedDataset( + "alpaca", data_args.data_path, tokenizer, data_split="train", seed=training_args.seed, + max_n_example=training_args.max_n_train_example, + input_field="input", instruction_field="instruction", output_field="output", + **{"num_interventions": len(layers), "position": training_args.position, + "share_weights": training_args.share_weights} + ) + data_collator_fn = transformers.DataCollatorForSeq2Seq( + tokenizer=tokenizer, + model=model, + label_pad_token_id=-100, + padding="longest" + ) + data_collator = ReftDataCollator(data_collator=data_collator_fn) + return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) + + +def train(rank, world_size): + device_id = rank + device = torch.device(f'cuda:{device_id}') + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # parsing layers arg + if training_args.layers != "all": + layers = [int(l) for l in training_args.layers.split(";")] + else: + temp_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) + layers = [l for l in range(temp_config.num_hidden_layers)] + if "+" in training_args.position and not training_args.share_weights: + layers += layers + + # get tokenizer + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + tokenizer.pad_token = tokenizer.unk_token + + # get reft model + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + torch_dtype=torch.bfloat16, + device_map=device + ) + representations = [{ + "layer": l, "component": "block_output", + # this is needed for loading although dummy. + "low_rank_dimension": training_args.rank, + "intervention": LoreftIntervention( + embed_dim=model.config.hidden_size, + low_rank_dimension=training_args.rank, + ) + } for l in layers] + + reft_config = ReftConfig(representations=representations) + reft_model = get_reft_model(model, reft_config) + reft_model.print_trainable_parameters() + reft_model_ddp = DDP(reft_model) # , device_ids=[device_id], find_unused_parameters=False) + + # check params and devices + original_params = {name for name, _ in reft_model.named_parameters()} + ddp_params = {name for name, _ in reft_model_ddp.named_parameters()} + + missing_in_ddp = original_params - ddp_params + missing_in_ddp = sorted(missing_in_ddp) + print("Missing in DDP is", missing_in_ddp) + print("Printing original params") + for x in reft_model.named_parameters(): + print(f"{x[0]} -> {x[1].device}") + print("Printing DDP params") + for x in reft_model_ddp.named_parameters(): + print(f"{x[0]} -> {x[1].device}") + + # get training data + data_module = make_supervised_data_module( + tokenizer=tokenizer, model=model, layers=layers, + training_args=training_args, data_args=data_args) + + # train + trainer = ReftTrainerForCausalLM( + model=reft_model_ddp, tokenizer=tokenizer, args=training_args, **data_module) + trainer.train() + print("Rank is", rank) + if rank == 0: + trainer.save_state() + reft_model_ddp.module.save(save_directory=training_args.output_dir) + # uncomment this line to only saving the interventons, + # you need to reinit the reft model with random init + # interventions mounted then load these weights + # trainer.save_model(output_dir=training_args.output_dir) + + # test if we can load. + ReftModel.load(training_args.output_dir, model) + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group("gloo", rank = rank, world_size = world_size) + +def cleanup(): + dist.destroy_process_group() + +def process_fn(rank, world_size): + setup(rank, world_size) + print("Rank", rank, "world size", world_size) + train(rank, world_size) + cleanup() + +if __name__ == "__main__": + assert torch.cuda.is_available(), "MultiGPU script needs CUDA to run" + n_gpus = torch.cuda.device_count() + assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + mp.spawn(process_fn, args=(world_size,), nprocs=world_size, join=True) From be355386c6af63fdfd6f0a0ebb1fb530a7eed40a Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Thu, 24 Oct 2024 16:40:44 -0700 Subject: [PATCH 02/10] More recent debugging --- examples/alpaca/train.py | 3 ++- examples/alpaca/train_multigpu.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/alpaca/train.py b/examples/alpaca/train.py index cff17cc..0255fe9 100644 --- a/examples/alpaca/train.py +++ b/examples/alpaca/train.py @@ -125,6 +125,7 @@ def train(): model=reft_model, tokenizer=tokenizer, args=training_args, **data_module) trainer.train() trainer.save_state() + # print(reft_model) # uncomment this line to only saving the interventons, # you need to reinit the reft model with random init @@ -138,4 +139,4 @@ def train(): if __name__ == "__main__": - train() \ No newline at end of file + train() diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py index a4c16ad..3cf51f0 100644 --- a/examples/alpaca/train_multigpu.py +++ b/examples/alpaca/train_multigpu.py @@ -132,13 +132,14 @@ def train(rank, world_size): missing_in_ddp = original_params - ddp_params missing_in_ddp = sorted(missing_in_ddp) - print("Missing in DDP is", missing_in_ddp) - print("Printing original params") - for x in reft_model.named_parameters(): - print(f"{x[0]} -> {x[1].device}") - print("Printing DDP params") - for x in reft_model_ddp.named_parameters(): - print(f"{x[0]} -> {x[1].device}") + if rank == 0: + print("Missing in DDP is", missing_in_ddp) + print("Printing original params") + for x in reft_model.named_parameters(): + print(f"{x[0]} -> {x[1].device}") + print("Printing DDP params") + for x in reft_model_ddp.named_parameters(): + print(f"{x[0]} -> {x[1].device}") # get training data data_module = make_supervised_data_module( From 02672a2dc28bd3fbbbc48431f03c64b6b576e5be Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Thu, 14 Nov 2024 07:05:17 -0800 Subject: [PATCH 03/10] Initial working multigpu model --- examples/alpaca/train.py | 1 - examples/alpaca/train_multigpu.py | 67 +++++++++++++++---------------- pyreft/__init__.py | 1 + pyreft/reft_trainer.py | 38 ++++++++++++++---- 4 files changed, 64 insertions(+), 43 deletions(-) diff --git a/examples/alpaca/train.py b/examples/alpaca/train.py index 0255fe9..2009a45 100644 --- a/examples/alpaca/train.py +++ b/examples/alpaca/train.py @@ -125,7 +125,6 @@ def train(): model=reft_model, tokenizer=tokenizer, args=training_args, **data_module) trainer.train() trainer.save_state() - # print(reft_model) # uncomment this line to only saving the interventons, # you need to reinit the reft model with random init diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py index 3cf51f0..622cf69 100644 --- a/examples/alpaca/train_multigpu.py +++ b/examples/alpaca/train_multigpu.py @@ -12,7 +12,7 @@ TaskType, get_reft_model, ReftConfig, - ReftTrainerForCausalLM, + ReftTrainerForCausalLMDistributed, LoreftIntervention, ReftDataCollator, ReftSupervisedDataset, @@ -59,6 +59,7 @@ class TrainingArguments(transformers.TrainingArguments): share_weights: bool = field(default=False) remove_unused_columns: bool = field(default=False) rank: int = field(default=1) + # local_rank: int = field(default=-1) max_n_train_example: int = field(default=None) @@ -123,8 +124,12 @@ def train(rank, world_size): reft_config = ReftConfig(representations=representations) reft_model = get_reft_model(model, reft_config) - reft_model.print_trainable_parameters() - reft_model_ddp = DDP(reft_model) # , device_ids=[device_id], find_unused_parameters=False) + reft_model.set_device(device) + reft_model.train() + reft_model.model.train() + reft_model.training = True + reft_model = reft_model.to(rank) + reft_model_ddp = DDP(reft_model, device_ids=[device_id]) # check params and devices original_params = {name for name, _ in reft_model.named_parameters()} @@ -132,29 +137,35 @@ def train(rank, world_size): missing_in_ddp = original_params - ddp_params missing_in_ddp = sorted(missing_in_ddp) - if rank == 0: - print("Missing in DDP is", missing_in_ddp) - print("Printing original params") - for x in reft_model.named_parameters(): - print(f"{x[0]} -> {x[1].device}") - print("Printing DDP params") - for x in reft_model_ddp.named_parameters(): - print(f"{x[0]} -> {x[1].device}") + for param_name in missing_in_ddp: + param = dict(reft_model.named_parameters())[param_name] + new_param_name = param_name.replace(".", "_") + reft_model_ddp.register_parameter(new_param_name, param) + dist.broadcast(param.data, src=0) # Broadcast from rank 0 to other processes + + reft_model_ddp.train() + reft_model_ddp.module.train() + reft_model_ddp.training = True # get training data data_module = make_supervised_data_module( tokenizer=tokenizer, model=model, layers=layers, training_args=training_args, data_args=data_args) - # train - trainer = ReftTrainerForCausalLM( + trainer = ReftTrainerForCausalLMDistributed( model=reft_model_ddp, tokenizer=tokenizer, args=training_args, **data_module) + + # assert all parameters on same device + for (n, p) in trainer.model.named_parameters(): + print(n, p.get_device(), rank) + assert(p.get_device() == rank) + + # train trainer.train() - print("Rank is", rank) if rank == 0: trainer.save_state() reft_model_ddp.module.save(save_directory=training_args.output_dir) - # uncomment this line to only saving the interventons, + # uncomment this line to only save the interventons, # you need to reinit the reft model with random init # interventions mounted then load these weights # trainer.save_model(output_dir=training_args.output_dir) @@ -162,23 +173,11 @@ def train(rank, world_size): # test if we can load. ReftModel.load(training_args.output_dir, model) -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - dist.init_process_group("gloo", rank = rank, world_size = world_size) - -def cleanup(): - dist.destroy_process_group() - -def process_fn(rank, world_size): - setup(rank, world_size) - print("Rank", rank, "world size", world_size) - train(rank, world_size) - cleanup() - if __name__ == "__main__": - assert torch.cuda.is_available(), "MultiGPU script needs CUDA to run" - n_gpus = torch.cuda.device_count() - assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" - world_size = n_gpus - mp.spawn(process_fn, args=(world_size,), nprocs=world_size, join=True) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group("nccl") + rank = dist.get_rank() + print("Starting on rank", rank) + train(rank, -1) + dist.destroy_process_group() + print("Finished on rank", rank) diff --git a/pyreft/__init__.py b/pyreft/__init__.py index e9a0741..bec6ffa 100644 --- a/pyreft/__init__.py +++ b/pyreft/__init__.py @@ -11,6 +11,7 @@ from .reft_trainer import ( ReftTrainer, ReftTrainerForCausalLM, + ReftTrainerForCausalLMDistributed, ReftTrainerForSequenceClassification ) diff --git a/pyreft/reft_trainer.py b/pyreft/reft_trainer.py index 7f0b9b1..0036521 100644 --- a/pyreft/reft_trainer.py +++ b/pyreft/reft_trainer.py @@ -1,6 +1,7 @@ import pyvene as pv import torch.nn as nn -from torch.utils.data import DataLoader +from torch.utils.data.sampler import Sampler +from torch.utils.data import DataLoader, DistributedSampler from transformers import ( Trainer, TrainingArguments, @@ -15,7 +16,7 @@ ) from datasets import Dataset from dataclasses import dataclass -from typing import Dict, Optional, Sequence +from typing import Dict, Optional, Sequence, Union, Iterable from tqdm import tqdm import os @@ -25,6 +26,7 @@ import numpy as np from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.utils import logging +import torch.distributed as dist logger = logging.get_logger(__name__) @@ -52,7 +54,13 @@ def make_data_collator(tokenizer, model) -> ReftDataCollator: return ReftDataCollator(data_collator=data_collator_fn) -def make_dataloader(dataset: Dataset, batch_size: int, collate_fn: DataCollatorForSeq2Seq, shuffle: bool) -> DataLoader: +def make_dataloader( + dataset: Dataset, + batch_size: int, + collate_fn: DataCollatorForSeq2Seq, + shuffle: bool, + sampler: Union[Sampler, Iterable, None]=None +) -> DataLoader: return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=collate_fn) @@ -78,7 +86,8 @@ def compute_loss( inputs, return_outputs=False ): - # run intervened forward pass + rank = dist.get_rank() + device = torch.device(f'cuda:{rank}') unit_locations = None if "intervention_locations" in inputs: if inputs["intervention_locations"].dim() == 3: @@ -91,12 +100,12 @@ def compute_loss( unit_locations={"sources->base": (None, 0)} base_outputs, cf_outputs = intervenable( { - "input_ids": inputs["input_ids"], - "attention_mask": inputs["attention_mask"] + "input_ids": inputs["input_ids"], # .to(device), + "attention_mask": inputs["attention_mask"], # .to(device), }, unit_locations=unit_locations, - labels=inputs["labels"], - subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None + labels=inputs["labels"], # .to(device), + subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None # .to(device) ) # return output = cf_outputs @@ -109,6 +118,19 @@ class ReftTrainerForCausalLM(ReftTrainer): def get_train_dataloader(self) -> DataLoader: return make_dataloader(self.train_dataset, self._train_batch_size, self.data_collator, shuffle=True) +class ReftTrainerForCausalLMDistributed(ReftTrainer): + def save_model(self, output_dir, _internal_call=False): + if dist.get_rank() == 0: + super().save_model(output_dir, _internal_call) + + def get_train_dataloader(self) -> DataLoader: + return make_dataloader( + self.train_dataset, + self._train_batch_size, + self.data_collator, + shuffle=False, + sampler=DistributedSampler(self.train_dataset, shuffle=True), + ) class ReftTrainerForSequenceClassification(ReftTrainer): def compute_loss( From 72b80128e4d5e246b9ee0001d698dd0dc3a128c4 Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Mon, 18 Nov 2024 07:18:14 -0800 Subject: [PATCH 04/10] fix comments --- examples/alpaca/train_multigpu.py | 19 +++++++++++++------ pyreft/reft_trainer.py | 11 +++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py index 622cf69..5f055bd 100644 --- a/examples/alpaca/train_multigpu.py +++ b/examples/alpaca/train_multigpu.py @@ -59,7 +59,6 @@ class TrainingArguments(transformers.TrainingArguments): share_weights: bool = field(default=False) remove_unused_columns: bool = field(default=False) rank: int = field(default=1) - # local_rank: int = field(default=-1) max_n_train_example: int = field(default=None) @@ -147,6 +146,14 @@ def train(rank, world_size): reft_model_ddp.module.train() reft_model_ddp.training = True + # log to wandb from main process only + if rank == 0: + training_args.report_to = ['wandb'] + training_args.run_name = 'multigpu_reft_alpaca_example' + training_args.logging_steps = 1 + else: + training_args.report_to = [] + # get training data data_module = make_supervised_data_module( tokenizer=tokenizer, model=model, layers=layers, @@ -154,24 +161,24 @@ def train(rank, world_size): trainer = ReftTrainerForCausalLMDistributed( model=reft_model_ddp, tokenizer=tokenizer, args=training_args, **data_module) - # assert all parameters on same device for (n, p) in trainer.model.named_parameters(): - print(n, p.get_device(), rank) assert(p.get_device() == rank) # train trainer.train() if rank == 0: + print("Saving") trainer.save_state() reft_model_ddp.module.save(save_directory=training_args.output_dir) # uncomment this line to only save the interventons, # you need to reinit the reft model with random init # interventions mounted then load these weights # trainer.save_model(output_dir=training_args.output_dir) - - # test if we can load. - ReftModel.load(training_args.output_dir, model) + print("Loading") + # test if we can load. + ReftModel.load(training_args.output_dir, model) + print("Complete") if __name__ == "__main__": torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) diff --git a/pyreft/reft_trainer.py b/pyreft/reft_trainer.py index 0036521..e979377 100644 --- a/pyreft/reft_trainer.py +++ b/pyreft/reft_trainer.py @@ -86,8 +86,7 @@ def compute_loss( inputs, return_outputs=False ): - rank = dist.get_rank() - device = torch.device(f'cuda:{rank}') + # run intervened forward pass unit_locations = None if "intervention_locations" in inputs: if inputs["intervention_locations"].dim() == 3: @@ -100,12 +99,12 @@ def compute_loss( unit_locations={"sources->base": (None, 0)} base_outputs, cf_outputs = intervenable( { - "input_ids": inputs["input_ids"], # .to(device), - "attention_mask": inputs["attention_mask"], # .to(device), + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"] }, unit_locations=unit_locations, - labels=inputs["labels"], # .to(device), - subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None # .to(device) + labels=inputs["labels"], + subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None ) # return output = cf_outputs From 18dd78aeb3a31fbea526a15f54e9f8eb99886f2e Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Thu, 24 Oct 2024 16:28:29 -0700 Subject: [PATCH 05/10] Not yet working attempt at MultiGPU --- examples/alpaca/train_multigpu.py | 183 ++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 examples/alpaca/train_multigpu.py diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py new file mode 100644 index 0000000..a4c16ad --- /dev/null +++ b/examples/alpaca/train_multigpu.py @@ -0,0 +1,183 @@ +import copy +import logging +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence + +import torch +import transformers +from torch.utils.data import Dataset +from transformers import Trainer + +from pyreft import ( + TaskType, + get_reft_model, + ReftConfig, + ReftTrainerForCausalLM, + LoreftIntervention, + ReftDataCollator, + ReftSupervisedDataset, + ReftModel +) +import pyvene as pv +import os +import sys +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +def count_parameters(model): + """Count parameters of a model that require gradients""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="yahma/llama-7b-hf") + + +@dataclass +class DataArguments: + data_path: str = field(default=None, metadata={"help": "Path to the training data."}) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + + layers: str = field( + default="all", + metadata={"help": "Intervening layers."}, + ) + position: str = field( + default="f1+l1", + metadata={"help": "Intervening position string."}, + ) + share_weights: bool = field(default=False) + remove_unused_columns: bool = field(default=False) + rank: int = field(default=1) + max_n_train_example: int = field(default=None) + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, model, layers, training_args, data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = ReftSupervisedDataset( + "alpaca", data_args.data_path, tokenizer, data_split="train", seed=training_args.seed, + max_n_example=training_args.max_n_train_example, + input_field="input", instruction_field="instruction", output_field="output", + **{"num_interventions": len(layers), "position": training_args.position, + "share_weights": training_args.share_weights} + ) + data_collator_fn = transformers.DataCollatorForSeq2Seq( + tokenizer=tokenizer, + model=model, + label_pad_token_id=-100, + padding="longest" + ) + data_collator = ReftDataCollator(data_collator=data_collator_fn) + return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) + + +def train(rank, world_size): + device_id = rank + device = torch.device(f'cuda:{device_id}') + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # parsing layers arg + if training_args.layers != "all": + layers = [int(l) for l in training_args.layers.split(";")] + else: + temp_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) + layers = [l for l in range(temp_config.num_hidden_layers)] + if "+" in training_args.position and not training_args.share_weights: + layers += layers + + # get tokenizer + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + tokenizer.pad_token = tokenizer.unk_token + + # get reft model + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + torch_dtype=torch.bfloat16, + device_map=device + ) + representations = [{ + "layer": l, "component": "block_output", + # this is needed for loading although dummy. + "low_rank_dimension": training_args.rank, + "intervention": LoreftIntervention( + embed_dim=model.config.hidden_size, + low_rank_dimension=training_args.rank, + ) + } for l in layers] + + reft_config = ReftConfig(representations=representations) + reft_model = get_reft_model(model, reft_config) + reft_model.print_trainable_parameters() + reft_model_ddp = DDP(reft_model) # , device_ids=[device_id], find_unused_parameters=False) + + # check params and devices + original_params = {name for name, _ in reft_model.named_parameters()} + ddp_params = {name for name, _ in reft_model_ddp.named_parameters()} + + missing_in_ddp = original_params - ddp_params + missing_in_ddp = sorted(missing_in_ddp) + print("Missing in DDP is", missing_in_ddp) + print("Printing original params") + for x in reft_model.named_parameters(): + print(f"{x[0]} -> {x[1].device}") + print("Printing DDP params") + for x in reft_model_ddp.named_parameters(): + print(f"{x[0]} -> {x[1].device}") + + # get training data + data_module = make_supervised_data_module( + tokenizer=tokenizer, model=model, layers=layers, + training_args=training_args, data_args=data_args) + + # train + trainer = ReftTrainerForCausalLM( + model=reft_model_ddp, tokenizer=tokenizer, args=training_args, **data_module) + trainer.train() + print("Rank is", rank) + if rank == 0: + trainer.save_state() + reft_model_ddp.module.save(save_directory=training_args.output_dir) + # uncomment this line to only saving the interventons, + # you need to reinit the reft model with random init + # interventions mounted then load these weights + # trainer.save_model(output_dir=training_args.output_dir) + + # test if we can load. + ReftModel.load(training_args.output_dir, model) + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group("gloo", rank = rank, world_size = world_size) + +def cleanup(): + dist.destroy_process_group() + +def process_fn(rank, world_size): + setup(rank, world_size) + print("Rank", rank, "world size", world_size) + train(rank, world_size) + cleanup() + +if __name__ == "__main__": + assert torch.cuda.is_available(), "MultiGPU script needs CUDA to run" + n_gpus = torch.cuda.device_count() + assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + mp.spawn(process_fn, args=(world_size,), nprocs=world_size, join=True) From d9f7dd5b8b738abf5bcf0d253655648275c04023 Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Thu, 24 Oct 2024 16:40:44 -0700 Subject: [PATCH 06/10] More recent debugging --- examples/alpaca/train.py | 3 ++- examples/alpaca/train_multigpu.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/alpaca/train.py b/examples/alpaca/train.py index cff17cc..0255fe9 100644 --- a/examples/alpaca/train.py +++ b/examples/alpaca/train.py @@ -125,6 +125,7 @@ def train(): model=reft_model, tokenizer=tokenizer, args=training_args, **data_module) trainer.train() trainer.save_state() + # print(reft_model) # uncomment this line to only saving the interventons, # you need to reinit the reft model with random init @@ -138,4 +139,4 @@ def train(): if __name__ == "__main__": - train() \ No newline at end of file + train() diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py index a4c16ad..3cf51f0 100644 --- a/examples/alpaca/train_multigpu.py +++ b/examples/alpaca/train_multigpu.py @@ -132,13 +132,14 @@ def train(rank, world_size): missing_in_ddp = original_params - ddp_params missing_in_ddp = sorted(missing_in_ddp) - print("Missing in DDP is", missing_in_ddp) - print("Printing original params") - for x in reft_model.named_parameters(): - print(f"{x[0]} -> {x[1].device}") - print("Printing DDP params") - for x in reft_model_ddp.named_parameters(): - print(f"{x[0]} -> {x[1].device}") + if rank == 0: + print("Missing in DDP is", missing_in_ddp) + print("Printing original params") + for x in reft_model.named_parameters(): + print(f"{x[0]} -> {x[1].device}") + print("Printing DDP params") + for x in reft_model_ddp.named_parameters(): + print(f"{x[0]} -> {x[1].device}") # get training data data_module = make_supervised_data_module( From bd68d522fe52a2bfff5a7d9f586ebb34acb29bd4 Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Thu, 14 Nov 2024 07:05:17 -0800 Subject: [PATCH 07/10] Initial working multigpu model --- examples/alpaca/train.py | 1 - examples/alpaca/train_multigpu.py | 67 +++++++++++++++---------------- pyreft/__init__.py | 1 + pyreft/reft_trainer.py | 38 ++++++++++++++---- 4 files changed, 64 insertions(+), 43 deletions(-) diff --git a/examples/alpaca/train.py b/examples/alpaca/train.py index 0255fe9..2009a45 100644 --- a/examples/alpaca/train.py +++ b/examples/alpaca/train.py @@ -125,7 +125,6 @@ def train(): model=reft_model, tokenizer=tokenizer, args=training_args, **data_module) trainer.train() trainer.save_state() - # print(reft_model) # uncomment this line to only saving the interventons, # you need to reinit the reft model with random init diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py index 3cf51f0..622cf69 100644 --- a/examples/alpaca/train_multigpu.py +++ b/examples/alpaca/train_multigpu.py @@ -12,7 +12,7 @@ TaskType, get_reft_model, ReftConfig, - ReftTrainerForCausalLM, + ReftTrainerForCausalLMDistributed, LoreftIntervention, ReftDataCollator, ReftSupervisedDataset, @@ -59,6 +59,7 @@ class TrainingArguments(transformers.TrainingArguments): share_weights: bool = field(default=False) remove_unused_columns: bool = field(default=False) rank: int = field(default=1) + # local_rank: int = field(default=-1) max_n_train_example: int = field(default=None) @@ -123,8 +124,12 @@ def train(rank, world_size): reft_config = ReftConfig(representations=representations) reft_model = get_reft_model(model, reft_config) - reft_model.print_trainable_parameters() - reft_model_ddp = DDP(reft_model) # , device_ids=[device_id], find_unused_parameters=False) + reft_model.set_device(device) + reft_model.train() + reft_model.model.train() + reft_model.training = True + reft_model = reft_model.to(rank) + reft_model_ddp = DDP(reft_model, device_ids=[device_id]) # check params and devices original_params = {name for name, _ in reft_model.named_parameters()} @@ -132,29 +137,35 @@ def train(rank, world_size): missing_in_ddp = original_params - ddp_params missing_in_ddp = sorted(missing_in_ddp) - if rank == 0: - print("Missing in DDP is", missing_in_ddp) - print("Printing original params") - for x in reft_model.named_parameters(): - print(f"{x[0]} -> {x[1].device}") - print("Printing DDP params") - for x in reft_model_ddp.named_parameters(): - print(f"{x[0]} -> {x[1].device}") + for param_name in missing_in_ddp: + param = dict(reft_model.named_parameters())[param_name] + new_param_name = param_name.replace(".", "_") + reft_model_ddp.register_parameter(new_param_name, param) + dist.broadcast(param.data, src=0) # Broadcast from rank 0 to other processes + + reft_model_ddp.train() + reft_model_ddp.module.train() + reft_model_ddp.training = True # get training data data_module = make_supervised_data_module( tokenizer=tokenizer, model=model, layers=layers, training_args=training_args, data_args=data_args) - # train - trainer = ReftTrainerForCausalLM( + trainer = ReftTrainerForCausalLMDistributed( model=reft_model_ddp, tokenizer=tokenizer, args=training_args, **data_module) + + # assert all parameters on same device + for (n, p) in trainer.model.named_parameters(): + print(n, p.get_device(), rank) + assert(p.get_device() == rank) + + # train trainer.train() - print("Rank is", rank) if rank == 0: trainer.save_state() reft_model_ddp.module.save(save_directory=training_args.output_dir) - # uncomment this line to only saving the interventons, + # uncomment this line to only save the interventons, # you need to reinit the reft model with random init # interventions mounted then load these weights # trainer.save_model(output_dir=training_args.output_dir) @@ -162,23 +173,11 @@ def train(rank, world_size): # test if we can load. ReftModel.load(training_args.output_dir, model) -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - dist.init_process_group("gloo", rank = rank, world_size = world_size) - -def cleanup(): - dist.destroy_process_group() - -def process_fn(rank, world_size): - setup(rank, world_size) - print("Rank", rank, "world size", world_size) - train(rank, world_size) - cleanup() - if __name__ == "__main__": - assert torch.cuda.is_available(), "MultiGPU script needs CUDA to run" - n_gpus = torch.cuda.device_count() - assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" - world_size = n_gpus - mp.spawn(process_fn, args=(world_size,), nprocs=world_size, join=True) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group("nccl") + rank = dist.get_rank() + print("Starting on rank", rank) + train(rank, -1) + dist.destroy_process_group() + print("Finished on rank", rank) diff --git a/pyreft/__init__.py b/pyreft/__init__.py index e9a0741..bec6ffa 100644 --- a/pyreft/__init__.py +++ b/pyreft/__init__.py @@ -11,6 +11,7 @@ from .reft_trainer import ( ReftTrainer, ReftTrainerForCausalLM, + ReftTrainerForCausalLMDistributed, ReftTrainerForSequenceClassification ) diff --git a/pyreft/reft_trainer.py b/pyreft/reft_trainer.py index 9cb8f21..143977f 100644 --- a/pyreft/reft_trainer.py +++ b/pyreft/reft_trainer.py @@ -1,6 +1,7 @@ import pyvene as pv import torch.nn as nn -from torch.utils.data import DataLoader +from torch.utils.data.sampler import Sampler +from torch.utils.data import DataLoader, DistributedSampler from transformers import ( Trainer, TrainingArguments, @@ -15,7 +16,7 @@ ) from datasets import Dataset from dataclasses import dataclass -from typing import Dict, Optional, Sequence +from typing import Dict, Optional, Sequence, Union, Iterable from tqdm import tqdm import os @@ -25,6 +26,7 @@ import numpy as np from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.utils import logging +import torch.distributed as dist logger = logging.get_logger(__name__) @@ -52,7 +54,13 @@ def make_data_collator(tokenizer, model) -> ReftDataCollator: return ReftDataCollator(data_collator=data_collator_fn) -def make_dataloader(dataset: Dataset, batch_size: int, collate_fn: DataCollatorForSeq2Seq, shuffle: bool) -> DataLoader: +def make_dataloader( + dataset: Dataset, + batch_size: int, + collate_fn: DataCollatorForSeq2Seq, + shuffle: bool, + sampler: Union[Sampler, Iterable, None]=None +) -> DataLoader: return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=collate_fn) @@ -78,7 +86,8 @@ def compute_loss( inputs, return_outputs=False ): - # run intervened forward pass + rank = dist.get_rank() + device = torch.device(f'cuda:{rank}') unit_locations = None if "intervention_locations" in inputs: if inputs["intervention_locations"].dim() == 3: @@ -91,12 +100,12 @@ def compute_loss( unit_locations={"sources->base": (None, 0)} base_outputs, cf_outputs = intervenable( { - "input_ids": inputs["input_ids"], - "attention_mask": inputs["attention_mask"] + "input_ids": inputs["input_ids"], # .to(device), + "attention_mask": inputs["attention_mask"], # .to(device), }, unit_locations=unit_locations, - labels=inputs["labels"], - subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None + labels=inputs["labels"], # .to(device), + subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None # .to(device) ) # return output = cf_outputs @@ -109,6 +118,19 @@ class ReftTrainerForCausalLM(ReftTrainer): def get_train_dataloader(self) -> DataLoader: return make_dataloader(self.train_dataset, self._train_batch_size, self.data_collator, shuffle=True) +class ReftTrainerForCausalLMDistributed(ReftTrainer): + def save_model(self, output_dir, _internal_call=False): + if dist.get_rank() == 0: + super().save_model(output_dir, _internal_call) + + def get_train_dataloader(self) -> DataLoader: + return make_dataloader( + self.train_dataset, + self._train_batch_size, + self.data_collator, + shuffle=False, + sampler=DistributedSampler(self.train_dataset, shuffle=True), + ) class ReftTrainerForSequenceClassification(ReftTrainer): def compute_loss( From 8c9b5df552504c86d1cb7c3a7536f48e2ee3c42a Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Mon, 18 Nov 2024 07:18:14 -0800 Subject: [PATCH 08/10] fix comments --- examples/alpaca/train_multigpu.py | 19 +++++++++++++------ pyreft/reft_trainer.py | 11 +++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py index 622cf69..5f055bd 100644 --- a/examples/alpaca/train_multigpu.py +++ b/examples/alpaca/train_multigpu.py @@ -59,7 +59,6 @@ class TrainingArguments(transformers.TrainingArguments): share_weights: bool = field(default=False) remove_unused_columns: bool = field(default=False) rank: int = field(default=1) - # local_rank: int = field(default=-1) max_n_train_example: int = field(default=None) @@ -147,6 +146,14 @@ def train(rank, world_size): reft_model_ddp.module.train() reft_model_ddp.training = True + # log to wandb from main process only + if rank == 0: + training_args.report_to = ['wandb'] + training_args.run_name = 'multigpu_reft_alpaca_example' + training_args.logging_steps = 1 + else: + training_args.report_to = [] + # get training data data_module = make_supervised_data_module( tokenizer=tokenizer, model=model, layers=layers, @@ -154,24 +161,24 @@ def train(rank, world_size): trainer = ReftTrainerForCausalLMDistributed( model=reft_model_ddp, tokenizer=tokenizer, args=training_args, **data_module) - # assert all parameters on same device for (n, p) in trainer.model.named_parameters(): - print(n, p.get_device(), rank) assert(p.get_device() == rank) # train trainer.train() if rank == 0: + print("Saving") trainer.save_state() reft_model_ddp.module.save(save_directory=training_args.output_dir) # uncomment this line to only save the interventons, # you need to reinit the reft model with random init # interventions mounted then load these weights # trainer.save_model(output_dir=training_args.output_dir) - - # test if we can load. - ReftModel.load(training_args.output_dir, model) + print("Loading") + # test if we can load. + ReftModel.load(training_args.output_dir, model) + print("Complete") if __name__ == "__main__": torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) diff --git a/pyreft/reft_trainer.py b/pyreft/reft_trainer.py index 143977f..90286aa 100644 --- a/pyreft/reft_trainer.py +++ b/pyreft/reft_trainer.py @@ -86,8 +86,7 @@ def compute_loss( inputs, return_outputs=False ): - rank = dist.get_rank() - device = torch.device(f'cuda:{rank}') + # run intervened forward pass unit_locations = None if "intervention_locations" in inputs: if inputs["intervention_locations"].dim() == 3: @@ -100,12 +99,12 @@ def compute_loss( unit_locations={"sources->base": (None, 0)} base_outputs, cf_outputs = intervenable( { - "input_ids": inputs["input_ids"], # .to(device), - "attention_mask": inputs["attention_mask"], # .to(device), + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"] }, unit_locations=unit_locations, - labels=inputs["labels"], # .to(device), - subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None # .to(device) + labels=inputs["labels"], + subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None ) # return output = cf_outputs From 005b9dfbca187d6ebbf4a2c66588b8d66d27bf30 Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Mon, 18 Nov 2024 16:48:49 -0800 Subject: [PATCH 09/10] Re-include sampler change --- pyreft/reft_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyreft/reft_trainer.py b/pyreft/reft_trainer.py index 90286aa..1afca4c 100644 --- a/pyreft/reft_trainer.py +++ b/pyreft/reft_trainer.py @@ -61,7 +61,7 @@ def make_dataloader( shuffle: bool, sampler: Union[Sampler, Iterable, None]=None ) -> DataLoader: - return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=collate_fn) + return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, sampler=sampler, collate_fn=collate_fn) class ReftTrainer(Trainer): From 539ff0efdb241c655851d4bf7033141e67b4a5f9 Mon Sep 17 00:00:00 2001 From: Ramgopal Venkateswaran Date: Fri, 13 Dec 2024 22:59:18 -0800 Subject: [PATCH 10/10] Fix seed for both train and train multigpu --- examples/alpaca/train.py | 4 ++++ examples/alpaca/train_multigpu.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/examples/alpaca/train.py b/examples/alpaca/train.py index 2009a45..f452d95 100644 --- a/examples/alpaca/train.py +++ b/examples/alpaca/train.py @@ -74,6 +74,10 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, mod def train(): + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() diff --git a/examples/alpaca/train_multigpu.py b/examples/alpaca/train_multigpu.py index 5f055bd..76a724f 100644 --- a/examples/alpaca/train_multigpu.py +++ b/examples/alpaca/train_multigpu.py @@ -82,6 +82,10 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, mod def train(rank, world_size): + seed = 42 + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + device_id = rank device = torch.device(f'cuda:{device_id}') parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))