diff --git a/fairscale/nn/misc/param_bucket.py b/fairscale/nn/misc/param_bucket.py index 2b811ef8a..e9344b7f0 100644 --- a/fairscale/nn/misc/param_bucket.py +++ b/fairscale/nn/misc/param_bucket.py @@ -73,8 +73,12 @@ def add_param(self, param: torch.Tensor) -> None: @torch.no_grad() def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None: assert self.buffer is not None - assert param.dtype == self.buffer.dtype - assert param.device == self.buffer.device + assert ( + param.dtype == self.buffer.dtype + ), f"Different types for the bucket and the param, cannot proceed: {param.dtype} - {self.buffer.dtype}" + assert ( + param.device == self.buffer.device + ), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}" fill_next = self._fill + param.numel() assert fill_next <= self.buffer.numel() diff --git a/pyproject.toml b/pyproject.toml index a26fbcc19..ae9d8a8cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,4 @@ use_parentheses = true skip_glob = ["build/*", "stubs/*"] # Don't split "import" and "from". force_sort_within_sections = true -known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "helpers", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] +known_third_party = ["benchmark_dataset", "dataclasses", "datasets", "golden_configs", "helpers", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]