From 27cc06de64c1b0c53a4ed91c9623ede2bf274f03 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 13 Jul 2021 15:26:01 +0100 Subject: [PATCH] Add option for nested tasks (#575) * Add option for nested tasks * Update CHANGELOG.md * Update CHANGELOG.md * Updates * Add grandparent test --- CHANGELOG.md | 2 ++ flash/core/model.py | 12 ++++++++++++ tests/core/test_model.py | 41 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d2fff491e..aded4ca732 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for Semantic Segmentation backbones and heads from `segmentation-models.pytorch` ([#562](https://github.com/PyTorchLightning/lightning-flash/pull/562)) +- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/flash/core/model.py b/flash/core/model.py index 76db8a189a..8e1dc45686 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -158,6 +158,18 @@ def __init__( self.deserializer = deserializer self.serializer = serializer + self._children = [] + + def __setattr__(self, key, value): + if isinstance(value, LightningModule): + self._children.append(key) + patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"] + if isinstance(value, pl.Trainer) or key in patched_attributes: + if hasattr(self, "_children"): + for child in self._children: + setattr(getattr(self, child), key, value) + super().__setattr__(key, value) + def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """ The training/validation/test step. Override for custom behavior. diff --git a/tests/core/test_model.py b/tests/core/test_model.py index ec6437f038..eb04ecdb68 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -98,6 +98,32 @@ def forward(self, x): return x * self.zeros + self.zero_one +class Parent(ClassificationTask): + + def __init__(self, child): + super().__init__() + + self.child = child + + def training_step(self, batch, batch_idx): + return self.child.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.child.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.child.test_step(batch, batch_idx) + + def forward(self, x): + return self.child(x) + + +class GrandParent(Parent): + + def __init__(self, child): + super().__init__(Parent(child)) + + # ================================ @@ -113,6 +139,21 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): assert "test_nll_loss" in result[0] +@pytest.mark.parametrize("task", [Parent, GrandParent]) +def test_nested_tasks(tmpdir, task): + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + val_dl = torch.utils.data.DataLoader(DummyDataset()) + child_task = ClassificationTask(model, loss_fn=F.nll_loss) + + parent_task = task(child_task) + + trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(parent_task, train_dl, val_dl) + result = trainer.test(parent_task, val_dl) + assert "test_nll_loss" in result[0] + + def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) task = ClassificationTask(model, preprocess=DefaultPreprocess())