diff --git a/CHANGELOG.md b/CHANGELOG.md index 877962446e..14cd73c12b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493)) - +- Fixed a bug where train and validation metrics weren't being correctly computed ([#559](https://github.com/PyTorchLightning/lightning-flash/pull/559)) ## [0.4.0] - 2021-06-22 diff --git a/flash/core/model.py b/flash/core/model.py index 2c4c2b6ada..76db8a189a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -140,7 +140,8 @@ def __init__( self.optimizer_kwargs = optimizer_kwargs or {} self.scheduler_kwargs = scheduler_kwargs or {} - self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) + self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) + self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") @@ -157,7 +158,7 @@ def __init__( self.deserializer = deserializer self.serializer = serializer - def step(self, batch: Any, batch_idx: int) -> Any: + def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """ The training/validation/test step. Override for custom behavior. """ @@ -168,7 +169,7 @@ def step(self, batch: Any, batch_idx: int) -> Any: losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} logs = {} y_hat = self.to_metrics_format(output["y_hat"]) - for name, metric in self.metrics.items(): + for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) logs[name] = metric # log the metric itself if it is of type Metric @@ -195,16 +196,16 @@ def forward(self, x: Any) -> Any: return self.model(x) def training_step(self, batch: Any, batch_idx: int) -> Any: - output = self.step(batch, batch_idx) + output = self.step(batch, batch_idx, self.train_metrics) self.log_dict({f"train_{k}": v for k, v in output["logs"].items()}, on_step=True, on_epoch=True, prog_bar=True) return output["loss"] def validation_step(self, batch: Any, batch_idx: int) -> None: - output = self.step(batch, batch_idx) + output = self.step(batch, batch_idx, self.val_metrics) self.log_dict({f"val_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Any, batch_idx: int) -> None: - output = self.step(batch, batch_idx) + output = self.step(batch, batch_idx, self.val_metrics) self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) @predict_context diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index e1da47be55..80e4094cf3 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -97,10 +97,10 @@ def to_metrics_format(self, x) -> torch.Tensor: x = x.logits return super().to_metrics_format(x) - def step(self, batch, batch_idx) -> dict: + def step(self, batch, batch_idx, metrics) -> dict: target = batch.pop("labels") batch = (batch, target) - return super().step(batch, batch_idx) + return super().step(batch, batch_idx, metrics) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: return self(batch) diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 8e05069a2b..5819b6bf2a 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -146,8 +146,8 @@ def on_train_epoch_start(self) -> None: encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch) super().on_train_epoch_start() - def step(self, batch: Any, batch_idx: int) -> Any: - return super().step((batch["video"], batch["label"]), batch_idx) + def step(self, batch: Any, batch_idx: int, metrics) -> Any: + return super().step((batch["video"], batch["label"]), batch_idx, metrics) def forward(self, x: Any) -> Any: x = self.backbone(x) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 6336bdfb06..ec6437f038 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from numbers import Number from pathlib import Path from typing import Any, Tuple @@ -20,6 +21,7 @@ import pytest import pytorch_lightning as pl import torch +from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F @@ -68,6 +70,34 @@ class DummyPostprocess(Postprocess): pass +class FixedDataset(torch.utils.data.Dataset): + + def __init__(self, targets): + super().__init__() + + self.targets = targets + + def __getitem__(self, index: int) -> Tuple[Tensor, Number]: + return torch.rand(1), self.targets[index] + + def __len__(self) -> int: + return len(self.targets) + + +class OnesModel(nn.Module): + + def __init__(self): + super().__init__() + + self.layer = nn.Linear(1, 2) + self.register_buffer('zeros', torch.zeros(2)) + self.register_buffer('zero_one', torch.tensor([0.0, 1.0])) + + def forward(self, x): + x = self.layer(x) + return x * self.zeros + self.zero_one + + # ================================ @@ -249,3 +279,19 @@ def test_optimization(tmpdir): assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) expected = get_linear_schedule_with_warmup.__name__ assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected + + +def test_classification_task_metrics(): + train_dataset = FixedDataset([0, 1]) + val_dataset = FixedDataset([1, 1]) + + model = OnesModel() + + class CheckAccuracy(Callback): + + def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + assert math.isclose(trainer.callback_metrics['train_accuracy_epoch'], 0.5) + + task = ClassificationTask(model) + trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy()) + trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))