diff --git a/ignite/handlers/lr_finder.py b/ignite/handlers/lr_finder.py index ddd53034ca4..2b3e58c05ae 100644 --- a/ignite/handlers/lr_finder.py +++ b/ignite/handlers/lr_finder.py @@ -161,13 +161,19 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f raise ValueError( "if output of the engine is torch.Tensor, then " "it must be 0d torch.Tensor or 1d torch.Tensor with 1 element, " - f"but got torch.Tensor of shape {loss.shape}" + f"but got torch.Tensor of shape {loss.shape}." ) else: raise TypeError( "output of the engine should be of type float or 0d torch.Tensor " "or 1d torch.Tensor with 1 element, " f"but got output of type {type(loss).__name__}" + "You may wish to use the output_transform kwarg with the attach method e.g.\n" + """ + lr_finder = FastaiLRFinder() + with lr_finder.attach(trainer, output_transform=lambda x:x["train_loss"]) as trainer_with_lr_finder: + trainer_with_lr_finder.run(dataloader_train) + """ ) loss = idist.all_reduce(loss) lr = self._lr_schedule.get_param()