Build and train a Generative Adversarial Networks (GAN) to generate fake MNIST digit images
⚡Image Generation
⚡Generative Adverserial Network (GAN)
⚡Fully Connected Neural Network Layers
⚡MNIST
⚡PyTorch
The Generative Adversarial Networks (GAN) are a special neural network where two networks compete to maximize their gain (i.e., minimize their losses in machine-learning lingo). GAN's objective is to generate the synthetic or fake data based on the training data it learns from, hence a Generative
network. Since two networks are fighting with each other, the network is called Adversarial.
A simple GAN is composed of a dataset (MNIST in our case) and two neural networks. A Generator
network takes a random noise
vector and learns to transform this noise such that its probability distribution eventually comes very close to the data it's learning from. A Discriminator
network is a neural network-based classifier. Both Generator
and Discriminator
train in parallel. Discriminator
gets to see images from actual MNIST dataset as well fake pictures generated by the Generator,
and its job is to correctly classify real images vs. fake images. At the same time, based on how Discriminator
is performing in real vs. fake classification, Generator
keeps improving its generated images to make them look like the images taken from the real dataset. Basically, Generator
tries to fool the Discriminator
and Discriminator
tries not to get fooled.
A typical GAN network is shown below...
Have a look at GAN Original Paper for more details
. In addition, there are some exciting applications of GANs one can explore...
- StackGAN: Generates 256x256 photo-realistic images conditioned on text descriptions. For example, if a description of a bird is provided, it will be able to generate an image of the bird based on the description; this will be an imaginary picture of a bird that does not exist in real words but looks real.
- iGAN: Based on user-drawn simple sketches, iGAN could produce photo-realistic samples that best matches the user sketch in real-time.
- CartoonGAN: CartoonGAN can transform the photos of real-world scenes into cartoon style images. For example, it can take an image of a real person's face and produce a catoonized image of that face.
- CycleGAN: CycleGAN performs an image-to-image translation. Model trains in an
unsupervised
manner using a collection of images from the source and target domain that do not need to be related in any way. Models trains without any paired examples. For example, it can take a photo of a summer landscape and transform the same image into a winter landscape.
Our goal in this project is to build a GAN and train it over the MNIST dataset so that the network learns to generate fake MNIST digit images that would have come from the real MNIST dataset.
- Dataset consists of 60,000 training images and 10,000 testing images. For this task, we'll use the training dataset. Every image in the dataset will belong to one of the ten classes (digit 0 to 9); however, the image labels do not matter for this task, and we won't use them.
- Each image in the dataset is a 28x28 pixel grayscale image, a zoomed-in single image shown below...
Few samples of real MNIST images are shown below...
We will use the in-built MNIST dataset from PyTorch's torchvision .
it's a clean, pre-processed dataset that pairs the image and respective label nicely; labels are not required for the task at hand; they will be ignored. Alternatively, the raw dataset can be downloaded from the original source here. The raw dataset comes as a set of zip files containing training images, training images, testing images, and testing images in separate files.
- Data is downloaded using the
torchvision
dataset. - The training dataset is then wrapped in a dataloader object with a
batch_size
of 64. We discard the testing dataset. Note that even though the dataloader will give us the images associated, we'll simply ignore them.
- First, we define the
Discriminator
network- Its a typical binary classifier where it'd accept 784 (28x28) inputs and produces a single
logit
output that's used to classify the input image as real (1) or fake (0) - Network has four fully-connected
Linear
layers withLeakyReLU
activation having anegative-slope
of 0.2. Furthermore, adropout
of 30% is applied after each linear layer except the last one. The activation function and negative slope value are based on the GAN paper recommendation. - The
forward
method of the Discriminator flattens the input it receives to make it a tensor of shape (batch_size, 784); this is then passed through the network.
- Its a typical binary classifier where it'd accept 784 (28x28) inputs and produces a single
- Then, we define the
Generator
network- Generator will consume a
noise
vector (z of length 100) and up-sample it by passing it through various layers of the network. - Network has four fully-connected
Linear
layers withLeakyReLU
activation having anegative-slope
of 0.2. Adropout
of 30% is applied after each linear layer except the last one The final linear layer output is then passed through thetanh
function to produce the final output of the Generator between -1 and 1. This is again as per the original GAN paper.
- Generator will consume a
The complete structure of the Discriminator
and Generator
network are shown below...
Once we have both Discriminator
and Generator
networks ready, they are assembled to build the MNIST GAN as shown below..
GAN training is a bit different compared to our typical supervised neural-network training. In the case of GAN, two separate networks are being trained together, and this network has different and opposing objectives (i.e., they are competing). The Discriminator
is trying to identify if the image sample is real and from the actual MNIST dataset or its fake images generated by our Generator.
Note that we are NOT interested incorrectly classifying the digits themselves. We'd need to define two separate loss functions.
- real_loss: calculates loss when images are drawn from the actual dataset. The predicted output is compared against the target label
1
, indicating real images. - fake_loss: calculates loss when images are generated by the Generator. The predicted output is compared against the target label
0
, indicating fake images.
Discriminator
computes both of the above losses and adds them together to get a total-loss
for back-propagation
Generator
computes the real_loss
to check its success in fooling the Discriminator.
In other words, even though it generates fake images (target 0), by computing real_loss, it compares Discriminator's output with 1
. In effect, generator loss has its labels flipped.
Since we are training two separate networks, we need two separate optimizers for each network. In both cases, we use Adam
optimizer with a learning-rate
of 0.002.
- Since classification is between two classes (real and fake) and our Discriminator outputs a
logit,
we useBCEWithLogitsLoss
as the loss function. This function internally first applies asigmoid
activation to logits and then calculates the loss using theBCELoss
(log loss) function. Before starting training, we create a (16 x 100)fixed-random-noise-vector
drawn from anormal distribution
between range-1 and 1
. This vector is kept fixed throughout the training. After each epoch of training, we feed the noise vector to, so far, trained Generator to generate fake images; these images help us visualize how and if generated image quality is improving or not. A sample of the noise vector is shown below..
Discriminator
is trained as follows...- A batch of
real
MNIST images are drawn from the dataloader - Each image in the batch is then
scaled
to values between-1 and 1
. This is a crucial step and required because Discriminator looks at real images from MNIST dataset and looks at fake images from Generator whose output is in range-1 to 1
(last layer output Generator network istanh
activated). So we need to ensure that the range of input values is consistent in both cases. - data batch is then fed to Discriminator, its predicted output is captured, and
real_loss
is calculated - A batch noise data (
z
) drawn from anormal distribution
between range-1 and 1
is created - Noise
z
is then fed through the Generator, its outputs (fake images) are captured, andfake_loss
is calculated - Then discriminator's
total_loss
is computed asreal_loss + fake_loss
- Finally,
total_loss
is back-propagated using Discriminator's optimizer
- A batch of
- After one batch of
Discriminator
training (above), theGenerator
is trained as follows...- A batch of noise data (
z
) drawn from anormal distribution
between range-1 and 1
is created - Noise
z
is then fed through the Generator, its outputs (fake images) are captured - The generated fake images are then fed through the
Discriminator,
and its predicted output is captured, andreal_loss
is calculated - Note that for fake generated images we are calculating
real_loss
(and not fake_loss) as discussed in Loss Definition section above - Above computed loss is then back-propagated using Generator's optimizer
- A batch of noise data (
- At the end of each epoch, the
fixed-random-noise-vector
is fed to the trained Generator to produce a batch of fake images; we then save these images as (fixed_samples.pkl
). We can load and view these protected images later for further analysis. After training our GAN network for100
epochs, we plot both Generator and Discriminator losses and it looks like this...
The above plot does not look like a typical neural-network training loss plot. There are huge fluctuations at the beginning, and it's very wobbly after that. This behavior is very typical of GAN training and expected. At the end of the 100th epoch, the discriminator loss (in red) seems to be going down, and generator loss (in blue) increases. It's possible that if we train our network for even more epochs, losses may converge, indicating an equilibrium between the competing networks. Equilibrium
is the fundamental idea behind GAN, which is drawn from The Game Theory that suggests that competing rationale agents will ultimately reach an equilibrium where they can't improve anymore.
Let's visualize the intermediate generator outputs that we saved during the training. This will show us how Generator learns and generates better fake images as training progresses.
Generator output after 1 epoch
of training...
Generator output after '10 epochs` of training...
Generator output after '50 epochs` of training...
Generator output after 100 epochs
of training...
We can see how Generator has improved from generating random noisy blobs after 1st epoch to realistic-looking MNIST digits after 100 epochs.
- Ensure the below-listed packages are installed
NumPy
pickle
matplotlib
torch
torchvision
- Download
mnist_gan.ipynb
jupyter notebook from this repo - Execute the notebook from start to finish in one go. If a GPU is available (recommended), it'll use it automatically; otherwise, it'll fall back to the CPU.
- A machine with
NVIDIA Quadro P5000
GPU with 16GB memory takes approximately 20 minutes to train for 100 epochs. - Longer training will yield better results
- A trained model can be used to generate fake MNIST digits, as shown below...
# Bring generator back to cpu and set eval mode on
g.to('cpu')
g.eval()
# Feed a latent vecor of size 100 to trained generator and get a fake generated image back
z = np.random.uniform(-1, 1, size=(1, 100))
z = torch.from_numpy(z).float()
fake_image = g(z)
# Reshape and display
fake_image = fake_image.view(1, 1, 28, 28).detach()
display_images(fake_image, n_cols=1, figsize=(2, 2))
A randomly generated fake image using the above code is shown below; we can see that the generated fake image looks very similar to the actual number 3
image shown above...