Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convergence issue with LR Finder #14824

Closed
awaelchli opened this issue Sep 21, 2022 · 2 comments
Closed

Convergence issue with LR Finder #14824

awaelchli opened this issue Sep 21, 2022 · 2 comments
Labels
breaking change Includes a breaking change tuner

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Sep 21, 2022

Bug description

Prior to #14113 the model below would converge to the optimal value using the learning rate suggested by the LR Finder. On the latest version (master), this is no longer the case. The PR #14113 introduced a regression.

How to reproduce the bug

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader, TensorDataset

import lightning as L
from pytorch_lightning.utilities.seed import seed_everything

seed_everything(seed=42)


class BasicLightningTrain(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.w00 = nn.Parameter(torch.tensor(1.7), requires_grad=False)
        self.b00 = nn.Parameter(torch.tensor(-0.85), requires_grad=False)
        self.w01 = nn.Parameter(torch.tensor(-40.8), requires_grad=False)

        self.w10 = nn.Parameter(torch.tensor(12.6), requires_grad=False)
        self.b10 = nn.Parameter(torch.tensor(0.0), requires_grad=False)
        self.w11 = nn.Parameter(torch.tensor(2.7), requires_grad=False)

        self.final_bias = nn.Parameter(torch.tensor(0.0), requires_grad=True)  # optimal value should be -16.0

        self.learning_rate = 0.01

    def forward(self, input):

        input_to_top_relu = input * self.w00 + self.b00
        top_relu_output = F.relu(input_to_top_relu)
        scaled_top_relu_output = top_relu_output * self.w01

        input_to_bottom_relu = input * self.w10 + self.b10
        bottom_relu_output = F.relu(input_to_bottom_relu)
        scaled_bottom_relu_output = bottom_relu_output * self.w11

        input_to_final_relu = scaled_top_relu_output + scaled_bottom_relu_output + self.final_bias

        output = F.relu(input_to_final_relu)

        return output

    def configure_optimizers(self):
        return SGD(self.parameters(), lr=self.learning_rate)

    def training_step(self, batch, batch_idx):

        input_i, label_i = batch
        output_i = self.forward(input_i)
        loss = (output_i - label_i) ** 2
        return loss


model = BasicLightningTrain()

inputs = torch.tensor([0.0, 0.5, 1.0])
labels = torch.tensor([0.0, 1.0, 0.0])

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

trainer = L.Trainer(max_epochs=34)

lr_find_results = trainer.tuner.lr_find(
    model, train_dataloaders=dataloader, min_lr=0.001, max_lr=1.0, early_stop_threshold=None
)
new_lr = lr_find_results.suggestion()

print(f"lr_find() suggests {new_lr:.5f} for the learning rate.")  # 0.00214

model.learning_rate = new_lr
trainer.fit(model, train_dataloaders=dataloader)

# v1.7.3 or below:  tensor(-16.0098)
# master:           tensor(-2.1706)
print(model.final_bias.data)

Error messages and logs

The model does not longer converge to the optimal value, despite the LRFinder choosing the same optimal learning rate as in the previous version of Lightning.

Important info


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer
#- PyTorch Lightning Version (e.g., 1.5.0): 1.7.6 and master 1.8.0dev
#- Lightning App Version (e.g., 0.5.2): -
#- PyTorch Version (e.g., 1.10): 1.11
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): MacOS
#- CUDA/cuDNN version: - 
#- GPU models and configuration: -
#- How you installed Lightning(`conda`, `pip`, source): from source
#- Running environment of LightningApp (e.g. local, cloud): -

@awaelchli awaelchli added bug Something isn't working needs triage Waiting to be triaged by maintainers tuner priority: 0 High priority task and removed needs triage Waiting to be triaged by maintainers labels Sep 21, 2022
@awaelchli awaelchli added this to the pl:1.7.x milestone Sep 21, 2022
@awaelchli
Copy link
Contributor Author

awaelchli commented Sep 21, 2022

Previously, on <=1.7.3 if you instantiate a new Trainer after tuning, then you get the same result as on master today, so the fix in #14113 is correct in the sense that the LR scheduler used during tuning gets removed properly before the real fit call.

@awaelchli
Copy link
Contributor Author

I guess the previous version only worked because the scheduler from the tuner would help the model converge, even though the suggestion from the LR finder was bad. Hence, now that the tuner does no longer leave the scheduler behind, the bad learning rate is set at a constant and hence let's the model diverge.

@awaelchli awaelchli removed the priority: 0 High priority task label Sep 21, 2022
@awaelchli awaelchli removed this from the pl:1.7.x milestone Sep 21, 2022
@awaelchli awaelchli added breaking change Includes a breaking change and removed bug Something isn't working labels Sep 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Includes a breaking change tuner
Projects
None yet
Development

No branches or pull requests

1 participant