-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_caption.py
124 lines (100 loc) · 3.97 KB
/
eval_caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
from dataclasses import dataclass
import json
import os
import re
import torch.utils.data
from adem.build import create_model
from adem.tokenizer import Tokenizer
from util.coco_karpathy_dataset import coco_caption_eval, coco_karpathy_caption_eval
from util.misc import MetricLogger
@dataclass
class ModelArgs:
llama_model_path = './data/weights/'
llm_model = '7B'
max_seq_len = 512
hidden_proj = 128
cpu_load = False
alpha = 0.1
adapter_dim = 12
gradient_checkpointing = False
is_train = False
data_root = './data/'
clip = 'ViT-L/14'
clip_root = './clip'
down_sample_num = [256, 64]
no_cls = False
drop_ratio = 0.1
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='./data')
parser.add_argument('--clip', type=str, default='ViT-L/14')
parser.add_argument('--clip_root', type=str, default='./clip')
parser.add_argument('--llm_model', type=str, default='7B')
parser.add_argument('--adapter_path', type=str, default='./output_dir')
parser.add_argument('--log_dir', type=str, default='./output_dir')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--down_sample_num', type=int, nargs='+', default=[256, 64])
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--drop_ratio', type=float, default=0.1)
parser.add_argument('--no_cls', action='store_true')
args = parser.parse_args()
log_dir = args.log_dir if args.log_dir is not None else './logs'
os.makedirs(log_dir, exist_ok=True)
llama_model_path = os.path.join(args.data_root, "weights/")
model_args = ModelArgs()
model_args.llama_model_path = llama_model_path
model_args.llm_model = args.llm_model
model_args.alpha = args.alpha
model_args.beta = args.beta
model_args.data_root = args.data_root
model_args.clip = args.clip
model_args.clip_root = args.clip_root
model_args.down_sample_num = args.down_sample_num
model_args.no_cls = args.no_cls
model_args.drop_ratio = args.drop_ratio
llama = create_model(model_args)
adapter = torch.load(os.path.join(args.adapter_path, 'checkpoint-4.pth'))['model']
sd = {}
for k in adapter:
sd[k.replace('module.', '')] = adapter[k]
_IncompatibleKeys = llama.load_state_dict(sd, False)
print(_IncompatibleKeys)
tokenizer = Tokenizer(model_path=os.path.join(args.llama_model_path, 'tokenizer.model'))
dataset_test = coco_karpathy_caption_eval(image_root=os.path.join(args.data_root, 'images'),
ann_root=os.path.join(args.data_root, 'coco_caption'))
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
)
llama.eval()
pattern = re.compile(r'picture of (.+)')
metric_logger = MetricLogger(delimiter=" ")
header = 'Caption generation:'
print_freq = 100
result = []
prompt = 'a picture of'
for image, image_id in metric_logger.log_every(data_loader_test, print_freq, header):
captions = llama.generate(
[prompt] * image.size(0), images=image, indicators=[1] * image.size(0), max_gen_len=20, tokenizer=tokenizer,
temperature=0.0
)
matched_caption = []
for c in captions:
pred = pattern.findall(c)
if len(pred) >= 1:
pred = pred[0]
else:
print(c)
pred = c
matched_caption.append(pred)
for caption, img_id in zip(matched_caption, image_id):
result.append({"image_id": img_id.item(), "caption": caption})
result_file = os.path.join(log_dir, 'test_result.json')
json.dump(result, open(result_file, 'w'))
coco_test = coco_caption_eval(os.path.join(args.data_root, 'coco_caption/'), result_file, split='val')
log_stats = {**{f'test_{k}': v for k, v in coco_test.eval.items()}}
with open(os.path.join(log_dir, "evaluate.txt"), "a") as f:
f.write(json.dumps(log_stats) + "\n")