Skip to content

Commit

Permalink
fix: get_minibatch() of SupervisedNE
Browse files Browse the repository at this point in the history
This commit fixes/improves SupervisedNE in two
ways:

1)
The `get_minibatch()` method of `SupervisedNE`
no longer uses recursion to handle the end of
the data loader's minibatches. The new
implementation catches only `StopIteration`
and non-recursively restarts the iterator of the
data loader. Any other type of error is now
deliberately unhandled to avoid unwanted
infinite recursion and to allow the user to see
the details of the error.

2)
The example notebook `Training_MNIST30K.ipynb`
demonstrating the usage of `SupervisedNE` is
updated so that its hyperparameter configuration
follows what is reported in the technical report
of EvoTorch. This way, results reported in the
report can be reproduced.

Co-authored-by: Timothy Atkinson <timothy@nnaisense.com>
Co-authored-by: Nihat Engin Toklu <engin@nnaisense.com>
  • Loading branch information
engintoklu and Timothy Atkinson committed Mar 24, 2023
1 parent e8060ff commit 8078df0
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 28 deletions.
71 changes: 50 additions & 21 deletions examples/notebooks/Training_MNIST30K.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from evotorch.neuroevolution.net import count_parameters\n",
"\n",
"class MNIST30K(nn.Module):\n",
" \n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" # The first convolution uses a 5x5 kernel and has 16 filters\n",
Expand All @@ -42,9 +43,8 @@
" # Another max pooling is applied with a kernel size of 2\n",
" self.pool2 = nn.MaxPool2d(kernel_size = 2)\n",
" \n",
" # The authors are unclear about when they apply batchnorm. \n",
" # As a result, we will apply after the second pool\n",
" self.norm = nn.BatchNorm1d(1568, affine = False)\n",
" # Apply layer normalization after the second pool\n",
" self.norm = nn.LayerNorm(1568, elementwise_affine=False)\n",
" \n",
" # A final linear layer maps outputs to the 10 target classes\n",
" self.out = nn.Linear(1568, 10)\n",
Expand Down Expand Up @@ -80,7 +80,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from torchvision import datasets, transforms\n",
Expand All @@ -105,7 +107,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from evotorch.neuroevolution import SupervisedNE\n",
Expand All @@ -114,7 +118,7 @@
" train_dataset, # Using the dataset specified earlier\n",
" MNIST30K, # Training the MNIST30K module designed earlier\n",
" nn.CrossEntropyLoss(), # Minimizing CrossEntropyLoss\n",
" minibatch_size = 256, # With a minibatch size of 256\n",
" minibatch_size = 1024, # With a minibatch size of 1024\n",
" common_minibatch = True, # Always using the same minibatch across all solutions on an actor\n",
" num_actors = 4, # The total number of CPUs used\n",
" num_gpus_per_actor = 'max', # Dividing all available GPUs between the 4 actors\n",
Expand All @@ -129,17 +133,31 @@
"## Training\n",
"Now we can set up the searcher.\n",
"\n",
"In the paper, they used SNES with, effectively, default parameters, and standard deviation 1. The authors achieved 98%+ with only a population size of 1k, but you can push this value higher as you wish. Note that by using the `distributed = True` keyword argument, we obtain semi-updates from the individual actors which are averaged."
"In the paper, they used SNES with, effectively, default parameters, and standard deviation 1. The authors achieved 98%+ with only a population size of 1k, but this value can be pushed further. Note that by using the `distributed = True` keyword argument, we obtain semi-updates from the individual actors which are averaged.\n",
"\n",
"In our example, we use PGPE with a population size of 3200. Hyperparameter configuration can be seen below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from evotorch.algorithms import SNES\n",
"searcher = SNES(mnist_problem, stdev_init = 1, popsize = 1000, distributed = True)"
"from evotorch.algorithms import PGPE\n",
"\n",
"searcher = PGPE(\n",
" mnist_problem,\n",
" radius_init=2.25, # Initial radius of the search distribution\n",
" center_learning_rate=1e-2, # Learning rate used by adam optimizer\n",
" stdev_learning_rate=0.1, # Learning rate for the standard deviation\n",
" popsize=3200,\n",
" distributed=True, # Gradients are computed locally at actors and averaged\n",
" optimizer=\"adam\", # Using the adam optimizer\n",
" ranking_method=None, # No rank-based fitness shaping is used\n",
")"
]
},
{
Expand All @@ -152,7 +170,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from evotorch.logging import StdOutLogger, PandasLogger\n",
Expand All @@ -164,16 +184,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Running evolution for 2000 generations (note that in the paper, it was 10k generations)..."
"Running evolution for 400 generations (note that in the paper, it was 10k generations)..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"searcher.run(2000)"
"searcher.run(400)"
]
},
{
Expand All @@ -186,7 +208,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pandas_logger.to_dataframe().mean_eval.plot()"
Expand All @@ -202,10 +226,13 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"net = mnist_problem.parameterize_net(searcher.status['center']).cpu()\n",
"#net = mnist_problem.parameterize_net(searcher.status['center']).cpu()\n",
"net = mnist_problem.make_net(searcher.status[\"center\"]).cpu()\n",
"\n",
"loss = torch.nn.CrossEntropyLoss()\n",
"net.eval()\n",
Expand All @@ -228,7 +255,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mnist_problem.kill_actors()"
Expand Down Expand Up @@ -260,7 +289,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down
17 changes: 10 additions & 7 deletions src/evotorch/neuroevolution/supervisedne.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __init__(

self.dataset = dataset
self.dataloader: DataLoader = None
self.dataloader_iterator = None

self._loss_func = loss_func
self._minibatch_size = None if minibatch_size is None else int(minibatch_size)
Expand Down Expand Up @@ -314,16 +315,18 @@ def get_minibatch(self) -> Any:
if self.dataloader is None:
self._prepare()

if self.dataloader_iterator is None:
self.dataloader_iterator = iter(self.dataloader)

batch = None
try:
batch = next(self.dataloader_iterator)
if batch is None:
self.dataloader_iterator = iter(self.dataloader)
batch = self.get_minibatch()
else:
batch = batch
except Exception:
except StopIteration:
pass

if batch is None:
self.dataloader_iterator = iter(self.dataloader)
batch = self.get_minibatch()
batch = next(self.dataloader_iterator)

# Move batch to device of network
return [var.to(self.network_device) for var in batch]
Expand Down

0 comments on commit 8078df0

Please sign in to comment.