-
-
Notifications
You must be signed in to change notification settings - Fork 333
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. The original commits were messed up. This commit overwrites all previous commits. 2.The markdown files and the instructions to run the code is given.
- Loading branch information
1 parent
7d820e8
commit 81f6567
Showing
9 changed files
with
345 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Get the imports done | ||
using Flux, Flux.Data.MNIST | ||
using Flux: @epochs, back!, testmode!, throttle | ||
using Base.Iterators: partition | ||
using Distributions: Uniform,Normal | ||
using CUDAnative: tanh, log, exp | ||
using CuArrays | ||
using Images | ||
using Statistics | ||
|
||
# Define the hyperparameters | ||
BATCH_SIZE = 128 | ||
NUM_EPOCHS = 15 | ||
noise_dim = 100 | ||
channels = 128 | ||
hidden_dim = 7 * 7 * channels | ||
training_steps = 0 | ||
verbose_freq = 100 | ||
dis_lr = 0.0001f0 # Discriminator Learning Rate | ||
gen_lr = 0.0001f0 # Generator Learning Rate | ||
|
||
# Loading Data | ||
|
||
# We use Flux's built in MNIST Loader | ||
imgs = MNIST.images() | ||
|
||
# Partition into batches of size 'BATCH_SIZE' | ||
data = [reshape(float(hcat(vec.(imgs)...)),28,28,1,:) for imgs in partition(imgs, BATCH_SIZE)] | ||
|
||
# Define out distribution for random sampling for the generator to sample noise from | ||
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results | ||
|
||
expand_dims(x,n::Int) = reshape(x,ones(Int64,n)...,size(x)...) | ||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) | ||
|
||
# The Generator | ||
generator = Chain( | ||
Dense(noise_dim, 1024, leakyrelu), | ||
x->expand_dims(x,1), | ||
BatchNorm(1024), | ||
x->squeeze(x), | ||
Dense(1024, hidden_dim, leakyrelu), | ||
x->expand_dims(x,1), | ||
BatchNorm(hidden_dim), | ||
x->squeeze(x), | ||
x->reshape(x,7,7,channels,:), | ||
ConvTranspose((4,4), channels=>64, relu; stride=(2,2), pad=(1,1)), | ||
x->expand_dims(x,2), | ||
BatchNorm(64), | ||
x->squeeze(x), | ||
ConvTranspose((4,4), 64=>1, tanh; stride=(2,2), pad=(1,1)) | ||
) |> gpu | ||
|
||
# The Discriminator | ||
discriminator = Chain( | ||
Conv((3,3), 1=>32, leakyrelu;pad = 1), | ||
x->meanpool(x, (2,2)), | ||
Conv((3,3), 32=>64, leakyrelu;pad = 1), | ||
x->meanpool(x, (2,2)), | ||
x->reshape(x,7*7*64,:), | ||
Dense(7*7*64, 1024, leakyrelu), | ||
x->expand_dims(x,1), | ||
BatchNorm(1024), | ||
x->squeeze(x), | ||
Dense(1024, 1,sigmoid) | ||
) |> gpu | ||
|
||
# <b>Define the optimizers</b> | ||
|
||
opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5) | ||
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5) | ||
|
||
# <b>Utility functions to zero out our model gradients</b> | ||
function nullify_grad!(p) | ||
if typeof(p) <: TrackedArray | ||
p.grad .= 0.0f0 | ||
end | ||
return p | ||
end | ||
|
||
function zero_grad!(model) | ||
model = mapleaves(nullify_grad!, model) | ||
end | ||
|
||
# <b>Creating and Saving Utilities</b> | ||
|
||
img(x) = Gray.(reshape((x+1)/2, 28, 28)) # For denormalizing the generated image | ||
|
||
function sample() | ||
noise = [rand(dist, noise_dim, 1) for i=1:9] # Sample 9 digits | ||
noise = gpu.(noise) # Add to GPU | ||
|
||
testmode!(generator) | ||
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise | ||
testmode!(generator, false) | ||
|
||
img_grid = vcat([hcat(imgs...) for imgs in partition(fake_imgs, 3)]...) # Create grid for saving | ||
end | ||
|
||
cd(@__DIR__) | ||
|
||
|
||
# We use the <b>Binary Cross Entropy Loss</b> | ||
function bce(ŷ, y) | ||
mean(-y.*log.(ŷ) - (1 .- y .+ 1f-10).*log.(1 .- ŷ .+ 1f-10)) | ||
end | ||
|
||
function train(x) | ||
global training_steps | ||
println("TRAINING") | ||
z = rand(dist, noise_dim, BATCH_SIZE) |> gpu | ||
inp = 2x .- 1 |> gpu # Normalize images to [-1,1] | ||
|
||
zero_grad!(discriminator) | ||
|
||
D_real = discriminator(inp) # D(x) | ||
real_labels = ones(size(D_real)) |> gpu | ||
|
||
|
||
D_real_loss = bce(D_real,real_labels) | ||
|
||
fake_x = generator(z) # G(z) | ||
D_fake = discriminator(fake_x) # D(G(z)) | ||
fake_labels = zeros(size(D_fake)) |> gpu | ||
|
||
D_fake_loss = bce(D_fake,fake_labels) | ||
|
||
D_loss = D_real_loss + D_fake_loss | ||
Flux.back!(D_loss) | ||
opt_disc() # Optimize the discriminator | ||
|
||
zero_grad!(generator) | ||
|
||
fake_x = generator(z) # G(z) | ||
D_fake = discriminator(fake_x) # D(G(z)) | ||
real_labels = ones(size(D_fake)) |> gpu | ||
|
||
G_loss = bce(D_fake,real_labels) | ||
|
||
Flux.back!(G_loss) | ||
opt_gen() # Optimise the generator | ||
|
||
if training_steps % verbose_freq == 0 | ||
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)") | ||
end | ||
|
||
training_steps += 1 | ||
end | ||
|
||
for e = 1:NUM_EPOCHS | ||
for imgs in data | ||
train(imgs) | ||
end | ||
println("Epoch $e over.") | ||
end | ||
|
||
save("sample_dcgan.png", sample()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Generative Adversarial Network Tutorial | ||
|
||
Generative Adversarial Nets (GAN), are generative models used to infer a complicated probability distribution. | ||
|
||
We have two networks competing against each other - The Generator and the discriminator. | ||
|
||
![GAN](GAN-1.jpg) | ||
|
||
The first net generates data from randomly sampled noise, and the second net tries to tell the difference between the real data and the fake data generated by the first net. | ||
|
||
The formulation per se involves the following min-max objective : | ||
|
||
![gan_loss](gan.png) | ||
|
||
At equilibrium, the discriminator will output a probability of 0.5 for each generated image. | ||
|
||
## Run the script | ||
|
||
``` | ||
julia dcgan.jl | ||
``` | ||
|
||
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* | ||
|
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Get the imports done | ||
using Flux, Flux.Data.MNIST,Flux | ||
using Flux: @epochs, back!, testmode!, throttle | ||
using Base.Iterators: partition,flatten | ||
using Flux: onehot,onehotbatch | ||
using Distributions: Normal | ||
using Statistics | ||
using Images | ||
|
||
# Define the hyperparameters | ||
NUM_EPOCHS = 5000 | ||
BATCH_SIZE = 100 | ||
NOISE_DIM = 100 | ||
gen_lr = 0.0001f0 # Generator learning rate | ||
dis_lr = 0.0001f0 # discriminator learning rate | ||
training_steps = 0 | ||
verbose_freq = 2 | ||
|
||
# Loading Data | ||
@info("Loading data set") | ||
train_labels = MNIST.labels()[1:100] |> gpu | ||
train_imgs = MNIST.images()[1:100] |> gpu | ||
|
||
# Bundle images together with labels and group into minibatches | ||
function make_minibatch(X, Y, idxs) | ||
X_batch = Array{Float32}(undef, 784, length(idxs)) | ||
for i in 1:length(idxs) | ||
X_batch[:, i] = Float32.(reshape(X[idxs[i]],784)) | ||
end | ||
Y_batch = onehotbatch(Y[idxs], 0:9) | ||
return vcat(X_batch, Y_batch) | ||
end | ||
|
||
mb_idxs = partition(1:length(train_imgs), BATCH_SIZE) | ||
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs] | ||
|
||
# Define out distribution for random sampling for the generator to sample noise from | ||
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results | ||
|
||
# The Generator | ||
generator = Chain(Dense(NOISE_DIM + 10,1200,leakyrelu), | ||
Dense(1200,1000,leakyrelu), | ||
Dense(1000,784,tanh) | ||
) |> gpu | ||
|
||
# The Discriminator | ||
discriminator = Chain(Dense(794,512,leakyrelu), | ||
Dense(512,128,leakyrelu), | ||
Dense(128,1,sigmoid) | ||
) |> gpu | ||
|
||
# <b>Define the optimizers</b> | ||
opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5) | ||
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5) | ||
|
||
# <b>Utility functions to zero out our model gradients</b> | ||
function nullify_grad!(p) | ||
if typeof(p) <: TrackedArray | ||
p.grad .= 0.0f0 | ||
end | ||
return p | ||
end | ||
|
||
function zero_grad!(model) | ||
model = mapleaves(nullify_grad!, model) | ||
end | ||
|
||
# <b>Creating and Saving Utilities</b> | ||
|
||
img(x) = Gray.(reshape((x.+1)/2, 28, 28, 1)) # For denormalizing the generated image | ||
|
||
function sample() | ||
num_samples = 9 # Number of digits to sample | ||
fake_labels = zeros(10,num_samples) | ||
for i in 1:num_samples | ||
fake_labels[rand(1:9),i] = 1 | ||
end | ||
|
||
noise = [vcat(rand(dist, NOISE_DIM, 1),fake_labels[:,i]) for i=1:num_samples] # Sample 9 digits | ||
noise = gpu.(noise) # Add to GPU | ||
|
||
testmode!(generator) | ||
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise | ||
testmode!(generator, false) | ||
|
||
img_grid = fake_imgs[1] | ||
end | ||
|
||
cd(@__DIR__) | ||
|
||
# We use the <b>Binary Cross Entropy Loss</b> | ||
function bce(ŷ, y) | ||
mean(-y.*log.(ŷ .+ 1f-10) - (1 .- y .+ 1f-10).*log.(1 .- ŷ .+ 1f-10)) | ||
end | ||
|
||
function train(x) | ||
global training_steps | ||
|
||
z = rand(dist,NOISE_DIM, BATCH_SIZE) |> gpu | ||
inp = 2x .- 1 |> gpu # Normalize images to [-1,1] | ||
inp[end-9:end,:] = x[end-9:end,:] # The labels should not be modified | ||
|
||
labels = Float32.(x[end-9:end,:]) |> gpu # y | ||
zero_grad!(discriminator) | ||
zero_grad!(generator) | ||
|
||
D_real = discriminator(inp) # D(x|y) | ||
real_labels = ones(size(D_real)) |> gpu | ||
|
||
D_real_loss = bce(D_real,real_labels) | ||
|
||
fake_x = generator(vcat(z,labels)) # G(z|y) | ||
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y)) | ||
fake_labels = zeros(size(D_fake)) |> gpu | ||
|
||
D_fake_loss = bce(D_fake,fake_labels) | ||
|
||
D_loss = D_real_loss + D_fake_loss | ||
Flux.back!(D_loss) | ||
opt_disc() # Optimize the discriminator | ||
|
||
zero_grad!(discriminator) | ||
zero_grad!(generator) | ||
|
||
fake_x = generator(vcat(z,labels)) # G(z|y) | ||
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y)) | ||
real_labels = ones(size(D_fake)) |> gpu | ||
|
||
G_loss = bce(D_fake,real_labels) | ||
Flux.back!(G_loss) | ||
opt_gen() # Optimise the generator | ||
|
||
if training_steps % verbose_freq == 0 | ||
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)") | ||
end | ||
|
||
println(training_steps) | ||
training_steps += 1 | ||
end | ||
|
||
for e = 1:NUM_EPOCHS | ||
for data in train_set | ||
train(data) | ||
end | ||
println("Epoch $e over.") | ||
end | ||
|
||
save("sample_cgan.png", sample()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Conditional Generative Adversarial Network Tutorial | ||
|
||
A cGAN is a GAN wherein both the generator and the discriminator are fed prior labels along with the image and random noise. | ||
|
||
It models the conditional probabilities conditioned on the labels. | ||
|
||
![cGAN](cGAN.jpg) | ||
|
||
## Run the script | ||
|
||
``` | ||
julia cgan.jl | ||
``` | ||
|
||
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.