Skip to content

Commit

Permalink
cast to float32
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianlim committed Apr 11, 2024
1 parent c67249f commit fd81eb1
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,22 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)

# NOTE: do we also need to check trainer.args.bf16 and trainer.args.fp16?
# - check the mixed precision setting on the FSDP root wrapper.
if model.mixed_precision:
for module in model.module():
if isinstance(module, FSDP):
# module.params will hold a list of FlatParameter's
for param in module.params:
if param.dtype != torch.float32 and param.device != torch.device("meta"):
# TODO: make the warning issue only once
warnings.warn("Training FSDP with mixed precision. Upcasting parameters to full precision.")

# upcasting to float32 because we are already using mixed precision
# this should be passthrough if dtype already is float32
param.data = param.data.to(torch.float32)

# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
Expand Down

0 comments on commit fd81eb1

Please sign in to comment.