Skip to content

Commit

Permalink
Modify Workflow to Allow IterableDataset Inputs
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Kerfoot <eric.kerfoot@kcl.ac.uk>
  • Loading branch information
ericspod committed Dec 13, 2024
1 parent 21920a3 commit 1d2f57e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
21 changes: 10 additions & 11 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,23 @@ def __init__(
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
) -> None:
if iteration_update is not None:
super().__init__(iteration_update)
else:
super().__init__(self._iteration)
super().__init__(self._iteration if iteration_update is None else iteration_update)

if isinstance(data_loader, DataLoader):
sampler = data_loader.__dict__["sampler"]
if isinstance(sampler, DistributedSampler):
sampler = getattr(data_loader, "sampler", None)

# set the epoch value for DistributedSampler objects when an epoch starts
if isinstance(sampler, DistributedSampler):
@self.on(Events.EPOCH_STARTED)
def set_sampler_epoch(engine: Engine) -> None:
sampler.set_epoch(engine.state.epoch)

# if the epoch_length isn't given, attempt to get it from the length of the data loader
if epoch_length is None:
epoch_length = len(data_loader)
else:
if epoch_length is None:
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
try:
epoch_length = len(data_loader)
except TypeError: # raised when data_loader is given an iterable dataset which has no length
pass # deliberately leave epoch_length as None

# set all sharable data for the workflow based on Ignite engine.state
self.state: Any = State(
Expand All @@ -147,7 +146,7 @@ def set_sampler_epoch(engine: Engine) -> None:
iteration=0,
epoch=0,
max_epochs=max_epochs,
epoch_length=epoch_length,
epoch_length=epoch_length, # None when the dataset is iterable and so has no length
output=None,
batch=None,
metrics={},
Expand Down
14 changes: 14 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

from monai.data import DataLoader, Dataset, IterableDataset
from monai.transforms import Compose, LoadImaged, SimulateDelayd
from monai.engines import SupervisedEvaluator

import torch.nn as nn


class _Stream:
Expand Down Expand Up @@ -59,6 +62,17 @@ def test_shape(self):
for d in dataloader:
self.assertTupleEqual(d["image"].shape[1:], expected_shape)

def test_supervisedevaluator(self):
"""
Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader.
"""
data = list(range(10))
dl = DataLoader(IterableDataset(data))
evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity())
evaluator.run() # fails if the epoch length or other internal setup is not done correctly

self.assertEqual(evaluator.state.iteration, len(data))


if __name__ == "__main__":
unittest.main()

0 comments on commit 1d2f57e

Please sign in to comment.