From fb85bfd97998905e3a6d80cb809a79ab9065438a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 14:35:07 +0100 Subject: [PATCH 1/2] Fix test metric logging --- CHANGELOG.md | 2 ++ flash/core/model.py | 3 ++- tests/core/test_model.py | 8 ++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58c5b77345..8cb0f4793d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases + ## [0.5.1] - 2021-10-26 ### Added diff --git a/flash/core/model.py b/flash/core/model.py index 20da95d285..6f87bcb4c3 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -354,6 +354,7 @@ def __init__( self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) + self.test_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") @@ -454,7 +455,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None: ) def test_step(self, batch: Any, batch_idx: int) -> None: - output = self.step(batch, batch_idx, self.val_metrics) + output = self.step(batch, batch_idx, self.test_metrics) self.log_dict( {f"test_{k}": v for k, v in output[OutputKeys.LOGS].items()}, on_step=False, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 0e68344bb5..f31bba3e70 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -437,6 +437,7 @@ def i_will_create_a_misconfiguration_exception(optimizer): def test_classification_task_metrics(): train_dataset = FixedDataset([0, 1]) val_dataset = FixedDataset([1, 1]) + test_dataset = FixedDataset([0, 0]) model = OnesModel() @@ -444,6 +445,13 @@ class CheckAccuracy(Callback): def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5) + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert math.isclose(trainer.callback_metrics["val_accuracy"], 1.0) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert math.isclose(trainer.callback_metrics["test_accuracy"], 0.0) + task = ClassificationTask(model) trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count()) trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset)) + trainer.test(task, dataloaders=DataLoader(test_dataset)) From 9c3257a6e4d0d94bfda9634505e1ac171670db46 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 14:36:45 +0100 Subject: [PATCH 2/2] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cb0f4793d..9be71fdaad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases +- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900)) ## [0.5.1] - 2021-10-26