Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FID statistics calculation as an option (can now do "train", "eval", or "fid_stats") #5

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def central_crop(image, size):
return tf.image.crop_to_bounding_box(image, top, left, size, size)


def get_dataset(config, additional_dim=None, uniform_dequantization=False, evaluation=False):
def get_dataset(config, additional_dim=None, uniform_dequantization=False, evaluation=False, drop_remainder=True):
"""Create data loaders for training and evaluation.

Args:
Expand Down Expand Up @@ -198,7 +198,7 @@ def create_dataset(dataset_builder, split):
ds = ds.shuffle(shuffle_buffer_size)
ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
for batch_size in reversed(batch_dims):
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return ds.prefetch(prefetch_size)

train_ds = create_dataset(dataset_builder, train_split_name)
Expand Down
7 changes: 6 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=True)
flags.DEFINE_string("workdir", None, "Work directory.")
flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
flags.DEFINE_enum("mode", None, ["train", "eval","fid_stats"], "Running mode: train or eval or fid_stats")
flags.DEFINE_string("eval_folder", "eval",
"The folder name for storing evaluation results")
flags.DEFINE_string("fid_folder", "assets/stats",
"The folder name for storing FID statistics")
flags.mark_flags_as_required(["workdir", "config", "mode"])


Expand Down Expand Up @@ -58,6 +60,9 @@ def main(argv):
elif FLAGS.mode == "eval":
# Run the evaluation pipeline
run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder)
elif FLAGS.mode == "fid_stats":
# Calculate the FID statistics
run_lib.fid_stats(FLAGS.config, FLAGS.fid_folder)
else:
raise ValueError(f"Mode {FLAGS.mode} not recognized.")

Expand Down
55 changes: 55 additions & 0 deletions run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,3 +580,58 @@ class EvalMeta:
os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))
for file in meta_files:
tf.io.gfile.remove(file)

# Create FID stats by looping through the whole data
def fid_stats(config,
fid_dir="assets/stats"):
"""Evaluate trained models.

Args:
config: Configuration to use.
fid_dir: The subfolder for storing fid statistics.
"""
# Create directory to eval_folder
tf.io.gfile.makedirs(fid_dir)

# Build data pipeline
train_ds, eval_ds, dataset_builder = datasets.get_dataset(config,
additional_dim=None,
uniform_dequantization=False,
evaluation=True,
drop_remainder=False)
bpd_iter = iter(train_ds)

# Use inceptionV3 for images with resolution higher than 256.
inceptionv3 = config.data.image_size >= 256
inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

all_pools = []
for batch_id in range(len(train_ds)):

batch = next(bpd_iter)

if jax.host_id() == 0:
logging.info("Making FID stats -- step: %d" % (batch_id))

batch_ = jax.tree_map(lambda x: x._numpy(), batch)
batch_ = (batch_['image']*255).astype(np.uint8).reshape((-1, config.data.image_size, config.data.image_size, 3))

# Force garbage collection before calling TensorFlow code for Inception network
gc.collect()
latents = evaluation.run_inception_distributed(batch_, inception_model,
inceptionv3=inceptionv3)
all_pools.append(latents["pool_3"])
# Force garbage collection again before returning to JAX code
gc.collect()

all_pools = np.concatenate(all_pools, axis=0) # Combine into one

# Save latent represents of the Inception network to disk or Google Cloud Storage
filename = f'{config.data.dataset.lower()}_{config.data.image_size}_stats.npz'
with tf.io.gfile.GFile(
os.path.join(fid_dir, filename), "wb") as fout:
io_buffer = io.BytesIO()
np.savez_compressed(
io_buffer, pool_3=all_pools)
fout.write(io_buffer.getvalue())