Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation with torchmetrics extremely slow #1413

Closed
lukazso opened this issue Dec 19, 2022 · 24 comments · Fixed by #2184
Closed

Validation with torchmetrics extremely slow #1413

lukazso opened this issue Dec 19, 2022 · 24 comments · Fixed by #2184
Labels
enhancement New feature or request help wanted Extra attention is needed
Milestone

Comments

@lukazso
Copy link
Contributor

lukazso commented Dec 19, 2022

Bug description

Hi all,

I recently tried to implement a DeepLabV3 training pipeline. I wanted to use the build-in torchmetrics.JaccardIndex as my evaluation metric. My LightningModule looks like this:

import torchmetrics 
from pytorch_lightning import LightningModule
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.loss = nn.CrossEntropyLoss(ignore_index=255, reduction="mean")
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)       
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        self.iou_metric.update(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")

When using this validation procedure, it is extremely slow. On average, the update step of the metric takes 23.4 seconds. However, the first 3 updates are very fast (<1 second), then they become slow.

I tried to reproduce this behavior in a MWE:

from timeit import default_timer as timer
from statistics import mean
import torchmetrics
import torch

num_classes = 38

iou_metric = torchmetrics.JaccardIndex(
    task="multiclass",
    threshold=0.5, 
    num_classes=num_classes,
    average="macro"
).to("cuda")

# dummy labels in shape [b, h, w]
label_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")

# dummy predicted labels in shape [b, h, w]
pred_mask = torch.randint(low=0, high=num_classes-1, size=(8, 480, 640), device="cuda")


runtime_hist = []
for i in range(100):
    start = timer()
    iou_metric.update(label_mask, pred_mask)
    elapsed = timer() - start
    runtime_hist.append(elapsed)


avg_runtime = round(mean(runtime_hist), 2)
print(avg_runtime)

Here I get an average update duration of 0.03 seconds, so I do not encounter the extremely slow update as in my LightningModule above. To me this looks like there is something wrong. At this point, I am not sure if thi

Here some training information for my pytorch-lightning training pipeline:

  • OS: Ubuntu 20.04.4
  • CUDA 11.3
  • DDP training strategy
  • GPUs: 4x V100
  • batch size: 8
  • image size (width x height): 640 x 480
  • number of workers in dataloader: 8

My package versions:

  • pytorch lightning: 1.8.4.post0 (installed via pip)
  • torch: 1.13.0
  • torchvision: 0.14.0
  • torchmetrics: 0.11.0
  • numpy: 11.23.5

Thanks so much!
Lukas

How to reproduce the bug

No response

Error messages and logs

No response

Environment

No response

More info

No response

cc @carmocca @justusschock @awaelchli @Borda

@carmocca
Copy link
Contributor

Your best course of action would be to profiler the update call in both examples and compare the differences. If you share a ready-to-run repro script, we can help you.

Are you sure it's not related to your data? Did you also try using a random dataset with your LightningModule?

@lukazso
Copy link
Contributor Author

lukazso commented Dec 20, 2022

Hi @carmocca, thanks for the reply. I created a repro script with a dummy dataset:

from statistics import mean
from timeit import default_timer as timer

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchmetrics 
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.strategies.ddp import DDPStrategy

from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


torch.random.manual_seed(123)


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 320
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DummyDataModule(LightningDataModule):
    def __init__(self, bs: int = 8, num_workers: int = 8) -> None:
        super().__init__()
        self.bs = bs
        self.num_workers = num_workers
    
    def train_dataloader(self):
        dataset = DummyDataset()
        dataloader = DataLoader(dataset, self.bs, shuffle=True, num_workers=self.num_workers)
        return dataloader
    
    def val_dataloader(self):
        dataset = DummyDataset()
        dataloader = DataLoader(dataset, self.bs, shuffle=False, num_workers=self.num_workers)
        return dataloader


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.loss = nn.CrossEntropyLoss(ignore_index=255, reduction="mean")
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            params=self.model.parameters(),
            lr=0.001
        )
        return optimizer        

    def training_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)       
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        loss = self.loss(preds, masks)

        # preds [num_classes, h, w] -> [h, w]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        self.iou_metric.update(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")


if __name__ == "__main__":
    data_module = DummyDataModule(bs=8, num_workers=8)
    model = DeepLabV3LightningModule()

    strategy = DDPStrategy(find_unused_parameters=False)
    trainer = pl.Trainer(
        max_epochs=1, accelerator="gpu", devices=[0, 1, 2, 3], num_sanity_val_steps=0, strategy=strategy, profiler="simple"
    )
    trainer.fit(model, data_module)

Now I get a strange behavior. When I run it like this, the validation part is still extremely slow. However, according to the profiler and my runtime measurement of the metric update step, the metric now is not the problem. This is what the profiler says ( I omitted the last part of the profiler output for brevity):

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                               |  Mean duration (s)    |  Num calls            |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                                                |  -                    |  701                  |  475.28               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                                                                                                                                                                   |  462.19               |  1                    |  462.19               |  97.245               |
|  [Strategy]DDPStrategy.batch_to_device                                                                                                                                                                |  20.144               |  20                   |  402.88               |  84.767               |
|  [LightningModule]DeepLabV3LightningModule.transfer_batch_to_device                                                                                                                                   |  20.144               |  20                   |  402.88               |  84.766               |
|  [Callback]TQDMProgressBar.on_validation_end                                                                                                                                                          |  44.494               |  1                    |  44.494               |  9.3616               |
|  run_training_batch                                                                                                                                                                                   |  1.1091               |  10                   |  11.091               |  2.3336               |
|  [LightningModule]DeepLabV3LightningModule.optimizer_step                                                                                                                                             |  1.1085               |  10                   |  11.085               |  2.3323               |
|  [Strategy]DDPStrategy.validation_step                                                                                                                                                                |  0.26495              |  10                   |  2.6495               |  0.55745              |
|  [Strategy]DDPStrategy.backward                                                                                                                                                                       |  0.25789              |  10                   |  2.5789               |  0.54261              |
|  [Strategy]DDPStrategy.training_step                                                                                                                                                                  |  0.17081              |  10                   |  1.7081               |  0.3594               |
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}.on_train_epoch_end            |  1.111                |  1                    |  1.111                |  0.23375              |
|  [TrainingEpochLoop].train_dataloader_next                                                                                                                                                            |  0.0067655            |  10                   |  0.067655             |  0.014235             |
|  [EvaluationEpochLoop].val_dataloader_idx_0_next                                                                                                                                                      |  0.0063154            |  10                   |  0.063154             |  0.013288             |
|  [LightningModule]DeepLabV3LightningModule.optimizer_zero_grad                                                                                                                                        |  0.0012365            |  10                   |  0.012365             |  0.0026017            |
|  [Callback]TQDMProgressBar.on_train_batch_end                                                                                                                                                         |  0.00091617           |  10                   |  0.0091617            |  0.0019276            |
|  [Callback]ModelSummary.on_fit_start                                                                                                                                                                  |  0.0057881            |  1                    |  0.0057881            |  0.0012178            |
|  [Callback]TQDMProgressBar.on_validation_batch_end                                                                                                                                                    |  0.0005465            |  10                   |  0.005465             |  0.0011498            |

The validation step is now reasonably fast (~260ms), but pushing the data to the GPU somehow takes forever. If I run the script again with just commenting out the line self.iou_metric.update(pred_labels, masks) in my validation step, everything behaves normal again. Here the profiler output for this case:

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                                                               |  Mean duration (s)    |  Num calls            |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                                                |  -                    |  701                  |  26.378               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  run_training_epoch                                                                                                                                                                                   |  15.044               |  1                    |  15.044               |  57.034               |
|  run_training_batch                                                                                                                                                                                   |  1.1192               |  10                   |  11.192               |  42.428               |
|  [LightningModule]DeepLabV3LightningModule.optimizer_step                                                                                                                                             |  1.1185               |  10                   |  11.185               |  42.404               |
|  [Strategy]DDPStrategy.backward                                                                                                                                                                       |  0.26588              |  10                   |  2.6588               |  10.08                |
|  [Strategy]DDPStrategy.batch_to_device                                                                                                                                                                |  0.12089              |  20                   |  2.4178               |  9.1659               |
|  [LightningModule]DeepLabV3LightningModule.transfer_batch_to_device                                                                                                                                   |  0.12074              |  20                   |  2.4147               |  9.1544               |
|  [Strategy]DDPStrategy.training_step                                                                                                                                                                  |  0.1625               |  10                   |  1.625                |  6.1605               |
|  [Callback]ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}.on_train_epoch_end            |  1.088                |  1                    |  1.088                |  4.1247               |
|  [Callback]TQDMProgressBar.on_validation_end                                                                                                                                                          |  0.24772              |  1                    |  0.24772              |  0.93914              |
|  [Strategy]DDPStrategy.validation_step                                                                                                                                                                |  0.014465             |  10                   |  0.14465              |  0.54838              |
|  [EvaluationEpochLoop].val_dataloader_idx_0_next                                                                                                                                                      |  0.0055015            |  10                   |  0.055015             |  0.20856              |
|  [TrainingEpochLoop].train_dataloader_next                                                                                                                                                            |  0.0051861            |  10                   |  0.051861             |  0.19661              |
|  [LightningModule]DeepLabV3LightningModule.optimizer_zero_grad                                                                                                                                        |  0.0012385            |  10                   |  0.012385             |  0.046952             |
|  [Callback]TQDMProgressBar.on_train_batch_end                                                                                                                                                         |  0.00089197           |  10                   |  0.0089197            |  0.033816             |
|  [Callback]TQDMProgressBar.on_validation_batch_end                                                                                                                                                    |  0.00058113           |  10                   |  0.0058113            |  0.022031             |
|  [Callback]ModelSummary.on_fit_start                                                                                                                                                                  |  0.0055644            |  1                    |  0.0055644            |  0.021095             |
|  [Callback]TQDMProgressBar.on_train_start                                                                                                                                                             |  0.0017953            |  1                    |  0.0017953            |  0.0068062            |
|  [LightningModule]DeepLabV3LightningModule.on_validation_model_train                                                                                                                                  |  0.0011453            |  1                    |  0.0011453            |  0.004342             |
|  [LightningModule]DeepLabV3LightningModule.on_validation_model_eval                                                                                                                                   |  0.0011215            |  1                    |  0.0011215            |  0.0042518            |
|  [Callback]TQDMProgressBar.on_validation_batch_start                                                                                                                                                  |  9.9515e-05           |  10                   |  0.00099515           |  0.0037727            |

I am a little bit lost now. Thanks for your help!

@carmocca
Copy link
Contributor

carmocca commented Dec 21, 2022

@lukazso. Your repro was very helpful. I am able to reproduce the slowdown locally.

First I tried to simplify your version:

Using trainer.validate, removing distributed
from statistics import mean
from timeit import default_timer as timer

import pytorch_lightning as pl
import torch
import torchmetrics
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule, LightningDataModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 320
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DummyDataModule(LightningDataModule):
    def __init__(self, bs: int = 8, num_workers: int = 8) -> None:
        super().__init__()
        self.bs = bs
        self.num_workers = num_workers
    
    def val_dataloader(self):
        dataset = DummyDataset()
        dataloader = DataLoader(dataset, self.bs, shuffle=False, num_workers=self.num_workers)
        return dataloader


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def transfer_batch_to_device(self, batch, device, _):
        profiler = self.trainer.profiler
        with profiler.profile("to0"):
            batch[0] = batch[0].to(device)
        with profiler.profile("to1"):
            batch[1] = batch[1].to(device)
        return batch

    def validation_step(self, batch, _):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        # preds [num_classes, h, w] -> [h, w]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        with self.trainer.profiler.profile("iou_metric.update"):
            self.iou_metric(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")


if __name__ == "__main__":
    seed_everything(1, workers=True)
    data_module = DummyDataModule(bs=8, num_workers=0)
    model = DeepLabV3LightningModule()
    trainer = pl.Trainer(
        accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
    )
    trainer.validate(model, data_module)

After seeing no changes, I decided to avoid trainer.validate but still driving the LightningModule

Manually driving the LightningModule
from statistics import mean
from timeit import default_timer as timer

import pytorch_lightning as pl
import torch
import torchmetrics
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 8 * 5  # batch size * limit_val_batches
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def transfer_batch_to_device(self, batch, device, _):
        profiler = self.trainer.profiler
        with profiler.profile("to0"):
            batch[0] = batch[0].to(device)
        with profiler.profile("to1"):
            batch[1] = batch[1].to(device)
        return batch

    def validation_step(self, batch, _):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        # preds [num_classes, h, w] -> [h, w]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        with self.trainer.profiler.profile("iou_metric.update"):
            self.iou_metric(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")


if __name__ == "__main__":
    seed_everything(1, workers=True)

    model = DeepLabV3LightningModule()
    trainer = pl.Trainer(
        accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
    )
    model.trainer = trainer
    dataset = DummyDataset()
    dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=0)
    device = torch.device("cuda")

    model.to(device)
    outputs = []
    for batch in dataloader:
        batch = model.transfer_batch_to_device(batch, device, 0)
        with torch.inference_mode():
            elapsed = model.validation_step(batch, 0)
        outputs.append(elapsed)
        model.validation_epoch_end(outputs)
    print(trainer.profiler.summary())

Still, same behaviour

So I completely stripped out PyTorch Lightning, only leaving it's profiler and using the advanced profiler

Pure PyTorch
from statistics import mean
from timeit import default_timer as timer

import pytorch_lightning as pl
import torch
import torchmetrics
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 8 * 5  # batch size * limit_val_batches
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DeepLabV3Module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

        self.profiler = pl.profilers.AdvancedProfiler()

    def transfer_batch_to_device(self, batch, device, _):
        profiler = self.profiler
        with profiler.profile("to0"):
            batch[0] = batch[0].to(device)
        with profiler.profile("to1"):
            batch[1] = batch[1].to(device)
        return batch

    def validation_step(self, batch, _):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)

        start = timer()
        with self.profiler.profile("iou_metric.update"):
            # COMMENT THIS to see difference
            self.iou_metric(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"{avg_runtime} seconds")


if __name__ == "__main__":
    model = DeepLabV3Module()
    dataset = DummyDataset()
    dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=0)
    device = torch.device("cuda")

    model.to(device)
    outputs = []
    for batch in dataloader:
        batch = model.transfer_batch_to_device(batch, device, 0)
        with torch.inference_mode():
            elapsed = model.validation_step(batch, 0)
        outputs.append(elapsed)
        model.validation_epoch_end(outputs)
    print(model.profiler.summary())

And still, no changes. So it must be an issue with PyTorch or torchmetrics. You'll have to debug further.

The profiler report shows that the difference is in batch[0].to(device), where updating the iou_metric makes the {method 'to' of 'torch._C._TensorBase' objects} be much slower.

cc @SkafteNicki or @justusschock in case you have any ideas of what could be causing this in torchmetrics. Maybe this issue should be transferred there.

@lukazso
Copy link
Contributor Author

lukazso commented Dec 26, 2022

@carmocca thanks for your help! Do you have the permission to move this issue to the metrics repo?

@carmocca carmocca transferred this issue from Lightning-AI/pytorch-lightning Dec 26, 2022
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@lukazso
Copy link
Contributor Author

lukazso commented Jan 18, 2023

Hi @carmocca , the issue currently has the "waiting for author" tag. I think a response from the metrics maintainers would be of great help to point towards the right direction. Can you change the label?

@carmocca carmocca added bug / fix Something isn't working and removed waiting on author labels Jan 18, 2023
@lukazso
Copy link
Contributor Author

lukazso commented Jan 23, 2023

Does one of the maintainers have an opinion on this? :)

@justusschock
Copy link
Member

Hey @lukazso
Sorry for the late reply. Unfortunately, I currently have no idea what could be causing this.

To further narrow this down: Do you experience the same with other metrics?

Could you maybe try with ConfusionMatrix (which is the base class of Jaccard) and Accuracy (as another classification metric)?

@SkafteNicki
Copy link
Member

Hi @lukazso,
I finally got around to debugging this issue. I have some results to share. Thanks @carmocca for helping with the initial narrowing of the problem. Steps I went through

  1. Started with another classification metric (Accuracy) that comes from another family of metrics than JaccardIndex (which comes from the ConfusionMatrix family) as proposed by @justusschock . No change.

  2. Tried a metric from a completely different package, in this case MeanSquaredError from the regression package. Everything works as expected. Problem must therefore be with classification metrics.

  3. Wanted to make sure that it had nothing to do with the logic in the class metrics, so tried out the functional version multiclass_jaccard_index. No change, so the problem must be in the backbone of the calculation.

  4. Started removing one line at the time from the implementation. Finally found the line that causes the problem
    https://github.com/Lightning-AI/metrics/blob/0e1639271fed7326650ef75115114f60b9e83802/src/torchmetrics/utilities/data.py#L228
    calling torch.bincount for some reason creates this problem.

  5. Tried removing everything that have to do with torchmetrics and just call torch.bincount.

    Removed torchmetrics, plain lightning + torch.bincount
    from statistics import mean
    from timeit import default_timer as timer
    
    import pytorch_lightning as pl
    import torch
    from torch.utils.data import Dataset, DataLoader
    from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
    
    
    class DummyDataset(Dataset):
        def __init__(self) -> None:
            super().__init__()
            self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
            self.mask = torch.ones((480, 640), dtype=torch.long
            )
    
        def __len__(self):
            return 8 * 5  # batch size * limit_val_batches
    
        def __getitem__(self, index):
            return self.random_img, self.mask
    
    
    class DeepLabV3Module(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.model = deeplabv3_resnet50(
                num_classes=38,
                aux_loss=False
            )
            self.profiler = pl.profilers.AdvancedProfiler()
    
        def transfer_batch_to_device(self, batch, device, _):
            profiler = self.profiler
            with profiler.profile("to0"):
                batch[0] = batch[0].to(device)
            with profiler.profile("to1"):
                batch[1] = batch[1].to(device)
            return batch
    
        def validation_step(self, batch, _):
            imgs, masks = batch
            out = self.model(imgs)
            preds = out["out"]
            preds = torch.softmax(preds, dim=1)
            pred_labels = torch.argmax(preds, dim=1)
    
            start = timer()
            with self.profiler.profile("bincount"):
                torch.bincount(pred_labels.flatten(), minlength=38)
            elapsed = timer() - start
            return elapsed
    
        def validation_epoch_end(self, outputs):
            avg_runtime = round(mean(outputs), 4)
            print(f"{avg_runtime} seconds")
    
    
    if __name__ == "__main__":
        model = DeepLabV3Module()
        dataset = DummyDataset()
        dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=0)
        device = torch.device("cuda")
    
        model.to(device)
        outputs = []
        for batch in dataloader:
            batch = model.transfer_batch_to_device(batch, device, 0)
            with torch.inference_mode():
                elapsed = model.validation_step(batch, 0)
            outputs.append(elapsed)
            model.validation_epoch_end(outputs)
        print(model.profiler.summary())
    this still reproduces the slowdown.
  6. Tried to remove lightning and just use plain pytorch

    import torch
    from time import time
    
    for _ in range(10):
        # this is just some sizes taken from the previous examples 
        x = (38*torch.randn(2457600, requires_grad=True, device='cuda')).round().abs().long()
        start = time()
        for _ in range(10):
            _ = torch.bincount(x, minlength=38*38)
        print(f"{time()-start}")

    everything seems to work as expected here.

The TLDR seems to be that it is some weird issue between lightning and torch.bincount, I have never seen before. I have also tried to change to torch.histc without success. We cannot go around using torch.bincount for classification metrics because we use it for nearly every metric because we need to count when pred == target and pred!=target for all the different class combinations.

@Borda, @justusschock, @carmocca not sure what the path forward is from here.

@lukazso
Copy link
Contributor Author

lukazso commented Feb 5, 2023

Hey @SkafteNicki, thanks for your thorough analysis! I ran your example where you ran torch.bincount without any torchmetric related stuff. I also experience the slowdown. I printed the runtimes of both torch.bincount and model.transfer_batch_to_device, because in my initial mwe the profiler showed that transferring the data to the GPU somehow is the slow step.

Batch 0 | batch to device:       5.89 milliseconds
Batch 0 | bincount:              21.07 milliseconds

Batch 1 | batch to device:       12.83 milliseconds
Batch 1 | bincount:              1533.85 milliseconds

Batch 2 | batch to device:       13.11 milliseconds
Batch 2 | bincount:              1566.48 milliseconds

Batch 3 | batch to device:       15.45 milliseconds
Batch 3 | bincount:              1511.01 milliseconds

Batch 4 | batch to device:       14.43 milliseconds
Batch 4 | bincount:              1506.0 milliseconds

I also ran my mwe again (the version modified by @carmocca) and measured the runtime of torch.bincount in when it is called originally in torchmetrics.

Code of my mwe
from statistics import mean
from timeit import default_timer as timer

import pytorch_lightning as pl
import torch
import torchmetrics
# from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule, LightningDataModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
        self.mask = torch.ones((480, 640), dtype=torch.long
        )
    
    def __len__(self):
        return 320
    
    def __getitem__(self, index):
        return self.random_img, self.mask


class DummyDataModule(LightningDataModule):
    def __init__(self, bs: int = 8, num_workers: int = 8) -> None:
        super().__init__()
        self.bs = bs
        self.num_workers = num_workers
    
    def val_dataloader(self):
        dataset = DummyDataset()
        dataloader = DataLoader(dataset, self.bs, shuffle=False, num_workers=self.num_workers)
        return dataloader


class DeepLabV3LightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = deeplabv3_resnet50(
            num_classes=38,
            aux_loss=False
        )
        self.iou_metric = torchmetrics.JaccardIndex(
            task="multiclass", 
            threshold=0.5, 
            num_classes=38,
            average="macro",
        )

    def transfer_batch_to_device(self, batch, device, _):
        profiler = self.trainer.profiler
        with profiler.profile("to0"):
            batch[0] = batch[0].to(device)
        with profiler.profile("to1"):
            batch[1] = batch[1].to(device)
        return batch

    def validation_step(self, batch, _):
        imgs, masks = batch
        out = self.model(imgs)
        preds = out["out"]
        # preds [num_classes, h, w] -> [h, w]
        preds = torch.softmax(preds, dim=1)
        pred_labels = torch.argmax(preds, dim=1)
        
        # measure runtime of metric update
        start = timer()
        with self.trainer.profiler.profile("iou_metric.update"):
            self.iou_metric(pred_labels, masks)
        elapsed = timer() - start
        return elapsed

    def validation_epoch_end(self, outputs):
        avg_runtime = round(mean(outputs), 4)
        print(f"GPU {self.local_rank}: {avg_runtime} seconds")


if __name__ == "__main__":
    # seed_everything(1, workers=True)
    data_module = DummyDataModule(bs=8, num_workers=0)
    model = DeepLabV3LightningModule()
    trainer = pl.Trainer(
        accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
    )
    trainer.validate(model, data_module)

Now, the runtimes of torch.bincount are not the problem! (even if the runtimes look a bit too low to me)

1.24 milliseconds
1.0 milliseconds
1.0 milliseconds
1.0 milliseconds
1.0 milliseconds

Instead, the profiler again points to model.transfer_batch_to_device as being the step which takes so long. However, you are right @SkafteNicki, torch.bincount is the problem, as model.transfer_batch_to_device works normal again when removing torch.bincount in the metrics calculation.

So I still do not understand why in the two cases different parts are affected by the slowdown. Maybe I am missing something?

Anyway, I also do not have a solution. The dirty quickfix would be to remove torch.bincount and use the for-loop option, knowing that this is far less performant and does not solve the actual issue (but still more performant than with the current slowdown):
https://github.com/Lightning-AI/metrics/blob/0e1639271fed7326650ef75115114f60b9e83802/src/torchmetrics/utilities/data.py#L224-L227

@carmocca
Copy link
Contributor

carmocca commented Feb 6, 2023

The TLDR seems to be that it is some weird issue between lightning and torch.bincount

@SkafteNicki The snippet you are using does not use lightning other than for the profiler. Notice how the model does not subclass LightnignModule and no Trainer is used. That example can be compressed further and be shared in the PyTorch issue tracker for further debugging by their dev team. As @lukazso correctly noticed, the interaction is between Tensor.to(device) and torch.bincount

@SkafteNicki
Copy link
Member

@carmocca yeah, I was probably to fast on the keyboard there and jumped to a conclusion.
I try narrow the script down, so we can report it to the Pytorch team.
In the meantime I am looking into alternatives, to see if we can solve it in our end.

@gabe-scorebreak
Copy link

Any progress?

@CWrecker
Copy link

CWrecker commented Jul 6, 2023

Did PyTorch acknowledge the potential issue / is there a link?

@xor-xor
Copy link

xor-xor commented Jul 11, 2023

FWIW, I experience similar slowdowns with MulticlassF1Score, MulticlassPrecision and MulticlassRecall, but only when average argument is set to "macro". And this starts to bite me pretty fast (say, hundreds of classes), whereas my classification task has more than 10k of classes, which makes those metrics completely unusable for me. Surprisingly, this doesn't happen at all with their multi-label counterparts.

(TorchMetrics: 1.0.0; Lightning: 1.9.5; PyTorch: 1.13.1)

@SkafteNicki SkafteNicki added this to the v1.1.0 milestone Jul 13, 2023
@xor-xor
Copy link

xor-xor commented Jul 14, 2023

An update to my previous comment: my problem seem to happen only when deterministic=True is being passed to Trainer.

BTW, during recent upgrade of my codebase, I've noticed that --deterministic flag got removed from Lightning's args somewhere between 1.5.10 and 1.9.5 versions, but deterministic arg is still present in Trainer - why is that?

@SkafteNicki
Copy link
Member

@xor-xor if you have deterministic=True then I can explain why you are seeing bad performance, its due to this for-loop:

if minlength is None:
minlength = len(torch.unique(x))
if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
for i in range(minlength):
output[i] = (x == i).sum()
return output
return torch.bincount(x, minlength=minlength)

We use bincount to count the number of times when pred==target for each individual class, but torch.bincount is not deterministic. Thus, if the user request deterministic mode then we need to use a for loop over the classes, which of cause scales very badly when you have a lot of classes.

@carmocca
Copy link
Contributor

@xor-xor

got removed from Lightning's args

What do you mean with Lightning's args? The LightningCLI?

@xor-xor
Copy link

xor-xor commented Jul 14, 2023

@SkafteNicki thanks for your explanation - totally makes sense!

@carmocca I mean pytorch_lightning.Trainer.add_argparse_args() - that's how I integrate Lightning's command-line args with my codebase's args since at least PyTorch Lightning 0.8.1, so I didn't know that something like LightningCLI has emerged in the meantime ;)

@SkafteNicki SkafteNicki modified the milestones: v1.1.0, future Aug 14, 2023
@Borda
Copy link
Member

Borda commented Aug 25, 2023

We use bincount to count the number of times when pred==target for each individual class, but torch.bincount is not deterministic. Thus, if the user request deterministic mode then we need to use a for loop over the classes, which of cause scales very badly when you have a lot of classes.

@xor-xor would you be interested in finding a better/faster solution? 🐿️

@Borda Borda added enhancement New feature or request help wanted Extra attention is needed and removed bug / fix Something isn't working labels Aug 25, 2023
@xor-xor
Copy link

xor-xor commented Aug 28, 2023

@xor-xor would you be interested in finding a better/faster solution? 🐿️

@Borda unfortunately, I'm very time-constrained nowadays, so more than likely I won't find space for that. Also, my team mates decided that we can live without deterministic=True in our extreme classification tasks for now, so it will be hard to carve such space from my working hours :(

@kyle-dorman
Copy link
Contributor

kyle-dorman commented Oct 11, 2023

use meshgrid?

import time
import torch

num_classes = 1000
minlength = num_classes ** 2
batch_size = 64
device = torch.device(torch.cuda.current_device())

output = torch.zeros(minlength, device=device, dtype=torch.long)
target = torch.zeros(batch_size, device=device, dtype=torch.long)

start = time.time()

for i in range(minlength):
    output[i] = (target == i).sum()
print(output.shape)
    
end = time.time()

print(end - start)

start = time.time()

mesh = torch.meshgrid(torch.arange(batch_size), torch.arange(minlength), indexing='ij')[1].to(device)
check = torch.eq(target.reshape(-1, 1), mesh)
output2 = check.sum(dim=0)
print(output2.shape)

end = time.time()

print(end - start)

torch.equal(output, output2)

Result

torch.Size([1000000])
38.242313861846924
torch.Size([1000000])
0.2178659439086914
True

@Borda
Copy link
Member

Borda commented Oct 12, 2023

use meshgrid?

Looks great! mind sending a PR?

@kyle-dorman
Copy link
Contributor

kyle-dorman commented Oct 17, 2023

Sent! #2184

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants