-
Notifications
You must be signed in to change notification settings - Fork 74
/
features.py
77 lines (71 loc) · 3.08 KB
/
features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
helpers for extracting features from image
"""
import os
import platform
import numpy as np
import torch
import cleanfid
from cleanfid.downloads_helper import check_download_url
from cleanfid.inception_pytorch import InceptionV3
from cleanfid.inception_torchscript import InceptionV3W
"""
returns a functions that takes an image in range [0,255]
and outputs a feature embedding vector
"""
def feature_extractor(name="torchscript_inception", device=torch.device("cuda"), resize_inside=False, use_dataparallel=True):
if name == "torchscript_inception":
path = "./" if platform.system() == "Windows" else "/tmp"
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(device)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x)
elif name == "pytorch_inception":
model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1)
else:
raise ValueError(f"{name} feature extractor not implemented")
return model_fn
"""
Build a feature extractor for each of the modes
"""
def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
if mode == "legacy_pytorch":
feat_model = feature_extractor(name="pytorch_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
elif mode == "legacy_tensorflow":
feat_model = feature_extractor(name="torchscript_inception", resize_inside=True, device=device, use_dataparallel=use_dataparallel)
elif mode == "clean":
feat_model = feature_extractor(name="torchscript_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
return feat_model
"""
Load precomputed reference statistics for commonly used datasets
"""
def get_reference_statistics(name, res, mode="clean", model_name="inception_v3", seed=0, split="test", metric="FID"):
base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
if split == "custom":
res = "na"
if model_name=="inception_v3":
model_modifier = ""
else:
model_modifier = "_"+model_name
if metric == "FID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz").lower()
url = f"{base_url}/{rel_path}"
mod_path = os.path.dirname(cleanfid.__file__)
stats_folder = os.path.join(mod_path, "stats")
fpath = check_download_url(local_folder=stats_folder, url=url)
stats = np.load(fpath)
mu, sigma = stats["mu"], stats["sigma"]
return mu, sigma
elif metric == "KID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz").lower()
url = f"{base_url}/{rel_path}"
mod_path = os.path.dirname(cleanfid.__file__)
stats_folder = os.path.join(mod_path, "stats")
fpath = check_download_url(local_folder=stats_folder, url=url)
stats = np.load(fpath)
return stats["feats"]