Skip to content

Commit

Permalink
Merge pull request #13 from gbstack/master
Browse files Browse the repository at this point in the history
Add StarGAN-v2 style FID calculation
  • Loading branch information
qingqing01 authored Sep 18, 2020
2 parents 8a4848d + b1797a3 commit de6cb8d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 26 deletions.
9 changes: 9 additions & 0 deletions ppgan/metric/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@ wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams
```
python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams
```

### Inception-V3 weights converted from torchvision

Download: https://aistudio.baidu.com/aistudio/datasetdetail/51890

This model weights file is converted from official torchvision inception-v3 model. And both BigGAN and StarGAN-v2 is using it to calculate FID score.

Note that this model weights is different from above one (which is converted from tensorflow unofficial version)

84 changes: 62 additions & 22 deletions ppgan/metric/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
import fnmatch
import numpy as np
import cv2
from PIL import Image
from cv2 import imread
from scipy import linalg
import paddle.fluid as fluid
from inception import InceptionV3
from paddle.fluid.dygraph.base import to_variable


def tqdm(x):
return x
try:
from tqdm import tqdm
except:
def tqdm(x):
return x


""" based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py
Expand Down Expand Up @@ -128,7 +131,7 @@ def calculate_fid_given_img(img_fake,
return fid_value


def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path, style=None):
if len(files) % batch_size != 0:
print(('Warning: number of images is not a multiple of the '
'batch size. Some samples are going to be ignored.'))
Expand All @@ -144,8 +147,23 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
for i in tqdm(range(n_batches)):
start = i * batch_size
end = start + batch_size
images = np.array(
[imread(str(f)).astype(np.float32) for f in files[start:end]])

# same as stargan-v2 official implementation: resize to 256 first, then resize to 299
if style == 'stargan':
img_list = []
for f in files[start:end]:
im = Image.open(str(f)).convert('RGB')
if im.size[0] != 299:
im = im.resize((256, 256), 2)
im = im.resize((299, 299), 2)

img_list.append(np.array(im).astype('float32'))

images = np.array(
img_list)
else:
images = np.array(
[imread(str(f)).astype(np.float32) for f in files[start:end]])

if len(images.shape) != 4:
images = imread(str(files[start]))
Expand All @@ -155,33 +173,53 @@ def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path):
images = images.transpose((0, 3, 1, 2))
images /= 255

images = to_variable(images)
param_dict, _ = fluid.load_dygraph(premodel_path)
model.set_dict(param_dict)
model.eval()
# imagenet normalization
if style == 'stargan':
mean = np.array([0.485, 0.456, 0.406]).astype('float32')
std = np.array([0.229, 0.224, 0.225]).astype('float32')
images[:] = (images[:] - mean[:, None, None]) / std[:, None, None]

pred = model(images)[0][0].numpy()
if style=='stargan':
pred_arr[start:end] = inception_infer(images, premodel_path)
else:
with fluid.dygraph.guard():
images = to_variable(images)
param_dict, _ = fluid.load_dygraph(premodel_path)
model.set_dict(param_dict)
model.eval()

pred_arr[start:end] = pred.reshape(end - start, -1)
pred = model(images)[0][0].numpy()

pred_arr[start:end] = pred.reshape(end - start, -1)

return pred_arr


def inception_infer(x, model_path):
exe = fluid.Executor()
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(model_path, exe)
results = exe.run(inference_program,
feed={feed_target_names[0]: x},
fetch_list=fetch_targets)
return results[0]


def _calculate_activation_statistics(files,
model,
premodel_path,
batch_size=50,
dims=2048,
use_gpu=False):
use_gpu=False,
style = None):
act = _get_activations(files, model, batch_size, dims, use_gpu,
premodel_path)
premodel_path, style)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma


def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu,
premodel_path):
premodel_path, style=None):
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
Expand All @@ -193,7 +231,7 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu,
filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'):
files.append(os.path.join(root, filename))
m, s = _calculate_activation_statistics(files, model, premodel_path,
batch_size, dims, use_gpu)
batch_size, dims, use_gpu, style)
return m, s


Expand All @@ -202,7 +240,8 @@ def calculate_fid_given_paths(paths,
batch_size,
use_gpu,
dims,
model=None):
model=None,
style = None):
assert os.path.exists(
premodel_path
), 'pretrain_model path {} is not exists! Please download it first'.format(
Expand All @@ -211,14 +250,15 @@ def calculate_fid_given_paths(paths,
if not os.path.exists(p):
raise RuntimeError('Invalid path: %s' % p)

if model is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx], class_dim=1008)
if model is None and style != 'stargan':
with fluid.dygraph.guard():
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx], class_dim=1008)

m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims,
use_gpu, premodel_path)
use_gpu, premodel_path, style)
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims,
use_gpu, premodel_path)
use_gpu, premodel_path, style)

fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
10 changes: 6 additions & 4 deletions ppgan/metric/test_fid_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def parse_args():
type=int,
default=1,
help='sample number in a batch for inference.')
parser.add_argument('--style',
type=str,
help='calculation style: stargan or default (gan-compression style)')
args = parser.parse_args()
return args

Expand All @@ -50,10 +53,9 @@ def main():
inference_model_path = args.inference_model
batch_size = args.batch_size

with fluid.dygraph.guard():
fid_value = calculate_fid_given_paths(paths, inference_model_path,
batch_size, args.use_gpu, 2048)
print('FID: ', fid_value)
fid_value = calculate_fid_given_paths(paths, inference_model_path,
batch_size, args.use_gpu, 2048, style=args.style)
print('FID: ', fid_value)


if __name__ == "__main__":
Expand Down

0 comments on commit de6cb8d

Please sign in to comment.