Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multigpu: Fix gradient accumulation and learning rate aggregation #3583

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions flair/trainers/plugins/functional/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def after_training_epoch(self, epoch, **kw):
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

@property
def attach_to_all_processes(self) -> bool:
Expand Down
65 changes: 38 additions & 27 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,37 +656,43 @@ def train_custom(
batch_steps = self.get_batch_steps(batch, mini_batch_chunk_size=mini_batch_chunk_size)

# forward and backward for batch
for batch_step in batch_steps:
# forward pass
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
if multi_gpu:
# We need to __call__ ddp_model() because this triggers hooks that sync gradients.
# But that calls forward rather than forward_loss. So we patch forward to redirect
# to forward_loss. Then undo the patch in case forward_loss itself calls forward.
def wrapped_forward_loss(*args, **kwargs2):
self.model.forward = original_forward
return self.model.forward_loss(*args, **kwargs2)

self.model.forward = wrapped_forward_loss
loss, datapoint_count = self.ddp_model(batch_step)
else:
loss, datapoint_count = self.model.forward_loss(batch_step)

batch_train_samples += datapoint_count
batch_train_loss += loss.item()

self._backward(scaler.scale(loss))

# identify dynamic embeddings (always deleted) on first sentence
if dynamic_embeddings is None:
dynamic_embeddings = identify_dynamic_embeddings(batch)

# depending on memory mode, embeddings are moved to CPU, GPU or deleted
store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings)
for batch_step_no, batch_step in enumerate(batch_steps):
disable_gradient_sync = multi_gpu and batch_step_no < len(batch_steps) - 1
grad_sync = self.ddp_model.no_sync() if disable_gradient_sync else contextlib.nullcontext()
with grad_sync:
# forward pass
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
if multi_gpu:
# We need to __call__ ddp_model() because this triggers hooks that sync gradients.
# But that calls forward rather than forward_loss. So we patch forward to redirect
# to forward_loss. Then undo the patch in case forward_loss itself calls forward.
def wrapped_forward_loss(*args, **kwargs2):
self.model.forward = original_forward
return self.model.forward_loss(*args, **kwargs2)

self.model.forward = wrapped_forward_loss
loss, datapoint_count = self.ddp_model(batch_step)
else:
loss, datapoint_count = self.model.forward_loss(batch_step)

batch_train_samples += datapoint_count
batch_train_loss += loss.item()

self._backward(scaler.scale(loss))

# identify dynamic embeddings (always deleted) on first sentence
if dynamic_embeddings is None:
dynamic_embeddings = identify_dynamic_embeddings(batch)

# depending on memory mode, embeddings are moved to CPU, GPU or deleted
store_embeddings(batch_step, embeddings_storage_mode, dynamic_embeddings)

self.dispatch("before_training_optimizer_step", **batch_kw)

# do the optimizer step
if multi_gpu:
# DDP averages across processes but we want the sum
self._scale_gradients(torch.distributed.get_world_size())
scaler.unscale_(self.optimizer)
if max_grad_norm is not None:
gradient_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
Expand Down Expand Up @@ -985,3 +991,8 @@ def _save_model(self, model_file: Union[str, Path], checkpoint: bool = False) ->
self.model.save(model_file, checkpoint)
if torch.distributed.is_initialized():
torch.distributed.barrier() # Prevent any process from loading a model until writing is complete

def _scale_gradients(self, constant):
for param in self.model.parameters():
if param.grad is not None:
param.grad.data.mul_(constant)
Loading