-
Notifications
You must be signed in to change notification settings - Fork 6
/
eval_overall_mps_on_imagereward.py
143 lines (117 loc) · 5.15 KB
/
eval_overall_mps_on_imagereward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
# from transformers import AutoProcessor #, AutoModel
import torch
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from fire import Fire
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
from dataclasses import dataclass
from transformers import CLIPModel as HFCLIPModel
from torch import nn, einsum
from trainer.models.base_model import BaseModelConfig
from transformers import CLIPConfig
from transformers import AutoProcessor, AutoModel, AutoTokenizer
from typing import Any, Optional, Tuple, Union
import torch
import cv2
import os
from trainer.models.cross_modeling import Cross_model
import matplotlib.pyplot as plt
import torch.nn.functional as F
import gc
import json
@torch.no_grad()
def infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device, condition=None):
def _process_image(image):
if isinstance(image, dict):
image = image["bytes"]
if isinstance(image, bytes):
image = Image.open(BytesIO(image))
if isinstance(image, str):
image = Image.open( image )
image = image.convert("RGB")
pixel_values = clip_processor(image, return_tensors="pt")["pixel_values"]
return pixel_values
def _tokenize(caption):
input_ids = tokenizer(
caption,
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return input_ids
image_input = _process_image(image).to(device)
text_input = _tokenize(prompt).to(device)
if condition is None:
condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things."
condition_batch = _tokenize(condition).repeat(text_input.shape[0],1).to(device)
with torch.no_grad():
text_f, text_features = clip_model.model.get_text_features(text_input)
image_f = clip_model.model.get_image_features(image_input.half())
condition_f, _ = clip_model.model.get_text_features(condition_batch)
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
sim_text_condition = sim_text_condition / sim_text_condition.max()
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
mask = mask.repeat(1,image_f.shape[1],1)
image_features = clip_model.cross_model(image_f, text_f,mask.half())[:,0,:]
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_score = clip_model.logit_scale.exp() * text_features @ image_features.T
return image_score[0]
def infer_example(images, prompt, clip_model, clip_processor, tokenizer, device):
scores = []
for image in images:
score = infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device)
scores.append(score)
scores = torch.stack(scores, dim=-1)
probs = torch.softmax(scores, dim=-1)[0]
return probs.cpu().tolist()
def acc(score_sample, predict_sample):
tol_cnt = 0.
true_cnt = 0.
for idx in range(len(score_sample)):
item_base = score_sample[idx]["ranking"]
item = predict_sample[idx]["rewards"]
for i in range(len(item_base)):
for j in range(i+1, len(item_base)):
if item_base[i] > item_base[j]:
if item[i] >= item[j]:
tol_cnt += 1
elif item[i] < item[j]:
tol_cnt += 1
true_cnt += 1
elif item_base[i] < item_base[j]:
if item[i] > item[j]:
tol_cnt += 1
true_cnt += 1
elif item[i] <= item[j]:
tol_cnt += 1
return true_cnt / tol_cnt
def main():
processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
device = "cuda"
image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
model_ckpt_path = "outputs/MPS_overall_checkpoint.pth"
model = torch.load(model_ckpt_path)
model.eval().to(device)
score_sample = []
with open("imagereward/test.json", "r") as f: # change the path to the ImageReward test dataset
score_sample = json.load(f)
predict_sample = []
with torch.no_grad():
for item in score_sample:
rewards = infer_example(item["generations"], item["prompt"], model, image_processor, tokenizer, device)
predict_item = {
"id": item["id"],
"prompt": item["prompt"],
"rewards": rewards
}
predict_sample.append(predict_item)
test_acc = acc(score_sample, predict_sample)
print(f"ImageReward Test Acc: {100 * test_acc:.2f}%")
if __name__ == '__main__':
Fire(main)