From fc6af3653cf73218fb6d25207ba4f0f977eae9e0 Mon Sep 17 00:00:00 2001 From: Mark Ma <519329064@qq.com> Date: Sun, 30 Aug 2020 15:07:15 +0800 Subject: [PATCH 1/4] add stargan-v2 style FID calculation. add --style command line option to let user choose stargan or gan-compression style (by default gan-compression style will be used). move `dygraph.guard()` declaration into fid module for two reason: 1. the inference model didn't work in dygraph mode, so we dynamically choose whether to use dygraph mode after style is determined. 2. easier to use for end user (no need to call fluid.dygraph.guard() explicitly) --- ppgan/metric/compute_fid.py | 69 ++++++++++++++++++++++++++-------- ppgan/metric/test_fid_score.py | 10 +++-- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py index c8fc8059e2658..8d000a0d5235d 100644 --- a/ppgan/metric/compute_fid.py +++ b/ppgan/metric/compute_fid.py @@ -16,6 +16,7 @@ import fnmatch import numpy as np import cv2 +from PIL import Image from cv2 import imread from scipy import linalg import paddle.fluid as fluid @@ -128,7 +129,7 @@ def calculate_fid_given_img(img_fake, return fid_value -def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): +def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, style=None): if len(files) % batch_size != 0: print(('Warning: number of images is not a multiple of the ' 'batch size. Some samples are going to be ignored.')) @@ -144,8 +145,23 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): for i in tqdm(range(n_batches)): start = i * batch_size end = start + batch_size - images = np.array( - [imread(str(f)).astype(np.float32) for f in files[start:end]]) + + # same as stargan-v2 official implementation: resize to 256 first, then resize to 299 + if style == 'stargan': + img_list = [] + for f in files[start:end]: + im = Image.open(str(f)).convert('RGB') + if im.size[0] != 299: + im = im.resize((256, 256), 2) + im = im.resize((299, 299), 2) + + img_list.append(np.array(im).astype('float32')) + + images = np.array( + img_list) + else: + images = np.array( + [imread(str(f)).astype(np.float32) for f in files[start:end]]) if len(images.shape) != 4: images = imread(str(files[start])) @@ -155,33 +171,53 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): images = images.transpose((0, 3, 1, 2)) images /= 255 - images = to_variable(images) - param_dict, _ = fluid.load_dygraph(premodel_path) - model.set_dict(param_dict) - model.eval() + # imagenet normalization + if style == 'stargan': + mean = np.array([0.485, 0.456, 0.406]).astype('float32') + std = np.array([0.229, 0.224, 0.225]).astype('float32') + images[:] = (images[:] - mean[:, None, None]) / std[:, None, None] - pred = model(images)[0][0].numpy() + if style=='stargan': + pred_arr[start:end] = inception_infer(images, premodel_path) + else: + with fluid.dygraph.guard(): + images = to_variable(images) + param_dict, _ = fluid.load_dygraph(premodel_path) + model.set_dict(param_dict) + model.eval() - pred_arr[start:end] = pred.reshape(end - start, -1) + pred = model(images)[0][0].numpy() + + pred_arr[start:end] = pred.reshape(end - start, -1) return pred_arr +def inception_infer(x, model_path): + exe = fluid.Executor() + [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_path, exe) + results = exe.run(inference_program, + feed={feed_target_names[0]: x}, + fetch_list=fetch_targets) + return results[0] + + def _calculate_activation_statistics(files, model, premodel_path, batch_size=50, dims=2048, - use_gpu=False): + use_gpu=False, + style = None): act = _get_activations(files, model, batch_size, dims, use_gpu, - premodel_path) + premodel_path, style) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) return mu, sigma def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, - premodel_path): + premodel_path, style=None): if path.endswith('.npz'): f = np.load(path) m, s = f['mu'][:], f['sigma'][:] @@ -193,7 +229,7 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'): files.append(os.path.join(root, filename)) m, s = _calculate_activation_statistics(files, model, premodel_path, - batch_size, dims, use_gpu) + batch_size, dims, use_gpu, style) return m, s @@ -202,7 +238,8 @@ def calculate_fid_given_paths(paths, batch_size, use_gpu, dims, - model=None): + model=None, + style = None): assert os.path.exists( premodel_path ), 'pretrain_model path {} is not exists! Please download it first'.format( @@ -216,9 +253,9 @@ def calculate_fid_given_paths(paths, model = InceptionV3([block_idx], class_dim=1008) m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, - use_gpu, premodel_path) + use_gpu, premodel_path, style) m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, - use_gpu, premodel_path) + use_gpu, premodel_path, style) fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value diff --git a/ppgan/metric/test_fid_score.py b/ppgan/metric/test_fid_score.py index e8abccaaf3e8c..36412a5510408 100644 --- a/ppgan/metric/test_fid_score.py +++ b/ppgan/metric/test_fid_score.py @@ -38,6 +38,9 @@ def parse_args(): type=int, default=1, help='sample number in a batch for inference.') + parser.add_argument('--style', + type=str, + help='calculation style: stargan or default (gan-compression style)') args = parser.parse_args() return args @@ -50,10 +53,9 @@ def main(): inference_model_path = args.inference_model batch_size = args.batch_size - with fluid.dygraph.guard(): - fid_value = calculate_fid_given_paths(paths, inference_model_path, - batch_size, args.use_gpu, 2048) - print('FID: ', fid_value) + fid_value = calculate_fid_given_paths(paths, inference_model_path, + batch_size, args.use_gpu, 2048, style=args.style) + print('FID: ', fid_value) if __name__ == "__main__": From 4fde6c882309db66483f6093d5d8b9d7e56ab48c Mon Sep 17 00:00:00 2001 From: Mark Ma <519329064@qq.com> Date: Sun, 30 Aug 2020 15:15:56 +0800 Subject: [PATCH 2/4] update README --- ppgan/metric/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ppgan/metric/README.md b/ppgan/metric/README.md index d27e99d639bfe..08fe7e700a48b 100644 --- a/ppgan/metric/README.md +++ b/ppgan/metric/README.md @@ -8,3 +8,12 @@ wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams ``` python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams ``` + +### Inception-V3 weights converted from torchvision + +Download: https://aistudio.baidu.com/aistudio/datasetdetail/51890 + +This model weights file is converted from official torchvision inception-v3 model. And both BigGAN and StarGAN-v2 is using it to calculate FID score. + +Note that this model weights is different from above one (which is converted from tensorflow unofficial version) + From 35eb66c320b65a99379214db1802f51e593068e9 Mon Sep 17 00:00:00 2001 From: Mark Ma <519329064@qq.com> Date: Tue, 1 Sep 2020 17:50:11 +0800 Subject: [PATCH 3/4] add progress bar display --- ppgan/metric/compute_fid.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py index 8d000a0d5235d..c7f8c0a0d4c5b 100644 --- a/ppgan/metric/compute_fid.py +++ b/ppgan/metric/compute_fid.py @@ -23,9 +23,11 @@ from inception import InceptionV3 from paddle.fluid.dygraph.base import to_variable - -def tqdm(x): - return x +try: + from tqdm import tqdm +except: + def tqdm(x): + return x """ based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py From b1797a3498b7b323e207a4de18ae3930e2b187c8 Mon Sep 17 00:00:00 2001 From: Mark Ma <519329064@qq.com> Date: Fri, 11 Sep 2020 11:36:04 +0800 Subject: [PATCH 4/4] fixed InceptionV3 dygraph class creation is not wrapped with fluid.dygraph.guard() --- ppgan/metric/compute_fid.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py index c7f8c0a0d4c5b..3e1d013ed2dfc 100644 --- a/ppgan/metric/compute_fid.py +++ b/ppgan/metric/compute_fid.py @@ -250,9 +250,10 @@ def calculate_fid_given_paths(paths, if not os.path.exists(p): raise RuntimeError('Invalid path: %s' % p) - if model is None: - block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] - model = InceptionV3([block_idx], class_dim=1008) + if model is None and style != 'stargan': + with fluid.dygraph.guard(): + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + model = InceptionV3([block_idx], class_dim=1008) m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, use_gpu, premodel_path, style)