From bd87368680793dedfe0302e4ff5ff0f01a164af1 Mon Sep 17 00:00:00 2001 From: Gustavo Rosa Date: Tue, 20 Apr 2021 15:17:36 -0300 Subject: [PATCH] Still making another small cleanup on package. --- nalp/datasets/image.py | 12 +++++------- nalp/models/dcgan.py | 2 +- nalp/models/generators/stacked_rnn.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/nalp/datasets/image.py b/nalp/datasets/image.py index bd7df5c..68bd366 100644 --- a/nalp/datasets/image.py +++ b/nalp/datasets/image.py @@ -56,15 +56,13 @@ def _preprocess(self, images, shape, normalize): """ + # Makes sure that images are float typed + images = images.astype('float32') + # If a shape is supplied if shape: - # Reshapes the array and make sure that it is float typed - images = images.reshape(shape).astype('float32') - - # If no shape is supplied - else: - # Just make sure that the array is float typed - images = images.astype('float32') + # Reshapes the array + images = images.reshape(shape) # If images should be normalized if normalize: diff --git a/nalp/models/dcgan.py b/nalp/models/dcgan.py index 51020fa..111bad6 100644 --- a/nalp/models/dcgan.py +++ b/nalp/models/dcgan.py @@ -43,7 +43,7 @@ def __init__(self, input_shape=(28, 28, 1), noise_dim=100, n_samplings=3, alpha= super(DCGAN, self).__init__(D, G, name='dcgan') logger.debug('Input: %s | Noise: %d | Number of samplings: %d | ' - 'Activation rate: %s | Dropout rate: %s', + 'Activation rate: %s | Dropout rate: %s.', input_shape, noise_dim, n_samplings, alpha, dropout_rate) logger.info('Class overrided.') diff --git a/nalp/models/generators/stacked_rnn.py b/nalp/models/generators/stacked_rnn.py index bb17ea0..00743a3 100644 --- a/nalp/models/generators/stacked_rnn.py +++ b/nalp/models/generators/stacked_rnn.py @@ -52,7 +52,7 @@ def __init__(self, encoder=None, vocab_size=1, embedding_size=32, hidden_size=(6 # Creates the linear (Dense) layer self.linear = Dense(vocab_size, name='out') - logger.debug('Number of cells: %d', len(hidden_size)) + logger.debug('Number of cells: %d.', len(hidden_size)) logger.info('Class overrided.') @property