Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add option for nested tasks (#575)
Browse files Browse the repository at this point in the history
* Add option for nested tasks

* Update CHANGELOG.md

* Update CHANGELOG.md

* Updates

* Add grandparent test
  • Loading branch information
ethanwharris authored Jul 13, 2021
1 parent fc3263f commit 27cc06d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for Semantic Segmentation backbones and heads from `segmentation-models.pytorch` ([#562](https://github.com/PyTorchLightning/lightning-flash/pull/562))

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
12 changes: 12 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def __init__(
self.deserializer = deserializer
self.serializer = serializer

self._children = []

def __setattr__(self, key, value):
if isinstance(value, LightningModule):
self._children.append(key)
patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results"]
if isinstance(value, pl.Trainer) or key in patched_attributes:
if hasattr(self, "_children"):
for child in self._children:
setattr(getattr(self, child), key, value)
super().__setattr__(key, value)

def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
"""
The training/validation/test step. Override for custom behavior.
Expand Down
41 changes: 41 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,32 @@ def forward(self, x):
return x * self.zeros + self.zero_one


class Parent(ClassificationTask):

def __init__(self, child):
super().__init__()

self.child = child

def training_step(self, batch, batch_idx):
return self.child.training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
return self.child.validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
return self.child.test_step(batch, batch_idx)

def forward(self, x):
return self.child(x)


class GrandParent(Parent):

def __init__(self, child):
super().__init__(Parent(child))


# ================================


Expand All @@ -113,6 +139,21 @@ def test_classificationtask_train(tmpdir: str, metrics: Any):
assert "test_nll_loss" in result[0]


@pytest.mark.parametrize("task", [Parent, GrandParent])
def test_nested_tasks(tmpdir, task):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
child_task = ClassificationTask(model, loss_fn=F.nll_loss)

parent_task = task(child_task)

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(parent_task, train_dl, val_dl)
result = trainer.test(parent_task, val_dl)
assert "test_nll_loss" in result[0]


def test_classificationtask_task_predict():
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
task = ClassificationTask(model, preprocess=DefaultPreprocess())
Expand Down

0 comments on commit 27cc06d

Please sign in to comment.