Skip to content

A Generative Adversarial Network (GAN) trained on the MNIST dataset, capable of creating fake but realistic looking MNIST digit images that appear to be drawn from the original dataset.

Notifications You must be signed in to change notification settings

sssingh/mnist-digit-generation-gan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MNIST Digit Generation using GAN

Build and train a Generative Adversarial Networks (GAN) to generate fake MNIST digit images

Features

⚡Image Generation
⚡Generative Adverserial Network (GAN)
⚡Fully Connected Neural Network Layers
⚡MNIST
⚡PyTorch

Table of Contents

Introduction

The Generative Adversarial Networks (GAN) are a special neural network where two networks compete to maximize their gain (i.e., minimize their losses in machine-learning lingo). GAN's objective is to generate the synthetic or fake data based on the training data it learns from, hence a Generative network. Since two networks are fighting with each other, the network is called Adversarial. A simple GAN is composed of a dataset (MNIST in our case) and two neural networks. A Generator network takes a random noise vector and learns to transform this noise such that its probability distribution eventually comes very close to the data it's learning from. A Discriminator network is a neural network-based classifier. Both Generator and Discriminator train in parallel. Discriminator gets to see images from actual MNIST dataset as well fake pictures generated by the Generator, and its job is to correctly classify real images vs. fake images. At the same time, based on how Discriminator is performing in real vs. fake classification, Generator keeps improving its generated images to make them look like the images taken from the real dataset. Basically, Generator tries to fool the Discriminator and Discriminator tries not to get fooled.

A typical GAN network is shown below...

Have a look at GAN Original Paper for more details
. In addition, there are some exciting applications of GANs one can explore...

  • StackGAN: Generates 256x256 photo-realistic images conditioned on text descriptions. For example, if a description of a bird is provided, it will be able to generate an image of the bird based on the description; this will be an imaginary picture of a bird that does not exist in real words but looks real.
  • iGAN: Based on user-drawn simple sketches, iGAN could produce photo-realistic samples that best matches the user sketch in real-time.
  • CartoonGAN: CartoonGAN can transform the photos of real-world scenes into cartoon style images. For example, it can take an image of a real person's face and produce a catoonized image of that face.
  • CycleGAN: CycleGAN performs an image-to-image translation. Model trains in an unsupervised manner using a collection of images from the source and target domain that do not need to be related in any way. Models trains without any paired examples. For example, it can take a photo of a summer landscape and transform the same image into a winter landscape.

Objective

Our goal in this project is to build a GAN and train it over the MNIST dataset so that the network learns to generate fake MNIST digit images that would have come from the real MNIST dataset.

Dataset

  • Dataset consists of 60,000 training images and 10,000 testing images. For this task, we'll use the training dataset. Every image in the dataset will belong to one of the ten classes (digit 0 to 9); however, the image labels do not matter for this task, and we won't use them.
  • Each image in the dataset is a 28x28 pixel grayscale image, a zoomed-in single image shown below...

Few samples of real MNIST images are shown below...

We will use the in-built MNIST dataset from PyTorch's torchvision . it's a clean, pre-processed dataset that pairs the image and respective label nicely; labels are not required for the task at hand; they will be ignored. Alternatively, the raw dataset can be downloaded from the original source here. The raw dataset comes as a set of zip files containing training images, training images, testing images, and testing images in separate files.

Solution Approach

Data Load

  • Data is downloaded using the torchvision dataset.
  • The training dataset is then wrapped in a dataloader object with a batch_size of 64. We discard the testing dataset. Note that even though the dataloader will give us the images associated, we'll simply ignore them.

Network Definition

  • First, we define the Discriminator network
    • Its a typical binary classifier where it'd accept 784 (28x28) inputs and produces a single logit output that's used to classify the input image as real (1) or fake (0)
    • Network has four fully-connected Linear layers with LeakyReLU activation having a negative-slope of 0.2. Furthermore, a dropout of 30% is applied after each linear layer except the last one. The activation function and negative slope value are based on the GAN paper recommendation.
    • The forward method of the Discriminator flattens the input it receives to make it a tensor of shape (batch_size, 784); this is then passed through the network.
  • Then, we define the Generator network
    • Generator will consume a noise vector (z of length 100) and up-sample it by passing it through various layers of the network.
    • Network has four fully-connected Linear layers with LeakyReLU activation having a negative-slope of 0.2. A dropout of 30% is applied after each linear layer except the last one The final linear layer output is then passed through the tanh function to produce the final output of the Generator between -1 and 1. This is again as per the original GAN paper.

The complete structure of the Discriminator and Generator network are shown below...



Once we have both Discriminator and Generator networks ready, they are assembled to build the MNIST GAN as shown below..



Loss Definition

GAN training is a bit different compared to our typical supervised neural-network training. In the case of GAN, two separate networks are being trained together, and this network has different and opposing objectives (i.e., they are competing). The Discriminator is trying to identify if the image sample is real and from the actual MNIST dataset or its fake images generated by our Generator. Note that we are NOT interested incorrectly classifying the digits themselves. We'd need to define two separate loss functions.

  1. real_loss: calculates loss when images are drawn from the actual dataset. The predicted output is compared against the target label 1, indicating real images.
  2. fake_loss: calculates loss when images are generated by the Generator. The predicted output is compared against the target label 0, indicating fake images.

Discriminator computes both of the above losses and adds them together to get a total-loss for back-propagation Generator computes the real_loss to check its success in fooling the Discriminator. In other words, even though it generates fake images (target 0), by computing real_loss, it compares Discriminator's output with 1. In effect, generator loss has its labels flipped.

Network Training

Since we are training two separate networks, we need two separate optimizers for each network. In both cases, we use Adam optimizer with a learning-rate of 0.002.

  • Since classification is between two classes (real and fake) and our Discriminator outputs a logit, we use BCEWithLogitsLoss as the loss function. This function internally first applies a sigmoid activation to logits and then calculates the loss using the BCELoss (log loss) function. Before starting training, we create a (16 x 100) fixed-random-noise-vector drawn from a normal distribution between range -1 and 1. This vector is kept fixed throughout the training. After each epoch of training, we feed the noise vector to, so far, trained Generator to generate fake images; these images help us visualize how and if generated image quality is improving or not. A sample of the noise vector is shown below..



  • Discriminator is trained as follows...
    • A batch of real MNIST images are drawn from the dataloader
    • Each image in the batch is then scaled to values between -1 and 1. This is a crucial step and required because Discriminator looks at real images from MNIST dataset and looks at fake images from Generator whose output is in range -1 to 1 (last layer output Generator network is tanh activated). So we need to ensure that the range of input values is consistent in both cases.
    • data batch is then fed to Discriminator, its predicted output is captured, and real_loss is calculated
    • A batch noise data (z) drawn from a normal distribution between range -1 and 1 is created
    • Noise z is then fed through the Generator, its outputs (fake images) are captured, and fake_loss is calculated
    • Then discriminator's total_loss is computed as real_loss + fake_loss
    • Finally, total_loss is back-propagated using Discriminator's optimizer
  • After one batch of Discriminator training (above), the Generator is trained as follows...
    • A batch of noise data (z) drawn from a normal distribution between range -1 and 1 is created
    • Noise z is then fed through the Generator, its outputs (fake images) are captured
    • The generated fake images are then fed through the Discriminator, and its predicted output is captured, and real_loss is calculated
    • Note that for fake generated images we are calculating real_loss (and not fake_loss) as discussed in Loss Definition section above
    • Above computed loss is then back-propagated using Generator's optimizer
  • At the end of each epoch, the fixed-random-noise-vector is fed to the trained Generator to produce a batch of fake images; we then save these images as (fixed_samples.pkl). We can load and view these protected images later for further analysis. After training our GAN network for 100 epochs, we plot both Generator and Discriminator losses and it looks like this...



The above plot does not look like a typical neural-network training loss plot. There are huge fluctuations at the beginning, and it's very wobbly after that. This behavior is very typical of GAN training and expected. At the end of the 100th epoch, the discriminator loss (in red) seems to be going down, and generator loss (in blue) increases. It's possible that if we train our network for even more epochs, losses may converge, indicating an equilibrium between the competing networks. Equilibrium is the fundamental idea behind GAN, which is drawn from The Game Theory that suggests that competing rationale agents will ultimately reach an equilibrium where they can't improve anymore.

Visualize Training Progress

Let's visualize the intermediate generator outputs that we saved during the training. This will show us how Generator learns and generates better fake images as training progresses.

Generator output after 1 epoch of training...



Generator output after '10 epochs` of training...



Generator output after '50 epochs` of training...



Generator output after 100 epochs of training...



We can see how Generator has improved from generating random noisy blobs after 1st epoch to realistic-looking MNIST digits after 100 epochs.

How To Use

  1. Ensure the below-listed packages are installed
    • NumPy
    • pickle
    • matplotlib
    • torch
    • torchvision
  2. Download mnist_gan.ipynb jupyter notebook from this repo
  3. Execute the notebook from start to finish in one go. If a GPU is available (recommended), it'll use it automatically; otherwise, it'll fall back to the CPU.
  4. A machine with NVIDIA Quadro P5000 GPU with 16GB memory takes approximately 20 minutes to train for 100 epochs.
  5. Longer training will yield better results
  6. A trained model can be used to generate fake MNIST digits, as shown below...
    # Bring generator back to cpu and set eval mode on
    g.to('cpu')
    g.eval()
    # Feed a latent vecor of size 100 to trained generator and get a fake generated image back
    z = np.random.uniform(-1, 1, size=(1, 100))
    z = torch.from_numpy(z).float()
    fake_image = g(z)
    # Reshape and display
    fake_image = fake_image.view(1, 1, 28, 28).detach()
    display_images(fake_image, n_cols=1, figsize=(2, 2))

A randomly generated fake image using the above code is shown below; we can see that the generated fake image looks very similar to the actual number 3 image shown above...

License

MIT License

Get in touch

email twitter linkedin website

Back To The Top

About

A Generative Adversarial Network (GAN) trained on the MNIST dataset, capable of creating fake but realistic looking MNIST digit images that appear to be drawn from the original dataset.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published