diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index 6771c5db4b..3f19df7d2a 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -933,7 +933,8 @@ def device_mesh__getitem__(self, mesh_dim_names: Union[str, tuple[str]]) -> 'Dev return submesh else: - from torch.distributed.device_mesh import _mesh_resources + from torch.utils._typing_utils import not_none + from torch.distributed.device_mesh import DeviceMesh, _mesh_resources def create_child_mesh( self, parent_mesh: 'DeviceMesh', submesh_dim_names: Tuple[str, ...],