Skip to content

Commit

Permalink
Merge pull request #3583 from ZipRecruiter/jeffp.multi-gpu-fixes
Browse files Browse the repository at this point in the history
Multigpu: Fix gradient accumulation and learning rate aggregation
  • Loading branch information
alanakbik authored Dec 19, 2024
2 parents 29feea4 + a2edb9e commit 0becfed
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 29 deletions.
2 changes: 0 additions & 2 deletions flair/trainers/plugins/functional/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def after_training_epoch(self, epoch, **kw):
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

@property
def attach_to_all_processes(self) -> bool:
Expand Down
65 changes: 38 additions & 27 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,37 +656,43 @@ def train_custom(
batch_steps = self.get_batch_steps(batch, mini_batch_chunk_size=mini_batch_chunk_size)

# forward and backward for batch
for batch_step in batch_steps:
# forward pass
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
if multi_gpu:
# We need to __call__ ddp_model() because this triggers hooks that sync gradients.
# But that calls forward rather than forward_loss. So we patch forward to redirect
# to forward_loss. Then undo the patch in case forward_loss itself calls forward.
def wrapped_forward_loss(*args, **kwargs2):
self.model.forward = original_forward
return self.model.forward_loss(*args, **kwargs2)

self.model.forward = wrapped_forward_loss
loss, datapoint_count = self.ddp_model(batch_step)
else:
loss, datapoint_count = self.model.forward_loss(batch_step)

batch_train_samples += datapoint_count
batch_train_loss += loss.item()

self._backward(scaler.scale(loss))

# identify dynamic embeddings (always deleted) on first sentence
if dynamic_embeddings is None:
dynamic_embeddings = identify_dynamic_embeddings(batch)

# depending on memory mode, embeddings are moved to CPU, GPU or deleted
store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings)
for batch_step_no, batch_step in enumerate(batch_steps):
disable_gradient_sync = multi_gpu and batch_step_no < len(batch_steps) - 1
grad_sync = self.ddp_model.no_sync() if disable_gradient_sync else contextlib.nullcontext()
with grad_sync:
# forward pass
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
if multi_gpu:
# We need to __call__ ddp_model() because this triggers hooks that sync gradients.
# But that calls forward rather than forward_loss. So we patch forward to redirect
# to forward_loss. Then undo the patch in case forward_loss itself calls forward.
def wrapped_forward_loss(*args, **kwargs2):
self.model.forward = original_forward
return self.model.forward_loss(*args, **kwargs2)

self.model.forward = wrapped_forward_loss
loss, datapoint_count = self.ddp_model(batch_step)
else:
loss, datapoint_count = self.model.forward_loss(batch_step)

batch_train_samples += datapoint_count
batch_train_loss += loss.item()

self._backward(scaler.scale(loss))

# identify dynamic embeddings (always deleted) on first sentence
if dynamic_embeddings is None:
dynamic_embeddings = identify_dynamic_embeddings(batch)

# depending on memory mode, embeddings are moved to CPU, GPU or deleted
store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings)

self.dispatch("before_training_optimizer_step", **batch_kw)

# do the optimizer step
if multi_gpu:
# DDP averages across processes but we want the sum
self._scale_gradients(torch.distributed.get_world_size())
scaler.unscale_(self.optimizer)
if max_grad_norm is not None:
gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
Expand Down Expand Up @@ -985,3 +991,8 @@ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) ->
self.model.save(model_file, checkpoint)
if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

def _scale_gradients(self, constant):
for param in self.model.parameters():
if param.grad is not None:
param.grad.data.mul_(constant)

0 comments on commit 0becfed

Please sign in to comment.