diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 6f7fc0a..7af3b69 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -1,3 +1,4 @@ +from pytorch_lightning.plugins import DeepSpeedPlugin from transformers import ( GPT2LMHeadModel, GPT2TokenizerFast, @@ -651,24 +652,21 @@ def train( if not is_gpu_used: n_gpu = 0 - # use the deepseed plugin if installed and specified + # use the DeepSpeed plugin if installed and specified deepspeed_plugin = None - # if is_gpu_used and use_deepspeed: - # deepspeed_config = gen_deepspeed_config( - # self.get_device(), learning_rate, weight_decay - # ) - # deepspeed_plugin = DeepSpeedPlugin(deepseed_config) - # logger.info("Using DeepSpeed training.") - # logger.warning( - # "deepspeed was attempted to be used, but was not installed. " - # + "Using normal training behavior." - # ) + if is_gpu_used and use_deepspeed: + deepspeed_plugin = DeepSpeedPlugin() + logger.info("Using DeepSpeed training.") + if not fp16: + logger.info("Setting FP16 to True for DeepSpeed ZeRO Training.") + fp16 = True + train_params = dict( accumulate_grad_batches=gradient_accumulation_steps, gpus=n_gpu, max_steps=num_steps, - gradient_clip_val=max_grad_norm if not fp16 else 0, + gradient_clip_val=max_grad_norm, checkpoint_callback=False, logger=loggers if loggers else False, weights_summary=None, diff --git a/aitextgen/utils.py b/aitextgen/utils.py index aebd582..92a20da 100644 --- a/aitextgen/utils.py +++ b/aitextgen/utils.py @@ -172,47 +172,3 @@ def skip_special_tokens(tensor, device, special_token_ids): ~tensor.unsqueeze(1).eq(special_token_id_tensor.unsqueeze(1)).any(1) ].tolist() - -def gen_deepspeed_config(device, lr, weight_decay): - """Deepspeed OneBitAdam config. - - Adapted from https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html#deepspeed - - Args: - device ([type]): Device for training - lr ([type]): Learning rate - weight_decay ([type]): Weight decay - """ - - deepspeed_config = { - "zero_allow_untested_optimizer": True, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": lr, - "betas": [0.998, 0.999], - "eps": 1e-5, - "weight_decay": weight_decay, - "cuda_aware": "cuda" in device, - }, - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "last_batch_iteration": -1, - "warmup_min_lr": 0, - "warmup_max_lr": 3e-5, - "warmup_num_steps": 100, - }, - }, - "zero_optimization": { - "stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning) - "cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU - "contiguous_gradients": True, # Reduce gradient fragmentation. - "overlap_comm": True, # Overlap reduce/backward operation of gradients for speed. - "allgather_bucket_size": 2e8, # Number of elements to all gather at once. - "reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once. - }, - } - - return deepspeed_config diff --git a/requirements.txt b/requirements.txt index c87cc73..202d6ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ transformers>=4.3.0 fire>=0.3.0 -pytorch-lightning>=1.2.0 +pytorch-lightning>=1.2.3 torch>=1.6.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 0586412..767343f 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ install_requires=[ "transformers>=4.3.0", "fire>=0.3.0", - "pytorch-lightning>=1.2.0", + "pytorch-lightning>=1.2.3", "torch>=1.6.0", ], )