Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Justin dev #62

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions cleanfid/clip_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,31 @@

def img_preprocess_clip(img_np):
x = Image.fromarray(img_np.astype(np.uint8)).convert("RGB")
T = transforms.Compose([
T = transforms.Compose(
[
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
])
]
)
return np.asarray(T(x)).clip(0, 255).astype(np.uint8)


class CLIP_fx():
class CLIP_fx:
def __init__(self, name="ViT-B/32", device="cuda"):
self.model, _ = clip.load(name, device=device)
self.model.eval()
self.name = "clip_"+name.lower().replace("-","_").replace("/","_")
self.name = "clip_" + name.lower().replace("-", "_").replace("/", "_")

def __call__(self, img_t):
img_x = img_t/255.0
T_norm = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
img_x = img_t / 255.0
T_norm = transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
)
img_x = T_norm(img_x)
assert torch.is_tensor(img_x)
if len(img_x.shape)==3:
if len(img_x.shape) == 3:
img_x = img_x.unsqueeze(0)
B,C,H,W = img_x.shape
B, C, H, W = img_x.shape
with torch.no_grad():
z = self.model.encode_image(img_x)
return z
18 changes: 13 additions & 5 deletions cleanfid/downloads_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
ARGS:
fpath - output folder path
"""


def check_download_inception(fpath="./"):
inception_path = os.path.join(fpath, "inception-2015-12-05.pt")
if not os.path.exists(inception_path):
# download the file
with urllib.request.urlopen(inception_url) as response, open(inception_path, 'wb') as f:
with urllib.request.urlopen(inception_url) as response, open(
inception_path, "wb"
) as f:
shutil.copyfileobj(response, f)
return inception_path

Expand All @@ -27,13 +31,15 @@ def check_download_inception(fpath="./"):
local_folder - output folder path
url - the weburl to download
"""


def check_download_url(local_folder, url):
name = os.path.basename(url)
local_path = os.path.join(local_folder, name)
if not os.path.exists(local_path):
os.makedirs(local_folder, exist_ok=True)
print(f"downloading statistics to {local_path}")
with urllib.request.urlopen(url) as response, open(local_path, 'wb') as f:
with urllib.request.urlopen(url) as response, open(local_path, "wb") as f:
shutil.copyfileobj(response, f)
return local_path

Expand All @@ -44,20 +50,22 @@ def check_download_url(local_folder, url):
file_id - id of the google drive file
out_path - output folder path
"""


def download_google_drive(file_id, out_path):
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
if key.startswith("download_warning"):
return value
return None

URL = "https://drive.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={'id': file_id}, stream=True)
response = session.get(URL, params={"id": file_id}, stream=True)
token = get_confirm_token(response)

if token:
params = {'id': file_id, 'confirm': token}
params = {"id": file_id, "confirm": token}
response = session.get(URL, params=params, stream=True)

CHUNK_SIZE = 32768
Expand Down
61 changes: 51 additions & 10 deletions cleanfid/features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
helpers for extracting features from image
"""

import os
import platform
import numpy as np
Expand All @@ -15,20 +16,33 @@
returns a functions that takes an image in range [0,255]
and outputs a feature embedding vector
"""
def feature_extractor(name="torchscript_inception", device=torch.device("cuda"), resize_inside=False, use_dataparallel=True):


def feature_extractor(
name="torchscript_inception",
device=torch.device("cuda"),
resize_inside=False,
use_dataparallel=True,
):
if name == "torchscript_inception":
path = "./" if platform.system() == "Windows" else "/tmp"
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(device)
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(
device
)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x)

def model_fn(x):
return model(x)
elif name == "pytorch_inception":
model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1)

def model_fn(x):
return model(x / 255)[0].squeeze(-1).squeeze(-1)
else:
raise ValueError(f"{name} feature extractor not implemented")
return model_fn
Expand All @@ -37,27 +51,54 @@ def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1)
"""
Build a feature extractor for each of the modes
"""


def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
if mode == "legacy_pytorch":
feat_model = feature_extractor(name="pytorch_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
feat_model = feature_extractor(
name="pytorch_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "legacy_tensorflow":
feat_model = feature_extractor(name="torchscript_inception", resize_inside=True, device=device, use_dataparallel=use_dataparallel)
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=True,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "clean":
feat_model = feature_extractor(name="torchscript_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
return feat_model


"""
Load precomputed reference statistics for commonly used datasets
"""
def get_reference_statistics(name, res, mode="clean", model_name="inception_v3", seed=0, split="test", metric="FID"):


def get_reference_statistics(
name,
res,
mode="clean",
model_name="inception_v3",
seed=0,
split="test",
metric="FID",
):
base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
if split == "custom":
res = "na"
if model_name=="inception_v3":
if model_name == "inception_v3":
model_modifier = ""
else:
model_modifier = "_"+model_name
model_modifier = "_" + model_name
if metric == "FID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz").lower()
url = f"{base_url}/{rel_path}"
Expand Down
Loading