Skip to content

Commit

Permalink
fixed shape of the virtual adversarial direction
Browse files Browse the repository at this point in the history
  • Loading branch information
mbarbetti committed Jun 21, 2022
1 parent 3907b28 commit 9f54b2f
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 10 deletions.
157 changes: 157 additions & 0 deletions tests/algorithms/gan/WGAN_ALP_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import PIL
import glob

import numpy as np
import tensorflow as tf

from datetime import datetime
from tensorflow.keras import Sequential, layers
from tf_gen_models.algorithms.gan import WGAN_ALP
from tf_gen_models.callbacks import GanExpLrScheduler, ImageSaver


# +---------------------------------+
# | Load and prepare datasets |
# +---------------------------------+

(train_img, _), (test_img, _) = tf.keras.datasets.mnist.load_data()

train_img = train_img . reshape ( train_img.shape[0], 28, 28, 1 ) \
. astype ( np.float32 )
train_img = (train_img - 127.5) / 127.5 # pixel intensity in [-1,1]

test_img = test_img . reshape ( test_img.shape[0], 28, 28, 1 ) \
. astype ( np.float32 )
test_img = (test_img - 127.5) / 127.5 # pixel intensity in [-1,1]

BUFFER_SIZE = 60000
BATCH_SIZE = 64

## TF.DATA.DATASET

train_ds = (
tf.data.Dataset.from_tensor_slices ( train_img )
.shuffle ( BUFFER_SIZE ) # shuffle all the images
.batch ( BATCH_SIZE, drop_remainder = True ) # mini-batch splitting
.cache() # cache the dataset
# .prefetch ( tf.data.AUTOTUNE ) # pre-prepare data to be consumed
)

test_ds = (
tf.data.Dataset.from_tensor_slices ( test_img )
.shuffle ( BUFFER_SIZE ) # shuffle all the images
.batch ( BATCH_SIZE, drop_remainder = True ) # mini-batch splitting
.cache() # cache the dataset
# .prefetch ( tf.data.AUTOTUNE ) # pre-prepare data to be consumed
)

# +---------------------------+
# | Adversarial players |
# +---------------------------+

LATENT_DIM = 100

## GENERATOR

generator = Sequential ( name = "generator" )

generator . add ( layers.Dense ( 7 * 7 * 256, use_bias = False, input_shape = (LATENT_DIM,) ) )
generator . add ( layers.BatchNormalization() )
generator . add ( layers.LeakyReLU() )

generator . add ( layers.Reshape ( (7, 7, 256) ) )

generator . add ( layers.Conv2DTranspose ( 256, (3, 3), strides = (1, 1), padding = "valid" ) )
generator . add ( layers.BatchNormalization ( axis = 1 ) )
generator . add ( layers.LeakyReLU() )

generator . add ( layers.Conv2DTranspose ( 128, (4, 4), strides = (2, 2), padding = "valid" ) )
generator . add ( layers.BatchNormalization ( axis = 1 ) )
generator . add ( layers.LeakyReLU() )

generator . add ( layers.Conv2DTranspose ( 64, (5, 5), strides = (1, 1), padding = "valid" ) )
generator . add ( layers.BatchNormalization ( axis = 1 ) )
generator . add ( layers.LeakyReLU() )

generator . add ( layers.Conv2DTranspose ( 1, (5, 5), strides = (1, 1), padding = "valid", activation = "tanh" ) )

## DISCRIMINATOR

discriminator = Sequential ( name = "discriminator" )

discriminator . add ( layers.Conv2D ( 32, (4, 4), strides = (2, 2), padding = "same", input_shape = [28, 28, 1] ) )
discriminator . add ( layers.BatchNormalization ( axis = 1 ) )
discriminator . add ( layers.LeakyReLU ( alpha = 0.2 ) )

discriminator . add ( layers.Conv2D ( 64, (4, 4), strides = (2, 2), padding = "same" ) )
discriminator . add ( layers.BatchNormalization ( axis = 1 ) )
discriminator . add ( layers.LeakyReLU ( alpha = 0.2 ) )

discriminator . add ( layers.Conv2D ( 128, (4, 4), strides = (2, 2), padding = "same" ) )
discriminator . add ( layers.BatchNormalization ( axis = 1 ) )
discriminator . add ( layers.LeakyReLU ( alpha = 0.2 ) )

discriminator . add ( layers.Flatten() )
discriminator . add ( layers.Dense ( 1, activation = "linear" ) )

# +--------------------------+
# | Training procedure |
# +--------------------------+

gan = WGAN_ALP (generator, discriminator, latent_dim = LATENT_DIM)

gan . summary()

## OPTIMIZERS

g_opt = tf.keras.optimizers.Adam ( 1e-4, beta_1 = 0.0, beta_2 = 0.9 )
d_opt = tf.keras.optimizers.Adam ( 1e-4, beta_1 = 0.0, beta_2 = 0.9 )

gan . compile ( g_optimizer = g_opt ,
d_optimizer = d_opt ,
g_updt_per_batch = 1 ,
d_updt_per_batch = 5 ,
v_adv_dir_updt = 1 ,
adv_lp_penalty = 100 )

## CALLBACKS

lr_sched = GanExpLrScheduler ( factor = 0.90, step = 5 )
img_saver = ImageSaver ( name = "dc-wgan-alp", dirname = "./images/dc-wgan-alp", step = 1, look = "multi" )

## TRAINING

EPOCHS = 50
STEPS_PER_EPOCH = int ( len(train_img) / BATCH_SIZE )

start = datetime.now()

train = gan . fit ( train_ds ,
epochs = EPOCHS ,
steps_per_epoch = STEPS_PER_EPOCH ,
validation_data = test_ds ,
callbacks = [ lr_sched, img_saver ] ,
verbose = 1 )

stop = datetime.now()

timestamp = str(stop-start) . split (".") [0] # HH:MM:SS
timestamp = timestamp . split (":") # [HH, MM, SS]
timestamp = f"{timestamp[0]}h {timestamp[1]}min {timestamp[2]}s"

print (f"Model training completed in {timestamp}.")

# +--------------------+
# | Create a GIF |
# +--------------------+

anim_file = "./images/dc-wgan-alp.gif"

filenames = glob.glob ("./images/dc-wgan-alp/dc-wgan-alp_ep*.png")
filenames = sorted (filenames)

img , *imgs = [ PIL.Image.open(f) for f in filenames ]
img . save ( fp = anim_file, format = "GIF", append_images = imgs,
save_all = True, duration = 135, loop = 0 )

print (f"GIF correctly exported to {anim_file}.")
19 changes: 10 additions & 9 deletions tf_gen_models/algorithms/gan/WGAN_ALP.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _compute_d_loss (self, gen_sample, ref_sample) -> tf.Tensor:
d_loss = tf.reduce_mean ( w_gen * D_gen - w_ref * D_ref , axis = None )

## initial virtual adversarial direction
r_k = tf.random.uniform ( shape = tf.shape(input_ref) ,
r_k = tf.random.uniform ( shape = tf.shape(input_ref)[1:] ,
minval = 0.0 ,
maxval = 1.0 ,
dtype = input_ref.dtype )
Expand All @@ -83,13 +83,14 @@ def _compute_d_loss (self, gen_sample, ref_sample) -> tf.Tensor:
## approximation of virtual adversarial direction
D_gen_pert = tf.cast ( self._discriminator (input_gen_pert), dtype = input_gen.dtype )
D_ref_pert = tf.cast ( self._discriminator (input_ref_pert), dtype = input_ref.dtype )
diff = tf.abs ( tf.concat ( [D_gen, D_ref], axis = 0 ) - \
tf.concat ( [D_gen_pert, D_ref_pert], axis = 0 ) )
r_k = tf.gradients ( tf.reduce_mean (diff, axis = None) , r_k )[0]
diff = tf.abs ( tf.concat ( [ D_gen , D_ref ] , axis = 0 ) - \
tf.concat ( [ D_gen_pert , D_ref_pert ] , axis = 0 ) )
diff = tf.reduce_mean ( diff, axis = None )
r_k = tf.gradients ( diff, [r_k] ) [0]
r_k /= tf.norm ( r_k , axis = None )

## virtual adversarial direction
epsilon = self._epsilon_sampler ( shape = tf.shape(input_ref), dtype = input_ref.dtype )
epsilon = self._epsilon_sampler ( shape = tf.shape(input_ref)[1:], dtype = input_ref.dtype )
r_adv = epsilon * r_k

## adversarial perturbation of input tensors
Expand All @@ -103,8 +104,8 @@ def _compute_d_loss (self, gen_sample, ref_sample) -> tf.Tensor:
## adversarial Lipschitz penalty correction
D_gen_pert = tf.cast ( self._discriminator (input_gen_pert), dtype = input_gen.dtype )
D_ref_pert = tf.cast ( self._discriminator (input_ref_pert), dtype = input_ref.dtype )
diff = tf.abs ( tf.concat ( [D_gen, D_ref], axis = 0 ) - \
tf.concat ( [D_gen_pert, D_ref_pert], axis = 0 ) )
diff = tf.abs ( tf.concat ( [ D_gen , D_ref ] , axis = 0 ) - \
tf.concat ( [ D_gen_pert , D_ref_pert ] , axis = 0 ) )
alp_term = tf.math.maximum ( diff / tf.norm ( r_adv, axis = None ) - self._lp_const, 0.0 ) # one-side penalty
alp_term = self._adv_lp_penalty * tf.reduce_mean (alp_term, axis = None) # adversarial Lipschitz penalty
d_loss += alp_term ** 2
Expand Down Expand Up @@ -176,8 +177,8 @@ def epsilon_sampler (self, func) -> None:
raise TypeError ("The epsilon sampler should be passed as a lambda function.")

## data-value control
func_args = func.__code__.co_varnames
if (len(func_args) != 2) or (func_args[0] != "shape") or (func_args[1] != "dtype"):
args = func.__code__.co_varnames
if (len(args) != 2) or ("shape" not in args) or ("dtype" not in args):
raise ValueError ( f"The lambda function for the epsilon sampler "
f"should have only ('shape', 'dtype') as arguments." )

Expand Down
2 changes: 1 addition & 1 deletion tf_gen_models/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.10"
__version__ = "0.0.11"

0 comments on commit 9f54b2f

Please sign in to comment.