-
-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Condtional GAN and DCGAN tutorial #111
Open
shreyas-kowshik
wants to merge
4
commits into
FluxML:master
Choose a base branch
from
shreyas-kowshik:gans
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Get the imports done | ||
using Flux, Flux.Data.MNIST | ||
using Flux: @epochs, back!, testmode!, throttle | ||
using Base.Iterators: partition | ||
using Distributions: Uniform,Normal | ||
using CUDAnative: tanh, log, exp | ||
using CuArrays | ||
using Images | ||
using Statistics | ||
|
||
# Define the hyperparameters | ||
BATCH_SIZE = 128 | ||
NUM_EPOCHS = 15 | ||
noise_dim = 100 | ||
channels = 128 | ||
hidden_dim = 7 * 7 * channels | ||
training_steps = 0 | ||
verbose_freq = 100 | ||
dis_lr = 0.0001f0 # Discriminator Learning Rate | ||
gen_lr = 0.0001f0 # Generator Learning Rate | ||
|
||
# Loading Data | ||
|
||
# We use Flux's built in MNIST Loader | ||
imgs = MNIST.images() | ||
|
||
# Partition into batches of size 'BATCH_SIZE' | ||
data = [reshape(float(hcat(vec.(imgs)...)),28,28,1,:) for imgs in partition(imgs, BATCH_SIZE)] | ||
|
||
# Define out distribution for random sampling for the generator to sample noise from | ||
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results | ||
|
||
expand_dims(x,n::Int) = reshape(x,ones(Int64,n)...,size(x)...) | ||
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) | ||
|
||
# The Generator | ||
generator = Chain( | ||
Dense(noise_dim, 1024, leakyrelu), | ||
x->expand_dims(x,1), | ||
BatchNorm(1024), | ||
x->squeeze(x), | ||
Dense(1024, hidden_dim, leakyrelu), | ||
x->expand_dims(x,1), | ||
BatchNorm(hidden_dim), | ||
x->squeeze(x), | ||
x->reshape(x,7,7,channels,:), | ||
ConvTranspose((4,4), channels=>64, relu; stride=(2,2), pad=(1,1)), | ||
x->expand_dims(x,2), | ||
BatchNorm(64), | ||
x->squeeze(x), | ||
ConvTranspose((4,4), 64=>1, tanh; stride=(2,2), pad=(1,1)) | ||
) |> gpu | ||
|
||
# The Discriminator | ||
discriminator = Chain( | ||
Conv((3,3), 1=>32, leakyrelu;pad = 1), | ||
x->meanpool(x, (2,2)), | ||
Conv((3,3), 32=>64, leakyrelu;pad = 1), | ||
x->meanpool(x, (2,2)), | ||
x->reshape(x,7*7*64,:), | ||
Dense(7*7*64, 1024, leakyrelu), | ||
x->expand_dims(x,1), | ||
BatchNorm(1024), | ||
x->squeeze(x), | ||
Dense(1024, 1,sigmoid) | ||
) |> gpu | ||
|
||
# <b>Define the optimizers</b> | ||
|
||
opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5) | ||
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5) | ||
|
||
# <b>Utility functions to zero out our model gradients</b> | ||
function nullify_grad!(p) | ||
if typeof(p) <: TrackedArray | ||
p.grad .= 0.0f0 | ||
end | ||
return p | ||
end | ||
|
||
function zero_grad!(model) | ||
model = mapleaves(nullify_grad!, model) | ||
end | ||
|
||
# <b>Creating and Saving Utilities</b> | ||
|
||
img(x) = Gray.(reshape((x+1)/2, 28, 28)) # For denormalizing the generated image | ||
|
||
function sample() | ||
noise = [rand(dist, noise_dim, 1) for i=1:9] # Sample 9 digits | ||
noise = gpu.(noise) # Add to GPU | ||
|
||
testmode!(generator) | ||
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise | ||
testmode!(generator, false) | ||
|
||
img_grid = vcat([hcat(imgs...) for imgs in partition(fake_imgs, 3)]...) # Create grid for saving | ||
end | ||
|
||
cd(@__DIR__) | ||
|
||
|
||
# We use the <b>Binary Cross Entropy Loss</b> | ||
function bce(ŷ, y) | ||
mean(-y.*log.(ŷ) - (1 .- y .+ 1f-10).*log.(1 .- ŷ .+ 1f-10)) | ||
end | ||
|
||
function train(x) | ||
global training_steps | ||
println("TRAINING") | ||
z = rand(dist, noise_dim, BATCH_SIZE) |> gpu | ||
inp = 2x .- 1 |> gpu # Normalize images to [-1,1] | ||
|
||
zero_grad!(discriminator) | ||
|
||
D_real = discriminator(inp) # D(x) | ||
real_labels = ones(size(D_real)) |> gpu | ||
|
||
|
||
D_real_loss = bce(D_real,real_labels) | ||
|
||
fake_x = generator(z) # G(z) | ||
D_fake = discriminator(fake_x) # D(G(z)) | ||
fake_labels = zeros(size(D_fake)) |> gpu | ||
|
||
D_fake_loss = bce(D_fake,fake_labels) | ||
|
||
D_loss = D_real_loss + D_fake_loss | ||
Flux.back!(D_loss) | ||
opt_disc() # Optimize the discriminator | ||
|
||
zero_grad!(generator) | ||
|
||
fake_x = generator(z) # G(z) | ||
D_fake = discriminator(fake_x) # D(G(z)) | ||
real_labels = ones(size(D_fake)) |> gpu | ||
|
||
G_loss = bce(D_fake,real_labels) | ||
|
||
Flux.back!(G_loss) | ||
opt_gen() # Optimise the generator | ||
|
||
if training_steps % verbose_freq == 0 | ||
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)") | ||
end | ||
|
||
training_steps += 1 | ||
end | ||
|
||
for e = 1:NUM_EPOCHS | ||
for imgs in data | ||
train(imgs) | ||
end | ||
println("Epoch $e over.") | ||
end | ||
|
||
save("sample_dcgan.png", sample()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Generative Adversarial Network Tutorial | ||
|
||
Generative Adversarial Nets (GAN), are generative models used to infer a complicated probability distribution. | ||
|
||
We have two networks competing against each other - The Generator and the discriminator. | ||
|
||
![GAN](GAN-1.jpg) | ||
|
||
The first net generates data from randomly sampled noise, and the second net tries to tell the difference between the real data and the fake data generated by the first net. | ||
|
||
The formulation per se involves the following min-max objective : | ||
|
||
![gan_loss](gan.png) | ||
|
||
At equilibrium, the discriminator will output a probability of 0.5 for each generated image. | ||
|
||
## Run the script | ||
|
||
``` | ||
julia dcgan.jl | ||
``` | ||
|
||
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* | ||
|
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Get the imports done | ||
using Flux, Flux.Data.MNIST,Flux | ||
using Flux: @epochs, back!, testmode!, throttle | ||
using Base.Iterators: partition,flatten | ||
using Flux: onehot,onehotbatch | ||
using Distributions: Normal | ||
using Statistics | ||
using Images | ||
|
||
# Define the hyperparameters | ||
NUM_EPOCHS = 5000 | ||
BATCH_SIZE = 100 | ||
NOISE_DIM = 100 | ||
gen_lr = 0.0001f0 # Generator learning rate | ||
dis_lr = 0.0001f0 # discriminator learning rate | ||
training_steps = 0 | ||
verbose_freq = 2 | ||
|
||
# Loading Data | ||
@info("Loading data set") | ||
train_labels = MNIST.labels()[1:100] |> gpu | ||
train_imgs = MNIST.images()[1:100] |> gpu | ||
|
||
# Bundle images together with labels and group into minibatches | ||
function make_minibatch(X, Y, idxs) | ||
X_batch = Array{Float32}(undef, 784, length(idxs)) | ||
for i in 1:length(idxs) | ||
X_batch[:, i] = Float32.(reshape(X[idxs[i]],784)) | ||
end | ||
Y_batch = onehotbatch(Y[idxs], 0:9) | ||
return vcat(X_batch, Y_batch) | ||
end | ||
|
||
mb_idxs = partition(1:length(train_imgs), BATCH_SIZE) | ||
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs] | ||
|
||
# Define out distribution for random sampling for the generator to sample noise from | ||
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results | ||
|
||
# The Generator | ||
generator = Chain(Dense(NOISE_DIM + 10,1200,leakyrelu), | ||
Dense(1200,1000,leakyrelu), | ||
Dense(1000,784,tanh) | ||
) |> gpu | ||
|
||
# The Discriminator | ||
discriminator = Chain(Dense(794,512,leakyrelu), | ||
Dense(512,128,leakyrelu), | ||
Dense(128,1,sigmoid) | ||
) |> gpu | ||
|
||
# <b>Define the optimizers</b> | ||
opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5) | ||
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5) | ||
|
||
# <b>Utility functions to zero out our model gradients</b> | ||
function nullify_grad!(p) | ||
if typeof(p) <: TrackedArray | ||
p.grad .= 0.0f0 | ||
end | ||
return p | ||
end | ||
|
||
function zero_grad!(model) | ||
model = mapleaves(nullify_grad!, model) | ||
end | ||
|
||
# <b>Creating and Saving Utilities</b> | ||
|
||
img(x) = Gray.(reshape((x.+1)/2, 28, 28, 1)) # For denormalizing the generated image | ||
|
||
function sample() | ||
num_samples = 9 # Number of digits to sample | ||
fake_labels = zeros(10,num_samples) | ||
for i in 1:num_samples | ||
fake_labels[rand(1:9),i] = 1 | ||
end | ||
|
||
noise = [vcat(rand(dist, NOISE_DIM, 1),fake_labels[:,i]) for i=1:num_samples] # Sample 9 digits | ||
noise = gpu.(noise) # Add to GPU | ||
|
||
testmode!(generator) | ||
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise | ||
testmode!(generator, false) | ||
|
||
img_grid = fake_imgs[1] | ||
end | ||
|
||
cd(@__DIR__) | ||
|
||
# We use the <b>Binary Cross Entropy Loss</b> | ||
function bce(ŷ, y) | ||
mean(-y.*log.(ŷ .+ 1f-10) - (1 .- y .+ 1f-10).*log.(1 .- ŷ .+ 1f-10)) | ||
end | ||
|
||
function train(x) | ||
global training_steps | ||
|
||
z = rand(dist,NOISE_DIM, BATCH_SIZE) |> gpu | ||
inp = 2x .- 1 |> gpu # Normalize images to [-1,1] | ||
inp[end-9:end,:] = x[end-9:end,:] # The labels should not be modified | ||
|
||
labels = Float32.(x[end-9:end,:]) |> gpu # y | ||
zero_grad!(discriminator) | ||
zero_grad!(generator) | ||
|
||
D_real = discriminator(inp) # D(x|y) | ||
real_labels = ones(size(D_real)) |> gpu | ||
|
||
D_real_loss = bce(D_real,real_labels) | ||
|
||
fake_x = generator(vcat(z,labels)) # G(z|y) | ||
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y)) | ||
fake_labels = zeros(size(D_fake)) |> gpu | ||
|
||
D_fake_loss = bce(D_fake,fake_labels) | ||
|
||
D_loss = D_real_loss + D_fake_loss | ||
Flux.back!(D_loss) | ||
opt_disc() # Optimize the discriminator | ||
|
||
zero_grad!(discriminator) | ||
zero_grad!(generator) | ||
|
||
fake_x = generator(vcat(z,labels)) # G(z|y) | ||
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y)) | ||
real_labels = ones(size(D_fake)) |> gpu | ||
|
||
G_loss = bce(D_fake,real_labels) | ||
Flux.back!(G_loss) | ||
opt_gen() # Optimise the generator | ||
|
||
if training_steps % verbose_freq == 0 | ||
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)") | ||
end | ||
|
||
println(training_steps) | ||
training_steps += 1 | ||
end | ||
|
||
for e = 1:NUM_EPOCHS | ||
for data in train_set | ||
train(data) | ||
end | ||
println("Epoch $e over.") | ||
end | ||
|
||
save("sample_cgan.png", sample()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Conditional Generative Adversarial Network Tutorial | ||
|
||
A cGAN is a GAN wherein both the generator and the discriminator are fed prior labels along with the image and random noise. | ||
|
||
It models the conditional probabilities conditioned on the labels. | ||
|
||
![cGAN](cGAN.jpg) | ||
|
||
## Run the script | ||
|
||
``` | ||
julia cgan.jl | ||
``` | ||
|
||
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need these utilities; just use
update!
withParams
andGrads
, likeFlux.train!
does.Also, you're still using HTML tags above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MikeInnes I used the update! without zeroing out the gradients of the models manually. However that model is not converging to the actual output even after repeated trials. However, manually zeroing out the gradients does. I don't know if it's a bug on my part. I was using this as a reference :
https://github.com/eriklindernoren/PyTorch-GAN/blob/1f130dfca726e14254e4fd78e5fb63f08931acd3/implementations/cgan/cgan.py#L161-L195
As pointed out on Slack,
gradient
used inupdate!
should automatically zero out the gradients, but the results are not reflecting them...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/FluxML/Flux.jl/blob/66ce8d8066e00b0f8eccd69649dabaee26f97129/src/optimise/train.jl#L20 should zero out the gradients with a simple
update!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The normal
update!
method will work if you usegradient
rather thanback!
.back!
should be avoided as it's effectively deprecated.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MikeInnes Sorry for replying late.Made the requested changes.
Does it look good?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update!
does zero out the gradient so no need to do it explicitly I suppose.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe
update!
only zeros out the gradient in the call but not all the gradient.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See this MWE below:
gives
See how Case 2 is different from Case 1 and Case 3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I created a DCGAN model with Tracker backend (Flux v0.9.0), it didn't converge without zeroing out gradients after training the discriminator. https://github.com/matsueushi/fluxjl-gan/blob/flux0.9.0/mnist-dcgan.jl
However, with Zygote backend (Flux v0.10.0),
gives expected results
and I didn't have to zero out gradients. https://github.com/matsueushi/fluxjl-gan/blob/e60684b6c8ecc601eb6784ae393eae9a3a3ba57a/mnist-dcgan.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is expected. Only a tracker based AD needs the zero-out part. AD based on Zygote doesn't have this side effect.