-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Validation with torchmetrics extremely slow #1413
Comments
Your best course of action would be to profiler the Are you sure it's not related to your data? Did you also try using a random dataset with your LightningModule? |
Hi @carmocca, thanks for the reply. I created a repro script with a dummy dataset: from statistics import mean
from timeit import default_timer as timer
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.strategies.ddp import DDPStrategy
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
torch.random.manual_seed(123)
class DummyDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
self.mask = torch.ones((480, 640), dtype=torch.long
)
def __len__(self):
return 320
def __getitem__(self, index):
return self.random_img, self.mask
class DummyDataModule(LightningDataModule):
def __init__(self, bs: int = 8, num_workers: int = 8) -> None:
super().__init__()
self.bs = bs
self.num_workers = num_workers
def train_dataloader(self):
dataset = DummyDataset()
dataloader = DataLoader(dataset, self.bs, shuffle=True, num_workers=self.num_workers)
return dataloader
def val_dataloader(self):
dataset = DummyDataset()
dataloader = DataLoader(dataset, self.bs, shuffle=False, num_workers=self.num_workers)
return dataloader
class DeepLabV3LightningModule(LightningModule):
def __init__(self):
super().__init__()
self.model = deeplabv3_resnet50(
num_classes=38,
aux_loss=False
)
self.loss = nn.CrossEntropyLoss(ignore_index=255, reduction="mean")
self.iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=38,
average="macro",
)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
params=self.model.parameters(),
lr=0.001
)
return optimizer
def training_step(self, batch, batch_idx):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
loss = self.loss(preds, masks)
return loss
def validation_step(self, batch, batch_idx):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
loss = self.loss(preds, masks)
# preds [num_classes, h, w] -> [h, w]
preds = torch.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
# measure runtime of metric update
start = timer()
self.iou_metric.update(pred_labels, masks)
elapsed = timer() - start
return elapsed
def validation_epoch_end(self, outputs):
avg_runtime = round(mean(outputs), 4)
print(f"GPU {self.local_rank}: {avg_runtime} seconds")
if __name__ == "__main__":
data_module = DummyDataModule(bs=8, num_workers=8)
model = DeepLabV3LightningModule()
strategy = DDPStrategy(find_unused_parameters=False)
trainer = pl.Trainer(
max_epochs=1, accelerator="gpu", devices=[0, 1, 2, 3], num_sanity_val_steps=0, strategy=strategy, profiler="simple"
)
trainer.fit(model, data_module) Now I get a strange behavior. When I run it like this, the validation part is still extremely slow. However, according to the profiler and my runtime measurement of the metric update step, the metric now is not the problem. This is what the profiler says ( I omitted the last part of the profiler output for brevity):
The validation step is now reasonably fast (~260ms), but pushing the data to the GPU somehow takes forever. If I run the script again with just commenting out the line
I am a little bit lost now. Thanks for your help! |
@lukazso. Your repro was very helpful. I am able to reproduce the slowdown locally. First I tried to simplify your version: Using trainer.validate, removing distributedfrom statistics import mean
from timeit import default_timer as timer
import pytorch_lightning as pl
import torch
import torchmetrics
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule, LightningDataModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
class DummyDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
self.mask = torch.ones((480, 640), dtype=torch.long
)
def __len__(self):
return 320
def __getitem__(self, index):
return self.random_img, self.mask
class DummyDataModule(LightningDataModule):
def __init__(self, bs: int = 8, num_workers: int = 8) -> None:
super().__init__()
self.bs = bs
self.num_workers = num_workers
def val_dataloader(self):
dataset = DummyDataset()
dataloader = DataLoader(dataset, self.bs, shuffle=False, num_workers=self.num_workers)
return dataloader
class DeepLabV3LightningModule(LightningModule):
def __init__(self):
super().__init__()
self.model = deeplabv3_resnet50(
num_classes=38,
aux_loss=False
)
self.iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=38,
average="macro",
)
def transfer_batch_to_device(self, batch, device, _):
profiler = self.trainer.profiler
with profiler.profile("to0"):
batch[0] = batch[0].to(device)
with profiler.profile("to1"):
batch[1] = batch[1].to(device)
return batch
def validation_step(self, batch, _):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
# preds [num_classes, h, w] -> [h, w]
preds = torch.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
# measure runtime of metric update
start = timer()
with self.trainer.profiler.profile("iou_metric.update"):
self.iou_metric(pred_labels, masks)
elapsed = timer() - start
return elapsed
def validation_epoch_end(self, outputs):
avg_runtime = round(mean(outputs), 4)
print(f"GPU {self.local_rank}: {avg_runtime} seconds")
if __name__ == "__main__":
seed_everything(1, workers=True)
data_module = DummyDataModule(bs=8, num_workers=0)
model = DeepLabV3LightningModule()
trainer = pl.Trainer(
accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
)
trainer.validate(model, data_module) After seeing no changes, I decided to avoid Manually driving the LightningModulefrom statistics import mean
from timeit import default_timer as timer
import pytorch_lightning as pl
import torch
import torchmetrics
from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
class DummyDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
self.mask = torch.ones((480, 640), dtype=torch.long
)
def __len__(self):
return 8 * 5 # batch size * limit_val_batches
def __getitem__(self, index):
return self.random_img, self.mask
class DeepLabV3LightningModule(LightningModule):
def __init__(self):
super().__init__()
self.model = deeplabv3_resnet50(
num_classes=38,
aux_loss=False
)
self.iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=38,
average="macro",
)
def transfer_batch_to_device(self, batch, device, _):
profiler = self.trainer.profiler
with profiler.profile("to0"):
batch[0] = batch[0].to(device)
with profiler.profile("to1"):
batch[1] = batch[1].to(device)
return batch
def validation_step(self, batch, _):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
# preds [num_classes, h, w] -> [h, w]
preds = torch.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
# measure runtime of metric update
start = timer()
with self.trainer.profiler.profile("iou_metric.update"):
self.iou_metric(pred_labels, masks)
elapsed = timer() - start
return elapsed
def validation_epoch_end(self, outputs):
avg_runtime = round(mean(outputs), 4)
print(f"GPU {self.local_rank}: {avg_runtime} seconds")
if __name__ == "__main__":
seed_everything(1, workers=True)
model = DeepLabV3LightningModule()
trainer = pl.Trainer(
accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
)
model.trainer = trainer
dataset = DummyDataset()
dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=0)
device = torch.device("cuda")
model.to(device)
outputs = []
for batch in dataloader:
batch = model.transfer_batch_to_device(batch, device, 0)
with torch.inference_mode():
elapsed = model.validation_step(batch, 0)
outputs.append(elapsed)
model.validation_epoch_end(outputs)
print(trainer.profiler.summary()) Still, same behaviour So I completely stripped out PyTorch Lightning, only leaving it's profiler and using the advanced profiler Pure PyTorchfrom statistics import mean
from timeit import default_timer as timer
import pytorch_lightning as pl
import torch
import torchmetrics
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
class DummyDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
self.mask = torch.ones((480, 640), dtype=torch.long
)
def __len__(self):
return 8 * 5 # batch size * limit_val_batches
def __getitem__(self, index):
return self.random_img, self.mask
class DeepLabV3Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = deeplabv3_resnet50(
num_classes=38,
aux_loss=False
)
self.iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=38,
average="macro",
)
self.profiler = pl.profilers.AdvancedProfiler()
def transfer_batch_to_device(self, batch, device, _):
profiler = self.profiler
with profiler.profile("to0"):
batch[0] = batch[0].to(device)
with profiler.profile("to1"):
batch[1] = batch[1].to(device)
return batch
def validation_step(self, batch, _):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
preds = torch.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
start = timer()
with self.profiler.profile("iou_metric.update"):
# COMMENT THIS to see difference
self.iou_metric(pred_labels, masks)
elapsed = timer() - start
return elapsed
def validation_epoch_end(self, outputs):
avg_runtime = round(mean(outputs), 4)
print(f"{avg_runtime} seconds")
if __name__ == "__main__":
model = DeepLabV3Module()
dataset = DummyDataset()
dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=0)
device = torch.device("cuda")
model.to(device)
outputs = []
for batch in dataloader:
batch = model.transfer_batch_to_device(batch, device, 0)
with torch.inference_mode():
elapsed = model.validation_step(batch, 0)
outputs.append(elapsed)
model.validation_epoch_end(outputs)
print(model.profiler.summary()) And still, no changes. So it must be an issue with PyTorch or torchmetrics. You'll have to debug further. The profiler report shows that the difference is in cc @SkafteNicki or @justusschock in case you have any ideas of what could be causing this in torchmetrics. Maybe this issue should be transferred there. |
@carmocca thanks for your help! Do you have the permission to move this issue to the metrics repo? |
Hi! thanks for your contribution!, great first issue! |
Hi @carmocca , the issue currently has the "waiting for author" tag. I think a response from the metrics maintainers would be of great help to point towards the right direction. Can you change the label? |
Does one of the maintainers have an opinion on this? :) |
Hey @lukazso To further narrow this down: Do you experience the same with other metrics? Could you maybe try with ConfusionMatrix (which is the base class of Jaccard) and Accuracy (as another classification metric)? |
Hi @lukazso,
The TLDR seems to be that it is some weird issue between lightning and @Borda, @justusschock, @carmocca not sure what the path forward is from here. |
Hey @SkafteNicki, thanks for your thorough analysis! I ran your example where you ran
I also ran my mwe again (the version modified by @carmocca) and measured the runtime of Code of my mwefrom statistics import mean
from timeit import default_timer as timer
import pytorch_lightning as pl
import torch
import torchmetrics
# from lightning_lite.utilities.seed import seed_everything
from pytorch_lightning import LightningModule, LightningDataModule
from torch.utils.data import Dataset, DataLoader
from torchvision.models.segmentation.deeplabv3 import deeplabv3_resnet50
class DummyDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.random_img = torch.rand((3, 480, 640), dtype=torch.float32)
self.mask = torch.ones((480, 640), dtype=torch.long
)
def __len__(self):
return 320
def __getitem__(self, index):
return self.random_img, self.mask
class DummyDataModule(LightningDataModule):
def __init__(self, bs: int = 8, num_workers: int = 8) -> None:
super().__init__()
self.bs = bs
self.num_workers = num_workers
def val_dataloader(self):
dataset = DummyDataset()
dataloader = DataLoader(dataset, self.bs, shuffle=False, num_workers=self.num_workers)
return dataloader
class DeepLabV3LightningModule(LightningModule):
def __init__(self):
super().__init__()
self.model = deeplabv3_resnet50(
num_classes=38,
aux_loss=False
)
self.iou_metric = torchmetrics.JaccardIndex(
task="multiclass",
threshold=0.5,
num_classes=38,
average="macro",
)
def transfer_batch_to_device(self, batch, device, _):
profiler = self.trainer.profiler
with profiler.profile("to0"):
batch[0] = batch[0].to(device)
with profiler.profile("to1"):
batch[1] = batch[1].to(device)
return batch
def validation_step(self, batch, _):
imgs, masks = batch
out = self.model(imgs)
preds = out["out"]
# preds [num_classes, h, w] -> [h, w]
preds = torch.softmax(preds, dim=1)
pred_labels = torch.argmax(preds, dim=1)
# measure runtime of metric update
start = timer()
with self.trainer.profiler.profile("iou_metric.update"):
self.iou_metric(pred_labels, masks)
elapsed = timer() - start
return elapsed
def validation_epoch_end(self, outputs):
avg_runtime = round(mean(outputs), 4)
print(f"GPU {self.local_rank}: {avg_runtime} seconds")
if __name__ == "__main__":
# seed_everything(1, workers=True)
data_module = DummyDataModule(bs=8, num_workers=0)
model = DeepLabV3LightningModule()
trainer = pl.Trainer(
accelerator="gpu", devices=1, limit_val_batches=5, profiler="simple"
)
trainer.validate(model, data_module) Now, the runtimes of
Instead, the profiler again points to So I still do not understand why in the two cases different parts are affected by the slowdown. Maybe I am missing something? Anyway, I also do not have a solution. The dirty quickfix would be to remove |
@SkafteNicki The snippet you are using does not use lightning other than for the profiler. Notice how the model does not subclass |
@carmocca yeah, I was probably to fast on the keyboard there and jumped to a conclusion. |
Any progress? |
Did PyTorch acknowledge the potential issue / is there a link? |
FWIW, I experience similar slowdowns with (TorchMetrics: 1.0.0; Lightning: 1.9.5; PyTorch: 1.13.1) |
An update to my previous comment: my problem seem to happen only when BTW, during recent upgrade of my codebase, I've noticed that |
@xor-xor if you have torchmetrics/src/torchmetrics/utilities/data.py Lines 234 to 241 in 3ca8a89
We use bincount to count the number of times when pred==target for each individual class, but torch.bincount is not deterministic. Thus, if the user request deterministic mode then we need to use a for loop over the classes, which of cause scales very badly when you have a lot of classes.
|
What do you mean with Lightning's args? The |
@SkafteNicki thanks for your explanation - totally makes sense! @carmocca I mean |
@xor-xor would you be interested in finding a better/faster solution? 🐿️ |
@Borda unfortunately, I'm very time-constrained nowadays, so more than likely I won't find space for that. Also, my team mates decided that we can live without |
use meshgrid? import time
import torch
num_classes = 1000
minlength = num_classes ** 2
batch_size = 64
device = torch.device(torch.cuda.current_device())
output = torch.zeros(minlength, device=device, dtype=torch.long)
target = torch.zeros(batch_size, device=device, dtype=torch.long)
start = time.time()
for i in range(minlength):
output[i] = (target == i).sum()
print(output.shape)
end = time.time()
print(end - start)
start = time.time()
mesh = torch.meshgrid(torch.arange(batch_size), torch.arange(minlength), indexing='ij')[1].to(device)
check = torch.eq(target.reshape(-1, 1), mesh)
output2 = check.sum(dim=0)
print(output2.shape)
end = time.time()
print(end - start)
torch.equal(output, output2) Result torch.Size([1000000])
38.242313861846924
torch.Size([1000000])
0.2178659439086914
True |
Looks great! mind sending a PR? |
Sent! #2184 |
Bug description
Hi all,
I recently tried to implement a DeepLabV3 training pipeline. I wanted to use the build-in
torchmetrics.JaccardIndex
as my evaluation metric. MyLightningModule
looks like this:When using this validation procedure, it is extremely slow. On average, the update step of the metric takes 23.4 seconds. However, the first 3 updates are very fast (<1 second), then they become slow.
I tried to reproduce this behavior in a MWE:
Here I get an average update duration of 0.03 seconds, so I do not encounter the extremely slow update as in my
LightningModule
above. To me this looks like there is something wrong. At this point, I am not sure if thiHere some training information for my pytorch-lightning training pipeline:
My package versions:
Thanks so much!
Lukas
How to reproduce the bug
No response
Error messages and logs
No response
Environment
No response
More info
No response
cc @carmocca @justusschock @awaelchli @Borda
The text was updated successfully, but these errors were encountered: