Skip to content

Commit

Permalink
Issue 871 (#985)
Browse files Browse the repository at this point in the history
* Fixes #871
- added tests
- updated FAQ
- added FAQ's code into tests

* Update faq.rst

* Updated Engine's docstring

* Fixes failing test

* Update faq.rst
  • Loading branch information
vfdev-5 authored Apr 28, 2020
1 parent 67eb7b8 commit e4cbd93
Show file tree
Hide file tree
Showing 4 changed files with 484 additions and 11 deletions.
240 changes: 238 additions & 2 deletions docs/source/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ flexibility to the user to allow for this:
engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)
optimizer.step()
engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED)
return loss.item()
trainer = Engine(update)
Expand All @@ -83,7 +83,7 @@ More detailed implementation can be found in `TBPTT Trainer <_modules/ignite/con
Gradients accumulation
----------------------

A best practice to use if we need to increase effectively the batchsize on limited GPU resources. There several ways to
A best practice to use if we need to increase effectively the batch size on limited GPU resources. There several ways to
do this, the most simple is the following:

.. code-block:: python
Expand All @@ -110,5 +110,241 @@ Based on `this blog article <https://medium.com/huggingface/training-larger-batc
`this code <https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3#file-gradient_accumulation-py>`_.


Working with iterators
----------------------

If data provider for training or validation is an iterator (infinite or finite with known or unknown size), here are
basic examples of how to setup trainer or evaluator.


Infinite iterator for training
``````````````````````````````

Let's use an infinite data iterator as training dataflow

.. code-block:: python
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def infinite_iterator(batch_size):
while True:
batch = torch.rand(batch_size, 3, 32, 32)
yield batch
def train_step(trainer, batch):
# ...
s = trainer.state
print(
"{}/{} : {} - {:.3f}".format(s.epoch, s.max_epochs, s.iteration, batch.norm())
)
trainer = Engine(train_step)
# We need to specify epoch_length to define the epoch
trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=3)
In this case we will obtain the following output:

.. code-block:: text
1/3 : 1 - 63.862
1/3 : 2 - 64.042
1/3 : 3 - 63.936
1/3 : 4 - 64.141
1/3 : 5 - 64.767
2/3 : 6 - 63.791
2/3 : 7 - 64.565
2/3 : 8 - 63.602
2/3 : 9 - 63.995
2/3 : 10 - 63.943
3/3 : 11 - 63.831
3/3 : 12 - 64.276
3/3 : 13 - 64.148
3/3 : 14 - 63.920
3/3 : 15 - 64.226
If we do not specify `epoch_length`, we can stop the training explicitly by calling :meth:`~ignite.engine.Engine.terminate`
In this case, there will be only a single epoch defined.

.. code-block:: python
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def infinite_iterator(batch_size):
while True:
batch = torch.rand(batch_size, 3, 32, 32)
yield batch
def train_step(trainer, batch):
# ...
s = trainer.state
print(
"{}/{} : {} - {:.3f}".format(s.epoch, s.max_epochs, s.iteration, batch.norm())
)
trainer = Engine(train_step)
@trainer.on(Events.ITERATION_COMPLETED(once=15))
def stop_training():
trainer.terminate()
trainer.run(infinite_iterator(4))
We obtain the following output:

.. code-block:: text
1/1 : 1 - 63.862
1/1 : 2 - 64.042
1/1 : 3 - 63.936
1/1 : 4 - 64.141
1/1 : 5 - 64.767
1/1 : 6 - 63.791
1/1 : 7 - 64.565
1/1 : 8 - 63.602
1/1 : 9 - 63.995
1/1 : 10 - 63.943
1/1 : 11 - 63.831
1/1 : 12 - 64.276
1/1 : 13 - 64.148
1/1 : 14 - 63.920
1/1 : 15 - 64.226
Same code can be used for validating models.


Finite iterator with unknown length
```````````````````````````````````

Let's use a finite data iterator but with unknown length (for user). In case of training, we would like to perform
several passes over the dataflow and thus we need to restart the data iterator when it is exhausted.
In the code, we do not specify `epoch_length` which will be automatically determined.

.. code-block:: python
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def finite_unk_size_data_iter():
for i in range(11):
yield i
def train_step(trainer, batch):
# ...
s = trainer.state
print(
"{}/{} : {} - {:.3f}".format(s.epoch, s.max_epochs, s.iteration, batch)
)
trainer = Engine(train_step)
@trainer.on(Events.DATALOADER_STOP_ITERATION)
def restart_iter():
trainer.state.dataloader = finite_unk_size_data_iter()
data_iter = finite_unk_size_data_iter()
trainer.run(data_iter, max_epochs=5)
In case of validation, the code is simply

.. code-block:: python
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
def finite_unk_size_data_iter():
for i in range(11):
yield i
def val_step(evaluator, batch):
# ...
s = evaluator.state
print(
"{}/{} : {} - {:.3f}".format(s.epoch, s.max_epochs, s.iteration, batch)
)
evaluator = Engine(val_step)
data_iter = finite_unk_size_data_iter()
evaluator.run(data_iter)
Finite iterator with known length
`````````````````````````````````

Let's use a finite data iterator with known size for training or validation.
If we need to restart the data iterator, we can do this either as in case of
unknown size by attaching the restart handler on `@trainer.on(Events.DATALOADER_STOP_ITERATION)`,
but here we will do this explicitly on iteration:

.. code-block:: python
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
def train_step(trainer, batch):
# ...
s = trainer.state
print(
"{}/{} : {} - {:.3f}".format(s.epoch, s.max_epochs, s.iteration, batch)
)
trainer = Engine(train_step)
@trainer.on(Events.ITERATION_COMPLETED(every=size))
def restart_iter():
trainer.state.dataloader = finite_size_data_iter(size)
data_iter = finite_size_data_iter(size)
trainer.run(data_iter, max_epochs=5)
In case of validation, the code is simply

.. code-block:: python
import torch
from ignite.engine import Engine, Events
torch.manual_seed(12)
size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
def val_step(evaluator, batch):
# ...
s = evaluator.state
print(
"{}/{} : {} - {:.3f}".format(s.epoch, s.max_epochs, s.iteration, batch)
)
evaluator = Engine(val_step)
data_iter = finite_size_data_iter(size)
evaluator.run(data_iter)
Other answers can be found on the github among the issues labeled by
`question <https://github.com/pytorch/ignite/issues?utf8=%E2%9C%93&q=is%3Aissue+label%3Aquestion+>`_.
15 changes: 11 additions & 4 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,8 @@ def run(
If a new state should be created (first run or run again from ended engine), it's default value is 1.
This argument should be `None` if run is resuming from a state.
epoch_length (int, optional): Number of iterations to count as one epoch. By default, it can be set as
`len(data)`. If `data` is an iterator and `epoch_length` is not set, an error is raised.
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
determined as the iteration on which data iterator raises `StopIteration`.
This argument should be `None` if run is resuming from a state.
seed (int, optional): Deprecated argument. Please, use `torch.manual_seed` or
:meth:`~ignite.utils.manual_seed`.
Expand Down Expand Up @@ -590,8 +591,7 @@ def switch_batch(engine):
epoch_length = len(data)
if epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
else:
raise ValueError("Argument `epoch_length` should be defined if `data` is an iterator")

self.state = State(iteration=0, epoch=0, max_epochs=max_epochs, epoch_length=epoch_length)
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
else:
Expand Down Expand Up @@ -694,13 +694,20 @@ def _run_once_on_dataset(self) -> Tuple[int, int, int]:
iter_counter += 1
should_exit = False
except StopIteration:

if self._dataloader_len is None:
if iter_counter > 0:
self._dataloader_len = iter_counter
else:
# this can happen when data is finite iterator and epoch_length is equal to its size
self._dataloader_len = self.state.iteration

# Define self.state.epoch_length if it is not yet set
if self.state.epoch_length is None:
# Define epoch length and stop the epoch
self.state.epoch_length = iter_counter
break

# Should exit while loop if we can not iterate
if should_exit:
if not self._is_done(self.state):
Expand Down Expand Up @@ -734,7 +741,7 @@ def _run_once_on_dataset(self) -> Tuple[int, int, int]:
self._dataloader_iter = iter(self.state.dataloader)
break

if iter_counter == self.state.epoch_length:
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
break

except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion tests/ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ def __init__(self, data, init_counter=0):
def check(self, batch):
self.true_batch = self.data[self.counter % len(self.data)]
self.counter += 1
return (self.true_batch == batch).all()
res = self.true_batch == batch
return res.all() if not isinstance(res, bool) else res


class IterationCounter:
Expand Down
Loading

0 comments on commit e4cbd93

Please sign in to comment.