From 769cfe33900452a550dcc84c0a5cf89e44570632 Mon Sep 17 00:00:00 2001 From: Lisa Jin Date: Fri, 8 Jul 2022 15:09:06 -0400 Subject: [PATCH] [T104292598] Refactor the "LRA" training code -> Pytorch Lightning (#343) * First attempt at PL trainer * Blocksparse switch revisions (#342) * minor cleanup; updated changelog * fixed mypy error * added checking for blocksparse availability Co-authored-by: Chris Yuan Co-authored-by: Chris Yuan * Finish PL refactor * Fix coding style, remove unused imports * Fix flake8 error * Make isort happy * Let pre-commit handle formatting... * Add type hints, fix eval behavior * Evaluate PL refactor with batch_submit.py Co-authored-by: Chris Yuan Co-authored-by: Chris Yuan --- .../benchmarks/LRA/batch_fetch_results.py | 44 +- xformers/benchmarks/LRA/batch_submit.py | 9 - xformers/benchmarks/LRA/code/model_wrapper.py | 162 +++-- xformers/benchmarks/LRA/run_tasks.py | 632 +++++------------- xformers/benchmarks/LRA/run_with_submitit.py | 2 +- xformers/components/multi_head_dispatch.py | 3 + xformers/factory/model_factory.py | 4 +- 7 files changed, 312 insertions(+), 544 deletions(-) diff --git a/xformers/benchmarks/LRA/batch_fetch_results.py b/xformers/benchmarks/LRA/batch_fetch_results.py index ccb99d0302..88227ac331 100644 --- a/xformers/benchmarks/LRA/batch_fetch_results.py +++ b/xformers/benchmarks/LRA/batch_fetch_results.py @@ -10,16 +10,6 @@ from pathlib import Path from typing import Any, Dict -reference_steps = { - "image": 35176, - "listops": 10000, - "pathfinder32-curv_contour_length_14": 62400, - "pathfinder32-curv_baseline": 62400, - "pathfinder32-curv_contour_length_9": 62400, - "text": 20000, - "retrieval": 30000, -} - if __name__ == "__main__": # Get the user requests parser = argparse.ArgumentParser( @@ -38,10 +28,10 @@ for attention in filter(lambda x: x.is_dir(), root.iterdir()): logging.info(f"\nFound results for {attention.stem}") - task_logs = attention.glob("*/*.log") + task_jsons = attention.glob("*/test_eval_summary.json") results[attention.stem] = {} - for task in filter(lambda x: "__0" in str(x), task_logs): + for task in task_jsons: task_name = task.stem.split("__")[0] logging.info(f"Logs found for task: {task_name}") results[attention.stem][task_name] = -1 @@ -49,25 +39,17 @@ # - collect the individual results with open(task, "r") as result_file: - for line in reversed(result_file.readlines()): - if '"component": "test"' in line: - found_result = True - - # Check that all the steps are done - res = json.loads(line) - - if res["train_step_idx"] == reference_steps[task_name]: - results[attention.stem][task_name] = res["best_accu"] - logging.info( - f"Final result found for {task_name}: {results[attention.stem][task_name]}" - ) - else: - logging.info( - "Current step: {}/{}. Not finished".format( - res["train_step_idx"], reference_steps[task_name] - ) - ) - break + dct = json.load(result_file) + if "test_accu_mean" in dct: + found_result = True + results[attention.stem][task_name] = dct["test_accu_mean"] + + logging.info( + f"Final result found for {task_name} at epoch {dct['train_step_idx']}: " + f"{results[attention.stem][task_name]}" + ) + else: + break # - report an error if no result was found if not found_result: diff --git a/xformers/benchmarks/LRA/batch_submit.py b/xformers/benchmarks/LRA/batch_submit.py index 964a516ce2..a3077aa62a 100644 --- a/xformers/benchmarks/LRA/batch_submit.py +++ b/xformers/benchmarks/LRA/batch_submit.py @@ -37,14 +37,6 @@ def get_default_shared_folder() -> str: parser.add_argument( "--partition", default="a100", type=str, help="Partition where to submit" ) - parser.add_argument( - "-tb", - "--tb_path", - type=str, - help="Path to the tensorboard directory", - dest="tb_dir", - default=f"/{default_checkpoint_path}/{os.getenv('USER')}/xformers/tb", - ) args = parser.parse_args() for attention in args.attentions: @@ -54,5 +46,4 @@ def get_default_shared_folder() -> str: + f" --attention {attention} --task {task} --config {args.config_path}" + f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}" + f" --partition {args.partition}" - + f" --tb_dir {args.tb_dir}/{attention}/{task}" ) diff --git a/xformers/benchmarks/LRA/code/model_wrapper.py b/xformers/benchmarks/LRA/code/model_wrapper.py index cf0be64e18..5eb3e2ca74 100755 --- a/xformers/benchmarks/LRA/code/model_wrapper.py +++ b/xformers/benchmarks/LRA/code/model_wrapper.py @@ -8,7 +8,9 @@ # https://github.com/mlpen/Nystromformer from enum import Enum +from typing import Dict, Union +import pytorch_lightning as pl import torch import torch.nn as nn @@ -17,6 +19,8 @@ from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig from xformers.utils import generate_matching_config +PLOutput = Dict[str, Union[float, torch.Tensor]] + class Pooling(str, Enum): MEAN = "mean" @@ -113,11 +117,12 @@ def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor): return seq_score -class ModelTrunk(nn.Module): +class ModelTrunk(pl.LightningModule): def __init__(self, config, model_name): super().__init__() config_model = config["model"] + self.config_training = config["training"] self.enable_amp = config["training"]["mixed_precision"] self.pooling_mode = Pooling(config_model["pooling_mode"]) @@ -134,6 +139,72 @@ def __init__(self, config, model_name): * ff_config["hidden_layer_multiplier"] ) + def training_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + outputs = self(**batch) + self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore + self.log("train_accu", outputs["accu"], sync_dist=True) + return outputs + + def training_epoch_end(self, outputs): + logs = self.eval_epoch_end(outputs) + self.log("train_accu_mean", logs["accu"], sync_dist=True) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.config_training["learning_rate"], + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=self.config_training["weight_decay"], + ) + + lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=self.config_training["learning_rate"], + pct_start=self.config_training["warmup"] + / self.config_training["num_train_steps"], + anneal_strategy=self.config_training["lr_decay"], + total_steps=self.config_training["num_train_steps"], + ) + + return [optimizer], [lr_scheduler] + + def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput: + outputs = self(**batch) + return outputs + + def eval_epoch_end(self, outputs, prefix: str = "train"): + logs = {} + counts = torch.tensor([x["count"] for x in outputs]).float() + logs["count"] = counts.sum() + for k in ("accu", "loss"): + logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[ + "count" + ] + self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True) + return logs + + def validation_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + outputs = self.eval_step(batch, batch_idx) + self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore + self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True) + return outputs + + def validation_epoch_end(self, outputs): + self.eval_epoch_end(outputs, prefix="val") + + def test_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + return self.eval_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + self.eval_epoch_end(outputs, prefix="test") + class ModelForSC(ModelTrunk): def __init__(self, config, model_name): @@ -146,25 +217,26 @@ def __init__(self, config, model_name): dim_mlp=self.dim_mlp, ) - def forward(self, input_ids_0, mask_0, label): + def forward( # type: ignore + self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor + ): - with torch.cuda.amp.autocast(enabled=self.enable_amp): - if self.pooling_mode == Pooling.CLS: - input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) - token_out = self.norm( - self.model(input_ids_0, encoder_input_mask=mask_0) - ) * mask_0.unsqueeze(-1) + token_out = self.norm( + self.model(input_ids_0, encoder_input_mask=mask_0) + ) * mask_0.unsqueeze(-1) - seq_scores = self.seq_classifer(token_out) + seq_scores = self.seq_classifer(token_out) - seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) - seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) - outputs = { - "loss": seq_loss.mean(), - "accu": seq_accu.mean(), - "count": label.size(0), - } + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } return outputs @@ -180,31 +252,37 @@ def __init__(self, config, model_name): dim_mlp=self.dim_mlp, ) - def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label): - - with torch.cuda.amp.autocast(enabled=self.enable_amp): - mask_0, mask_1 = mask_0.long(), mask_1.long() - - if self.pooling_mode == Pooling.CLS: - input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) - input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size) - - # Concatenate the two inputs into one batch - input_ids = torch.cat([input_ids_0, input_ids_1], dim=0) - masks = torch.cat([mask_0, mask_1], dim=0) - - tokens_out = self.norm( - self.model(input_ids, encoder_input_mask=masks) - ) * masks.unsqueeze(-1) - - seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0)) - - seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) - seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) - outputs = { - "loss": seq_loss.mean(), - "accu": seq_accu.mean(), - "count": label.size(0), - } + def forward( # type: ignore + self, + input_ids_0: torch.Tensor, + input_ids_1: torch.Tensor, + mask_0: torch.Tensor, + mask_1: torch.Tensor, + label: torch.Tensor, + ): + + mask_0, mask_1 = mask_0.long(), mask_1.long() + + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size) + + # Concatenate the two inputs into one batch + input_ids = torch.cat([input_ids_0, input_ids_1], dim=0) + masks = torch.cat([mask_0, mask_1], dim=0) + + tokens_out = self.norm( + self.model(input_ids, encoder_input_mask=masks) + ) * masks.unsqueeze(-1) + + seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0)) + + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } return outputs diff --git a/xformers/benchmarks/LRA/run_tasks.py b/xformers/benchmarks/LRA/run_tasks.py index 96c52357b8..8f6e26edb7 100644 --- a/xformers/benchmarks/LRA/run_tasks.py +++ b/xformers/benchmarks/LRA/run_tasks.py @@ -4,34 +4,25 @@ # LICENSE file in the root directory of this source tree. -# CREDITS: adapted from the Nystromformer repo -# https://github.com/mlpen/Nystromformer - import argparse -import datetime import json import logging -import math import os -import random -import sys -import time -from contextlib import suppress from enum import Enum from pathlib import Path -from typing import Any, Dict +from typing import Dict, Tuple -import numpy as np +import pytorch_lightning as pl import torch -import torch.distributed as dist import torch.nn as nn from fvcore.nn import FlopCountAnalysis, flop_count_str -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.tensorboard import SummaryWriter +from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.strategies import DDPStrategy +from torch.utils.data import DataLoader from xformers.benchmarks.LRA.code.dataset import LRADataset from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual -from xformers.benchmarks.utils import temp_files_ctx from xformers.components.attention import ATTENTION_REGISTRY @@ -60,16 +51,16 @@ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: task = args.task attention_name = args.attention - if task == Task.Retrieval: - model: nn.Module = ModelForSCDual(config[f"{task}"], attention_name) - else: - model = ModelForSC(config[f"{task}"], attention_name) - - args.logger.info(model) - args.logger.info( - f"num_parameter: {np.sum([np.prod(weight.size()) for weight in model.parameters()]) // 1e3 / 1e3}M" + model: pl.LightningModule = ( + ModelForSCDual(config[f"{task}"], attention_name) + if task == Task.Retrieval + else ModelForSC(config[f"{task}"], attention_name) ) + logging.info(model) + summary = pl.utilities.model_summary.LayerSummary(model) + logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M") + with torch.no_grad(): # Check the flops seq_len = config[f"{task}"]["model"]["common"]["seq_len"] @@ -77,433 +68,12 @@ def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: mask = torch.rand(1, seq_len).long() indices = torch.rand(1, seq_len).long() flops = FlopCountAnalysis(model.model, (x, mask, indices)) - args.logger.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") - args.logger.info(flop_count_str(flops)) + logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") + logging.info(flop_count_str(flops)) return model -def build_training_setup( - config_training: Dict, - task: Task, - model: nn.Module, - rank: int = 0, - world_size: int = 1, -): - datasets = {} - samplers = {} - - for component in ["train", "test", "dev"]: - dataset = LRADataset( - file_path=f"datasets/{task}.{component}.pickle", - seq_len=config_training["seq_len"], - ) - - sampler = DistributedSampler( - dataset, - num_replicas=world_size, - rank=rank, - shuffle=(component == "train"), - drop_last=(component == "train"), - ) # type:ignore - datasets[component] = dataset - samplers[component] = sampler - - logging.info(f"Learning rate: {config_training['learning_rate']}") - - optimizer = torch.optim.AdamW( - model.parameters(), - lr=config_training["learning_rate"], - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=config_training["weight_decay"], - ) - - lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( # type: ignore - optimizer=optimizer, - max_lr=config_training["learning_rate"], - pct_start=config_training["warmup"] / config_training["num_train_steps"], - anneal_strategy=config_training["lr_decay"], - total_steps=config_training["num_train_steps"], - ) - - amp_scaler = torch.cuda.amp.GradScaler(enabled=config_training["mixed_precision"]) - - logging.info(f"Dataloader ready. Rank {rank} of {world_size}") - - return datasets, samplers, optimizer, lr_scheduler, amp_scaler - - -def print_summary( - summary, - save_if_improved, - train_step_idx, - model, - checkpoint_path, - logger, - tb_logger=None, -): - - summary["loss"] = np.average(summary["loss"], weights=summary["count"]) - summary["accu"] = np.average(summary["accu"], weights=summary["count"]) - summary["count"] = np.sum(summary["count"]).astype(float) - - if summary["accu"] > summary["best_accu"]: - summary["best_accu"] = summary["accu"] - if save_if_improved: - best_accu = summary["best_accu"] - torch.save( - {"model_state_dict": model.state_dict()}, - checkpoint_path, - ) - logger.info(f"best_accu={best_accu:.3f}. Saved best model") - - summary["max_memory_mb"] = torch.cuda.max_memory_allocated() // 1e3 / 1e3 - - summary_round = {"train_step_idx": train_step_idx} - for key in summary: - if type(summary[key]) is str: - summary_round[key] = summary[key] - else: - summary_round[key] = round(summary[key], 4) - - if tb_logger: - tb_logger.add_scalar("acc", summary["accu"], train_step_idx) - tb_logger.add_scalar("loss", summary["loss"], train_step_idx) - tb_logger.add_scalar("max_mem", summary["max_memory_mb"], train_step_idx) - tb_logger.add_scalar("count", summary["count"], train_step_idx) - - logger.info(summary_round) - logger.info(json.dumps(summary_round, sort_keys=True) + "\n") - - summary["t"] = 0 - summary["loss"] = [] - summary["accu"] = [] - summary["count"] = [] - - -def setup_log(args, rank, attention_name, task): - log_f = Path( - os.path.join( - args.checkpoint_dir, f"{task}__{attention_name}__{rank}_output.log" - ) - ) - if not log_f.exists(): - log_f.parent.mkdir(parents=True, exist_ok=True) - with open(log_f, "x") as _: - pass - - logger = torch.multiprocessing.get_logger() - logger.setLevel(level=logging.INFO) - logger.addHandler(logging.FileHandler(filename=str(log_f))) - if rank == 0: - logger.addHandler(logging.StreamHandler(sys.stdout)) - return log_f.absolute(), logger - - -def eval_model(model, dataloaders, component, config, step): - model.eval() - - for dev_step_idx, batch_dev in enumerate(dataloaders[component]): - _ = step( - batch_dev, - component, - step_idx=dev_step_idx, - step_max=config["num_eval_steps"], - ) - - if dev_step_idx == config["num_eval_steps"]: - break - - model.train() - - -def rewrite_hyper(config, rewrites): - def replace(config_dict, k, v): - if len(k.split(":")) == 1: - config_dict[k] = v - return - first_key = k.split(":")[0] - assert first_key in config_dict, first_key - k = k[len(first_key) + 1 :] - replace(config_dict[first_key], k, v) - - for k, v in rewrites.items(): - replace(config, k, v) - return config - - -def seed_worker(_: int): - # Make sure that non-pytorch random generators are properly set - worker_seed = torch.initial_seed() % 2**32 - np.random.seed(worker_seed) - random.seed(worker_seed) - - -def benchmark(rank, args): - # Setup multiprocessing - dist.init_process_group( - init_method="file://" + args.temp_file, - backend="NCCL", - rank=rank, - world_size=args.world_size, - ) - try: - torch.cuda.set_device(args.gpu) - except AttributeError: - # Single node launcher - torch.cuda.set_device(rank) - - task = args.task - attention_name = args.attention - - # Build the problem - log_f_path, logger = setup_log(args, rank, attention_name, task) - args.logger = logger - config = load_config(args.config) - - config_task = config[f"{task}"] - if args.sweep_parameters is not None: - logger.info("Replacing hyperparameters") - rewrite_hyper(config_task, args.sweep_parameters) - - config_training = config_task["training"] - config_training["seq_len"] = config_task["model"]["common"]["seq_len"] - model = build_model(args, config) - - torch.manual_seed(config_training.get("seed", 0)) # also sets the cuda seed - np.random.seed(config_training.get("seed", 0)) - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.cuda.reset_peak_memory_stats() - - # tensorboard - tb_logger = SummaryWriter(args.tb_dir) - - torch.manual_seed(config_training.get("seed", 0)) # also sets the cuda seed - np.random.seed(config_training.get("seed", 0)) - torch.backends.cudnn.enabled = True - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.cuda.reset_peak_memory_stats() - - # tensorboard - tb_logger = SummaryWriter(args.tb_dir) - - # Setup the training - device_ids = list(range(torch.cuda.device_count())) - logger.info(f"GPU list: {device_ids}") - model = model.cuda() - model = nn.parallel.DistributedDataParallel( - model, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True - ) - - ( - datasets, - samplers, - optimizer, - lr_scheduler, - amp_scaler, - ) = build_training_setup(config_training, task, model, rank, args.world_size) - - init_t = time.time() - - # Messenger structure which will be moved around to collect metrics - summary = { - comp: { - "t": 0, - "loss": [], - "accu": [], - "count": [], - "best_accu": 0, - "component": comp, - } - for comp in ["train", "dev", "test"] - } - - # Setup the dataloaders - accumu_steps = config_task["training"]["gradient_accumulation"] - per_gpu_batch_size = ( - config_training["batch_size"] // args.world_size // accumu_steps - ) - logging.warning( - "Requested batch size: {}. Given world size and grad accumulation, per-gpu batch is {}".format( - config_training["batch_size"], per_gpu_batch_size - ) - ) - - # reset train/eval steps if using gradient accumulation - if accumu_steps > 1: - config_training["num_train_steps"] *= accumu_steps - config_training["num_eval_steps"] *= accumu_steps - - epochs = math.ceil( - config_training["num_train_steps"] - * config_training["batch_size"] - / len(datasets["train"]) - ) - - logging.warning( - "Requested train steps: {}. Given dataset, this translates into {} epochs".format( - config_training["num_train_steps"], epochs - ) - ) - - logger.info(f"accumu_steps={accumu_steps}") - model_path = str(log_f_path).replace(".log", ".model") - g = torch.Generator() - g.manual_seed(config_training.get("seed", 0)) - - dataloaders = { - k: DataLoader( - datasets[k], - sampler=samplers[k], - batch_size=per_gpu_batch_size, - shuffle=False, - pin_memory=True, - num_workers=1, - worker_init_fn=seed_worker, - generator=g, - ) - for k in datasets.keys() - } - - # Our step function - def step( - batch: Dict[str, Any], - component: str, - step_idx: int, - step_max: int, - accumulate: bool = False, - ): - if step_idx > step_max: - logger.warning( - "Calling `step` beyond the training schedule, this is probably a mistake" - ) - return - - t0 = time.time() - batch_size = batch[list(batch.keys())[0]].size(0) - - for key in batch: - batch[key] = batch[key].cuda() - - if component == "train": - acc_context = model.no_sync() if accumulate else suppress() - - with acc_context, torch.autograd.set_detect_anomaly(args.debug): - outputs = model(**batch) - amp_scaler.scale(outputs["loss"]).backward() - - if not accumulate: - amp_scaler.step(optimizer) - optimizer.zero_grad() - amp_scaler.update() - lr_scheduler.step() - - else: - with torch.no_grad(): - outputs = model(**batch) - - t1 = time.time() - - t_escape = t1 - t0 - learning_rate = optimizer.param_groups[0]["lr"] - loss = outputs["loss"].item() - accu = outputs["accu"].item() - cnt = outputs["count"] - time_since_start = time.time() - init_t - eta = ( - datetime.timedelta( - seconds=round(time_since_start / (step_idx + 1) * step_max) - ) - if component == "train" - else -1 - ) - - if not step_idx % 10: - logger.info( - f"{component}: step={step_idx}/{step_max}, total_time={time_since_start:.1f}," - + f" eta={eta}," - + f" batch_time={t_escape:.3f}, bs={batch_size}, lr={learning_rate:.6f}," - + f" loss={loss:.4f}, accu={accu:.4f}", - ) - - summary[component]["t"] += t_escape - summary[component]["loss"].append(loss) - summary[component]["accu"].append(accu) - summary[component]["count"].append(cnt) - - if not accumulate: - step_idx += 1 - - return loss, step_idx - - # Start training or evaluating - train_step_idx = 0 - if not args.skip_train: - try: - model.train() - for epoch in range(epochs): - logger.info(f"\nEpoch {epoch}") - - # Make sure that per-rank sampling is really random - for sampler in samplers.values(): - sampler.set_epoch(epoch) - - for i_batch, batch in enumerate(dataloaders["train"]): - grad_accumulate = ( - i_batch % config_training["gradient_accumulation"] != 0 - ) - - _, train_step_idx = step( - batch, - component="train", - step_idx=train_step_idx, - step_max=config_training["num_train_steps"], - accumulate=grad_accumulate, - ) - - if not (train_step_idx + 1) % config_training["eval_frequency"]: - print_summary( - summary["train"], - False, - train_step_idx, - model, - model_path, - logger, - ) - - eval_model(model, dataloaders, "dev", config_training, step) - - print_summary( - summary["dev"], - True, - train_step_idx, - model, - model_path, - logger, - tb_logger, - ) - - if train_step_idx == config_training["num_train_steps"]: - break - - except KeyboardInterrupt as e: - print(e) - - checkpoint = torch.load(model_path, map_location="cpu") - model.load_state_dict(checkpoint["model_state_dict"]) - model.eval() - try: - eval_model(model, dataloaders, "test", config_training, step) - except StopIteration: - pass - - print_summary(summary["test"], False, train_step_idx, model, model_path, logger) - - def get_arg_parser(): parser = argparse.ArgumentParser() parser.add_argument( @@ -542,6 +112,11 @@ def get_arg_parser(): dest="checkpoint_dir", default=f"/checkpoints/{os.getenv('USER')}/xformers", ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Path to checkpoint", + ) parser.add_argument( "--debug", help="Make it easier to debug a possible issue", @@ -563,23 +138,160 @@ def get_arg_parser(): type=dict, default=None, ) - parser.add_argument( - "--tb_dir", - type=str, - help="Path to the tensorboard directory", - dest="tb_dir", - default=f"/checkpoints/{os.getenv('USER')}/xformers/tb", - ) return parser +def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]: + experiment_name = f"{task}__{attention_name}" + logger = TensorBoardLogger( + save_dir=args.checkpoint_dir, + name="", # remove lightning_logs subdirectory + version=experiment_name, + ) + log_dir = os.path.join(logger._save_dir, experiment_name) + return log_dir, logger + + +def rewrite_hyper(config, rewrites): + def replace(config_dict, k, v): + if len(k.split(":")) == 1: + config_dict[k] = v + return + first_key = k.split(":")[0] + assert first_key in config_dict, first_key + k = k[len(first_key) + 1 :] + replace(config_dict[first_key], k, v) + + for k, v in rewrites.items(): + replace(config, k, v) + return config + + +def build_dataloaders( + args: argparse.Namespace, + config_training: Dict, + num_workers: int = 4, +) -> Dict[str, DataLoader]: + datasets = {} + for component in ("train", "dev", "test"): + datasets[component] = LRADataset( + file_path=f"datasets/{args.task}.{component}.pickle", + seq_len=config_training["seq_len"], + ) + + # Gradient accumulation + accumu_steps = config_training["gradient_accumulation"] + logging.info(f"accumu_steps={accumu_steps}") + + # Batch size + per_gpu_batch_size = ( + config_training["batch_size"] // args.world_size // accumu_steps + ) + logging.warning( + f"Requested batch size: {config_training['batch_size']}. Given world\ + size and grad accumulation, per-gpu batch is\ + {per_gpu_batch_size}" + ) + + dataloaders = { + k: DataLoader( + v, + batch_size=per_gpu_batch_size, + shuffle=False, + pin_memory=True, + num_workers=num_workers, + ) + for k, v in datasets.items() + } + return dataloaders + + +def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]: + eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step} + for k, v in trainer.callback_metrics.items(): + eval_summary[k] = v.item() + return eval_summary + + +class BasicProgressBar(TQDMProgressBar): + def get_metrics(self, trainer, model): + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + + +def benchmark(args): + log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}") + args.logger = logger + + config = load_config(args.config) + + config_task = config[f"{args.task}"] + if args.sweep_parameters is not None: + logging.info("Replacing hyperparameters") + rewrite_hyper(config_task, args.sweep_parameters) + + config_training = config_task["training"] + config_training["seq_len"] = config_task["model"]["common"]["seq_len"] + logging.info(f"Learning rate: {config_training['learning_rate']}") + + pl.seed_everything(config_training.get("seed", 0)) + dataloaders = build_dataloaders(args, config_training) + + model = build_model(args, config) + + progress_bar = BasicProgressBar() + checkpoint_callback = ModelCheckpoint( + monitor="val_accu", + mode="max", + dirpath=args.checkpoint_dir, + filename="{epoch}-{val_accu:.2f}", + every_n_train_steps=config_training["eval_frequency"], + ) + + trainer = pl.Trainer( + accelerator="gpu", + strategy=DDPStrategy(find_unused_parameters=args.debug) + if not args.skip_train + else None, + accumulate_grad_batches=config_training["gradient_accumulation"], + callbacks=[progress_bar, checkpoint_callback], + detect_anomaly=args.debug, + deterministic=True, + gpus=args.world_size, + limit_val_batches=config_training["num_eval_steps"], + logger=logger, + max_steps=config_training["num_train_steps"], + num_sanity_val_steps=int(not args.skip_train), + precision=16 if config_training["mixed_precision"] else 32, + val_check_interval=config_training["eval_frequency"] + / float(len(dataloaders["train"])), + ) + + if not args.skip_train: + trainer.fit( + model, + train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["dev"], + ) + ckpt_path = checkpoint_callback.best_model_path + else: + ckpt_path = args.checkpoint_path + + trainer.test( + model, + dataloaders=dataloaders["test"], + ckpt_path=ckpt_path, + ) + eval_summary = get_eval_summary(trainer) + with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f: + logging.info(f"Saving test results at {f.name}") + json.dump(eval_summary, f) + + if __name__ == "__main__": parser = get_arg_parser() args = parser.parse_args() - setup_log(args, "main", f"{args.attention}", f"{args.task}") - - with temp_files_ctx(num=1) as temp_files: - args.temp_file = temp_files[0] - torch.multiprocessing.spawn( - benchmark, args=(args,), nprocs=args.world_size, join=True - ) + if args.skip_train and args.checkpoint_path is None: + raise parser.error("Must provide --checkpoint_path if --skip_train=True") + benchmark(args) diff --git a/xformers/benchmarks/LRA/run_with_submitit.py b/xformers/benchmarks/LRA/run_with_submitit.py index d7af681e86..13945aac6c 100644 --- a/xformers/benchmarks/LRA/run_with_submitit.py +++ b/xformers/benchmarks/LRA/run_with_submitit.py @@ -76,7 +76,7 @@ def __init__(self, args): def __call__(self): self._setup_gpu_args() - benchmark(self.args.rank, self.args) + benchmark(self.args) def checkpoint(self): self.args.dist_url = get_init_file().as_uri() diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py index 57413f3f92..f3eaffc915 100644 --- a/xformers/components/multi_head_dispatch.py +++ b/xformers/components/multi_head_dispatch.py @@ -31,6 +31,9 @@ class MultiHeadDispatchConfig: use_rotary_embeddings: Optional[bool] out_proj: Optional[nn.Module] + def __getitem__(self, item): + return getattr(self, item) + # Move head forward and fold into batch dim. dimensions become (B * nh, S, hs) def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): diff --git a/xformers/factory/model_factory.py b/xformers/factory/model_factory.py index 0b0d642a46..8b9e6bdcdc 100644 --- a/xformers/factory/model_factory.py +++ b/xformers/factory/model_factory.py @@ -277,7 +277,9 @@ def forward( # Apply the optional input masking if encoder_input_mask is not None: - x += encoder_input_mask.unsqueeze(0).unsqueeze(-1) + if x.dim() - encoder_input_mask.dim() > 1: + encoder_input_mask.unsqueeze(0) + x += encoder_input_mask.unsqueeze(-1) x = encoders(x) memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)