diff --git a/datasets.py b/datasets.py index 26b4b2a..b39df75 100644 --- a/datasets.py +++ b/datasets.py @@ -68,7 +68,7 @@ def central_crop(image, size): return tf.image.crop_to_bounding_box(image, top, left, size, size) -def get_dataset(config, additional_dim=None, uniform_dequantization=False, evaluation=False): +def get_dataset(config, additional_dim=None, uniform_dequantization=False, evaluation=False, drop_remainder=True): """Create data loaders for training and evaluation. Args: @@ -198,7 +198,7 @@ def create_dataset(dataset_builder, split): ds = ds.shuffle(shuffle_buffer_size) ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) for batch_size in reversed(batch_dims): - ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.batch(batch_size, drop_remainder=drop_remainder) return ds.prefetch(prefetch_size) train_ds = create_dataset(dataset_builder, train_split_name) diff --git a/main.py b/main.py index 45d9ac5..6200c1d 100644 --- a/main.py +++ b/main.py @@ -28,9 +28,11 @@ config_flags.DEFINE_config_file( "config", None, "Training configuration.", lock_config=True) flags.DEFINE_string("workdir", None, "Work directory.") -flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval") +flags.DEFINE_enum("mode", None, ["train", "eval","fid_stats"], "Running mode: train or eval or fid_stats") flags.DEFINE_string("eval_folder", "eval", "The folder name for storing evaluation results") +flags.DEFINE_string("fid_folder", "assets/stats", + "The folder name for storing FID statistics") flags.mark_flags_as_required(["workdir", "config", "mode"]) @@ -58,6 +60,9 @@ def main(argv): elif FLAGS.mode == "eval": # Run the evaluation pipeline run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder) + elif FLAGS.mode == "fid_stats": + # Calculate the FID statistics + run_lib.fid_stats(FLAGS.config, FLAGS.fid_folder) else: raise ValueError(f"Mode {FLAGS.mode} not recognized.") diff --git a/run_lib.py b/run_lib.py index ebb35d9..6a435b5 100644 --- a/run_lib.py +++ b/run_lib.py @@ -580,3 +580,58 @@ class EvalMeta: os.path.join(eval_dir, f"meta_{jax.host_id()}_*")) for file in meta_files: tf.io.gfile.remove(file) + +# Create FID stats by looping through the whole data +def fid_stats(config, + fid_dir="assets/stats"): + """Evaluate trained models. + + Args: + config: Configuration to use. + fid_dir: The subfolder for storing fid statistics. + """ + # Create directory to eval_folder + tf.io.gfile.makedirs(fid_dir) + + # Build data pipeline + train_ds, eval_ds, dataset_builder = datasets.get_dataset(config, + additional_dim=None, + uniform_dequantization=False, + evaluation=True, + drop_remainder=False) + bpd_iter = iter(train_ds) + + # Use inceptionV3 for images with resolution higher than 256. + inceptionv3 = config.data.image_size >= 256 + inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3) + + all_pools = [] + for batch_id in range(len(train_ds)): + + batch = next(bpd_iter) + + if jax.host_id() == 0: + logging.info("Making FID stats -- step: %d" % (batch_id)) + + batch_ = jax.tree_map(lambda x: x._numpy(), batch) + batch_ = (batch_['image']*255).astype(np.uint8).reshape((-1, config.data.image_size, config.data.image_size, 3)) + + # Force garbage collection before calling TensorFlow code for Inception network + gc.collect() + latents = evaluation.run_inception_distributed(batch_, inception_model, + inceptionv3=inceptionv3) + all_pools.append(latents["pool_3"]) + # Force garbage collection again before returning to JAX code + gc.collect() + + all_pools = np.concatenate(all_pools, axis=0) # Combine into one + + # Save latent represents of the Inception network to disk or Google Cloud Storage + filename = f'{config.data.dataset.lower()}_{config.data.image_size}_stats.npz' + with tf.io.gfile.GFile( + os.path.join(fid_dir, filename), "wb") as fout: + io_buffer = io.BytesIO() + np.savez_compressed( + io_buffer, pool_3=all_pools) + fout.write(io_buffer.getvalue()) +