From 974837897d2cabba669895722ac0c001d5fc783a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 4 Aug 2021 21:03:59 +0200 Subject: [PATCH 1/5] Fix deprecated attribute usage --- .../models/self_supervised/moco/moco2_module.py | 14 ++++++++++---- pl_bolts/utils/__init__.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 7a50fad853..8cd0bfca3c 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -26,7 +26,7 @@ Moco2TrainImagenetTransforms, Moco2TrainSTL10Transforms, ) -from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils import _TORCHVISION_AVAILABLE, _PL_GREATER_EQUAL_1_4 from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -150,7 +150,7 @@ def _momentum_update_key_encoder(self): @torch.no_grad() def _dequeue_and_enqueue(self, keys, queue_ptr, queue): # gather keys before updating queue - if self.trainer.use_ddp or self.trainer.use_ddp2: + if self._use_ddp_or_ddp2(self.trainer): keys = concat_all_gather(keys) batch_size = keys.shape[0] @@ -229,14 +229,14 @@ def forward(self, img_q, img_k, queue): with torch.no_grad(): # no gradient to keys # shuffle for making use of BN - if self.trainer.use_ddp or self.trainer.use_ddp2: + if self._use_ddp_or_ddp2(self.trainer): img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k) k = self.encoder_k(img_k) # keys: NxC k = nn.functional.normalize(k, dim=1) # undo shuffle - if self.trainer.use_ddp or self.trainer.use_ddp2: + if self._use_ddp_or_ddp2(self.trainer): k = self._batch_unshuffle_ddp(k, idx_unshuffle) # compute logits @@ -335,6 +335,12 @@ def add_model_specific_args(parent_parser): return parser + def _use_ddp_or_ddp2(self, trainer: Trainer) -> bool: + # for backwards compatibility + if _PL_GREATER_EQUAL_1_4: + return trainer.accelerator_connector.use_ddp or trainer.accelerator_connector.use_ddp2 + return trainer.use_ddp or trainer.use_ddp2 + # utils @torch.no_grad() diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index f404a6108f..96906a41b2 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -38,5 +38,6 @@ def _compare_version(package: str, op, version) -> bool: _WANDB_AVAILABLE: bool = _module_available("wandb") _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") _TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1") +_PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0") __all__ = ["BatchGradientVerification"] From 3ae927476fd50d7d7a6f67c9112627ad1678ac5f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 4 Aug 2021 21:07:33 +0200 Subject: [PATCH 2/5] More usage --- pl_bolts/models/rl/double_dqn_model.py | 2 +- pl_bolts/models/rl/dqn_model.py | 10 ++++++++-- pl_bolts/models/rl/per_dqn_model.py | 3 ++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/rl/double_dqn_model.py b/pl_bolts/models/rl/double_dqn_model.py index b8f4c34ad2..0b4290964c 100644 --- a/pl_bolts/models/rl/double_dqn_model.py +++ b/pl_bolts/models/rl/double_dqn_model.py @@ -58,7 +58,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict: # calculates training loss loss = double_dqn_loss(batch, self.net, self.target_net, self.gamma) - if self.trainer.use_dp or self.trainer.use_ddp2: + if self._use_dp_or_ddp2(self.trainer): loss = loss.unsqueeze(0) # Soft update of target network diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 9231dfa16f..a4b6008789 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -20,7 +20,7 @@ from pl_bolts.models.rl.common.gym_wrappers import make_environment from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import CNN -from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils import _GYM_AVAILABLE, _PL_GREATER_EQUAL_1_4 from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: @@ -272,7 +272,7 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict: # calculates training loss loss = dqn_loss(batch, self.net, self.target_net, self.gamma) - if self.trainer.use_dp or self.trainer.use_ddp2: + if self._use_dp_or_ddp2(self.trainer): loss = loss.unsqueeze(0) # Soft update of target network @@ -404,6 +404,12 @@ def add_model_specific_args(arg_parser: argparse.ArgumentParser, ) -> argparse.A return arg_parser + def _use_dp_or_ddp2(self, trainer: Trainer) -> bool: + # for backwards compatibility + if _PL_GREATER_EQUAL_1_4: + return trainer.accelerator_connector.use_dp or trainer.accelerator_connector.use_ddp2 + return trainer.use_dp or trainer.use_ddp2 + def cli_main(): parser = argparse.ArgumentParser(add_help=False) diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index ac086a6a22..2853f5c662 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -14,6 +14,7 @@ from pl_bolts.losses.rl import per_dqn_loss from pl_bolts.models.rl.common.memory import Experience, PERBuffer from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.utils import _PL_GREATER_EQUAL_1_4 class PERDQN(DQN): @@ -116,7 +117,7 @@ def training_step(self, batch, _) -> OrderedDict: # calculates training loss loss, batch_weights = per_dqn_loss(samples, weights, self.net, self.target_net, self.gamma) - if self.trainer.use_dp or self.trainer.use_ddp2: + if self._use_dp_or_ddp2(self.trainer): loss = loss.unsqueeze(0) # update priorities in buffer From 46efb279083f9078b2b1e7eaf8dcf07438347930 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Aug 2021 19:10:12 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/models/self_supervised/moco/moco2_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 8cd0bfca3c..9f8432de02 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -26,7 +26,7 @@ Moco2TrainImagenetTransforms, Moco2TrainSTL10Transforms, ) -from pl_bolts.utils import _TORCHVISION_AVAILABLE, _PL_GREATER_EQUAL_1_4 +from pl_bolts.utils import _PL_GREATER_EQUAL_1_4, _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: From f531595fb9920a88579d1b01bbaa77a7323fc404 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 4 Aug 2021 21:10:14 +0200 Subject: [PATCH 4/5] flake8 --- pl_bolts/models/rl/per_dqn_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index 2853f5c662..6a1befd2e6 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -14,7 +14,6 @@ from pl_bolts.losses.rl import per_dqn_loss from pl_bolts.models.rl.common.memory import Experience, PERBuffer from pl_bolts.models.rl.dqn_model import DQN -from pl_bolts.utils import _PL_GREATER_EQUAL_1_4 class PERDQN(DQN): From 82d080ef307df52842a8217d261933b6b0d3c4c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 5 Aug 2021 18:15:16 +0200 Subject: [PATCH 5/5] Apply suggestions from code review --- pl_bolts/models/rl/dqn_model.py | 3 ++- pl_bolts/models/self_supervised/moco/moco2_module.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index a4b6008789..69e1cb3a55 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -404,7 +404,8 @@ def add_model_specific_args(arg_parser: argparse.ArgumentParser, ) -> argparse.A return arg_parser - def _use_dp_or_ddp2(self, trainer: Trainer) -> bool: + @staticmethod + def _use_dp_or_ddp2(trainer: Trainer) -> bool: # for backwards compatibility if _PL_GREATER_EQUAL_1_4: return trainer.accelerator_connector.use_dp or trainer.accelerator_connector.use_ddp2 diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 9f8432de02..9cf89ac7c5 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -335,7 +335,8 @@ def add_model_specific_args(parent_parser): return parser - def _use_ddp_or_ddp2(self, trainer: Trainer) -> bool: + @staticmethod + def _use_ddp_or_ddp2(trainer: Trainer) -> bool: # for backwards compatibility if _PL_GREATER_EQUAL_1_4: return trainer.accelerator_connector.use_ddp or trainer.accelerator_connector.use_ddp2