Skip to content

Commit

Permalink
Add convert_module to FSDP
Browse files Browse the repository at this point in the history
  • Loading branch information
tshu-w committed Oct 6, 2024
1 parent 5dea36c commit 757a579
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
"16-true": torch.float16,
"32-true": torch.float32,
}
self._desired_input_dtype = precision_to_type[self.precision]
self._desired_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_dtype)
return module

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
Expand Down Expand Up @@ -101,7 +107,7 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":

@override
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)
return _DtypeContextManager(self._desired_dtype)

@override
def module_init_context(self) -> ContextManager:
Expand All @@ -115,7 +121,7 @@ def forward_context(self) -> ContextManager:

@override
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

@override
def convert_output(self, data: Any) -> Any:
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_pytorch/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ def test_fsdp_precision_config(precision, expected):
assert config.reduce_dtype == expected[2]


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_convert_module(precision, expected_dtype):
precision = DeepSpeedPrecision(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype


def test_fsdp_precision_default_scaler():
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down

0 comments on commit 757a579

Please sign in to comment.