Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fixed a bug where validation metrics could be aggregated together with test metrics #900

Merged
merged 3 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where test metrics were not logged correctly with active learning ([#879](https://github.com/PyTorchLightning/lightning-flash/pull/879))


- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900))


## [0.5.1] - 2021-10-26

### Added
Expand Down
3 changes: 2 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(

self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.test_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")
Expand Down Expand Up @@ -454,7 +455,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
)

def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx, self.val_metrics)
output = self.step(batch, batch_idx, self.test_metrics)
self.log_dict(
{f"test_{k}": v for k, v in output[OutputKeys.LOGS].items()},
on_step=False,
Expand Down
8 changes: 8 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,21 @@ def i_will_create_a_misconfiguration_exception(optimizer):
def test_classification_task_metrics():
train_dataset = FixedDataset([0, 1])
val_dataset = FixedDataset([1, 1])
test_dataset = FixedDataset([0, 0])

model = OnesModel()

class CheckAccuracy(Callback):
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["val_accuracy"], 1.0)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["test_accuracy"], 0.0)

task = ClassificationTask(model)
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count())
trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))
trainer.test(task, dataloaders=DataLoader(test_dataset))