Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect result when a metric object is logged with PyTorch Lightning logging system #2231

Closed
laclouis5 opened this issue Nov 21, 2023 · 6 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x

Comments

@laclouis5
Copy link

🐛 Bug

I have a Lightning module very similar the one presented in the TorchMetrics PyTorch Lightning tutorial:

class MyModel(LightningModule):

    def __init__(self, num_classes):
        ...
        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        # log step metric
        self.accuracy(preds, y)
        self.log('train_acc_step', self.accuracy)
        ...

    def on_train_epoch_end(self):
        # log epoch metric
        self.log('train_acc_epoch', self.accuracy)

Basically, I log the metric object directly with self.log but this gives an incorrect result. If I manually compute the result with .compute() then the calculation is correct.

I read the Common Pitfalls section but none of the cases apply to my setting, I think.

My workflow is a little more complex, I'll try to sum up in the following section.

To Reproduce

I'm addressing a simple multi-class classification problem on images. For that, I'm using PyTorch Lightning and I got some standard classification metrics stored in a MetricCollection object. In the following B is the batch size and C is the number of classes. (H, W) i the image size.

class ClassificationModel(LightningModule):
  def __init__(self, nb_classes: int):
        super().__init__()

        self.net = ...
 
        self.metrics = torchmetrics.MetricCollection(
            metrics={
                "Accuracy": torchmetrics.Accuracy(
                    task="multiclass", average="macro", num_classes=nb_classes
                ),
                "AP": torchmetrics.AveragePrecision(
                    task="multiclass", average="macro", num_classes=nb_classes
                ),
                "Recall": torchmetrics.Recall(
                    task="multiclass", average="macro", num_classes=nb_classes
                ),
                "Precision": torchmetrics.Precision(
                    task="multiclass", average="macro", num_classes=nb_classes
                ),
                "F1Score": torchmetrics.F1Score(
                    task="multiclass", average="macro", num_classes=nb_classes
                ),
            },
            compute_groups=False,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # Returns a tensor logits of shape (B, C)
        ...

    def validation_step(self, batch: dict[str, torch.Tensor], _):
        image = batch["image"]  # (B, 3, H, W)
        label = batch["label"]  # (B,)

        output = self(image)  # (B, C)

        self.metrics.update(output, label)

    def on_validation_epoch_end(self):
        self.log_dict(dictionary=self.metrics)

Expected behavior

As explained above, this yield an incorrect result. If I change to this, then the result is correct:

    def on_validation_epoch_end(self):
        metrics = self.metrics.compute()
        self.log_dict(dictionary=metrics)

As advertised in the tutorial, logging the metric object directly should yield the exact same result.

Environment

  • TorchMetrics version: 1.2.0
  • Python & PyTorch Version: 3.11 and 2.0.1+cu118
  • Any other relevant information such as OS (e.g., Linux): Linux
@laclouis5 laclouis5 added bug / fix Something isn't working help wanted Extra attention is needed labels Nov 21, 2023
Copy link

Hi! thanks for your contribution!, great first issue!

@laclouis5
Copy link
Author

laclouis5 commented Nov 21, 2023

This is another issue but I also noticed that not providing compute_groups=False in the MetricCollection gives garbage results. This is very surprising since this is the default.

@Borda Borda added the v1.2.x label Nov 22, 2023
@SkafteNicki
Copy link
Member

This is another issue but I also noticed that not providing compute_groups=False in the MetricCollection gives garbage results. This is very surprising since this is the default.

I think it is due to another reported issue #2206 that is being fixed in PR #2211

@SkafteNicki
Copy link
Member

Hi @laclouis5, thanks for reporting this issue.
Could you try to be more specific? Maybe provide a reproducible example for the issue. I quickly tried to write a test that should cover what you are mentioning:

def test_something(tmpdir):
"Testing something"
class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.metric = SumMetric()
self.metric2 = SumMetric()
self.sum = []
def training_step(self, batch, batch_idx):
x = batch
s = x.sum()
self.sum.append(s)
self.metric(s)
self.metric2(s)
self.log("train_step_metric", self.metric)
return self.step(x)
def on_train_epoch_end(self):
# log epoch metric
self.log("train_epoch_metric", self.metric)
val = self.metric2.compute()
self.log("train_epoch_manual_metric", val)
logger = CustomCSVLogger("tmpdir/logs")
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=5,
limit_val_batches=0,
max_epochs=1,
log_every_n_steps=1,
logger=logger,
)
model = TestModel()
with no_warning_call(
UserWarning,
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
trainer.fit(model)
logged_metrics = logger._experiment.metrics
assert logged_metrics[-1]['train_epoch_metric'] == logged_metrics[-1]["train_epoch_manual_metric"]

and the last assert does not fail, meaning that logging the object and logging the value is equal for this basic example.

@laclouis5
Copy link
Author

Here is a code snippet that seems to reproduce the issue on my machine. Note that I'm using a TensorBoard logger, so the issue can be seen in the TensorBoard time series.

However, I think I found the problem. It looks like adding a self.metric.reset() after the logging of the manual metric (the one obtained with .compute()) solves the issue. But this is weird that you do not observe this issue since you also forgot to call .reset() in your example test code.

from pathlib import Path
from typing import Any

import lightning
import torch
import torchmetrics
from lightning.pytorch.loggers import TensorBoardLogger
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset


class CustomNetwork(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.layer = torch.nn.Linear(10, 5)

    def forward(self, input):
        return self.layer(input)


class CustomDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()

    def __getitem__(self, index):
        return torch.randn(size=(10,)), 3

    def __len__(self):
        return 100


class CustomDataModule(lightning.LightningDataModule):
    def __init__(self) -> None:
        super().__init__()

        self.train_dataset = CustomDataset()
        self.valid_dataset = CustomDataset()

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=2,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=2,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )


class CustomModel(lightning.LightningModule):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        self.model = CustomNetwork()

        self.metric_1 = torchmetrics.Accuracy(
            task="multiclass", average="macro", num_classes=5
        )

        self.metric_2 = torchmetrics.Accuracy(
            task="multiclass", average="macro", num_classes=5
        )

        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, input):
        return self.model(input)

    def configure_optimizers(self):
        return AdamW(self.model.parameters(), lr=0.001)

    def training_step(self, batch, index):
        input, gt = batch
        gt = gt.to(dtype=torch.long)
        pred = self(input)

        return self.loss_fn(pred, gt)

    def validation_step(self, batch, index):
        input, gt = batch
        gt = gt.to(dtype=torch.long)
        pred = self(input)

        self.metric_1.update(pred, gt)
        self.metric_2.update(pred, gt)

    def on_validation_epoch_end(self):
        self.log("M1", self.metric_1)

        m2 = self.metric_2.compute()
        self.log("M2", m2)

        # This solves the issue
        # self.metric_2.reset()


def main():
    torch.use_deterministic_algorithms(True, warn_only=True)
    lightning.seed_everything(31415, workers=True)
    torch.set_float32_matmul_precision("medium")

    Path("test_runs/").mkdir(exist_ok=True)
    logger = TensorBoardLogger(save_dir=Path.cwd(), name="test_runs/")

    trainer = lightning.Trainer(
        precision="16-mixed",
        logger=logger,
        max_epochs=100,
        deterministic="warn",
        log_every_n_steps=1,
    )

    model = CustomModel()
    datamodule = CustomDataModule()

    trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__":
    main()

@laclouis5
Copy link
Author

I'm closing the issue since I think that the error is on my side here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x
Projects
None yet
Development

No branches or pull requests

3 participants