From 00b0d3621ab0059e892f2b4320820db1f8cdc3a9 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Tue, 17 Dec 2024 14:20:11 -0800 Subject: [PATCH 1/3] fix: Gradient accumulation skips syncs to increase speed --- flair/trainers/trainer.py | 57 ++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 03879a2b1..6d197a617 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -656,33 +656,36 @@ def train_custom( batch_steps = self.get_batch_steps(batch, mini_batch_chunk_size=mini_batch_chunk_size) # forward and backward for batch - for batch_step in batch_steps: - # forward pass - with torch.autocast(device_type=flair.device.type, enabled=use_amp): - if multi_gpu: - # We need to __call__ ddp_model() because this triggers hooks that sync gradients. - # But that calls forward rather than forward_loss. So we patch forward to redirect - # to forward_loss. Then undo the patch in case forward_loss itself calls forward. - def wrapped_forward_loss(*args, **kwargs2): - self.model.forward = original_forward - return self.model.forward_loss(*args, **kwargs2) - - self.model.forward = wrapped_forward_loss - loss, datapoint_count = self.ddp_model(batch_step) - else: - loss, datapoint_count = self.model.forward_loss(batch_step) - - batch_train_samples += datapoint_count - batch_train_loss += loss.item() - - self._backward(scaler.scale(loss)) - - # identify dynamic embeddings (always deleted) on first sentence - if dynamic_embeddings is None: - dynamic_embeddings = identify_dynamic_embeddings(batch) - - # depending on memory mode, embeddings are moved to CPU, GPU or deleted - store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings) + for batch_step_no, batch_step in enumerate(batch_steps): + skip_sync = multi_gpu and batch_step_no < len(batch_steps) - 1 + gradient_sync = contextlib.nullcontext() if skip_sync else self.ddp_model.no_sync() + with gradient_sync: + # forward pass + with torch.autocast(device_type=flair.device.type, enabled=use_amp): + if multi_gpu: + # We need to __call__ ddp_model() because this triggers hooks that sync gradients. + # But that calls forward rather than forward_loss. So we patch forward to redirect + # to forward_loss. Then undo the patch in case forward_loss itself calls forward. + def wrapped_forward_loss(*args, **kwargs2): + self.model.forward = original_forward + return self.model.forward_loss(*args, **kwargs2) + + self.model.forward = wrapped_forward_loss + loss, datapoint_count = self.ddp_model(batch_step) + else: + loss, datapoint_count = self.model.forward_loss(batch_step) + + batch_train_samples += datapoint_count + batch_train_loss += loss.item() + + self._backward(scaler.scale(loss)) + + # identify dynamic embeddings (always deleted) on first sentence + if dynamic_embeddings is None: + dynamic_embeddings = identify_dynamic_embeddings(batch) + + # depending on memory mode, embeddings are moved to CPU, GPU or deleted + store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings) self.dispatch("before_training_optimizer_step", **batch_kw) From 56b2a9e8bb837f906fa4d552772dfd5db1a409be Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Tue, 17 Dec 2024 15:22:23 -0800 Subject: [PATCH 2/3] fix: Sum learning rate instead of averaging on multi gpu sync --- flair/trainers/trainer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 6d197a617..3313966ce 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -657,9 +657,9 @@ def train_custom( # forward and backward for batch for batch_step_no, batch_step in enumerate(batch_steps): - skip_sync = multi_gpu and batch_step_no < len(batch_steps) - 1 - gradient_sync = contextlib.nullcontext() if skip_sync else self.ddp_model.no_sync() - with gradient_sync: + enable_gradient_sync = multi_gpu and batch_step_no == len(batch_steps) - 1 + sync_context = self.ddp_model.no_sync() if enable_gradient_sync else contextlib.nullcontext() + with sync_context: # forward pass with torch.autocast(device_type=flair.device.type, enabled=use_amp): if multi_gpu: @@ -690,6 +690,8 @@ def wrapped_forward_loss(*args, **kwargs2): self.dispatch("before_training_optimizer_step", **batch_kw) # do the optimizer step + if multi_gpu: + self._scale_gradients(torch.distributed.get_world_size()) # DDP averages across processes but we want the sum scaler.unscale_(self.optimizer) if max_grad_norm is not None: gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm) @@ -988,3 +990,8 @@ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) -> self.model.save(model_file, checkpoint) if torch.distributed.is_initialized(): torch.distributed.barrier() # Prevent any process from loading a model until writing is complete + + def _scale_gradients(self, constant): + for param in self.model.parameters(): + if param.grad is not None: + param.grad.data.mul_(constant) From a2edb9ed889445175560781ca7299dcaf397e443 Mon Sep 17 00:00:00 2001 From: Jeff Picard Date: Tue, 17 Dec 2024 17:31:54 -0800 Subject: [PATCH 3/3] Fix deadlock in checkpoint plugin --- flair/trainers/plugins/functional/checkpoints.py | 2 -- flair/trainers/trainer.py | 9 +++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py index a8179edbc..4261a56a2 100644 --- a/flair/trainers/plugins/functional/checkpoints.py +++ b/flair/trainers/plugins/functional/checkpoints.py @@ -30,8 +30,6 @@ def after_training_epoch(self, epoch, **kw): ) model_name = "model_epoch_" + str(epoch) + ".pt" self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state) - if torch.distributed.is_initialized(): - torch.distributed.barrier() # Prevent any process from loading a model until writing is complete @property def attach_to_all_processes(self) -> bool: diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 3313966ce..4f5d5ff7b 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -657,9 +657,9 @@ def train_custom( # forward and backward for batch for batch_step_no, batch_step in enumerate(batch_steps): - enable_gradient_sync = multi_gpu and batch_step_no == len(batch_steps) - 1 - sync_context = self.ddp_model.no_sync() if enable_gradient_sync else contextlib.nullcontext() - with sync_context: + disable_gradient_sync = multi_gpu and batch_step_no < len(batch_steps) - 1 + grad_sync = self.ddp_model.no_sync() if disable_gradient_sync else contextlib.nullcontext() + with grad_sync: # forward pass with torch.autocast(device_type=flair.device.type, enabled=use_amp): if multi_gpu: @@ -691,7 +691,8 @@ def wrapped_forward_loss(*args, **kwargs2): # do the optimizer step if multi_gpu: - self._scale_gradients(torch.distributed.get_world_size()) # DDP averages across processes but we want the sum + # DDP averages across processes but we want the sum + self._scale_gradients(torch.distributed.get_world_size()) scaler.unscale_(self.optimizer) if max_grad_norm is not None: gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)