-
Notifications
You must be signed in to change notification settings - Fork 0
/
img_features.py
27 lines (21 loc) · 1.04 KB
/
img_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
from utils.load_data_utils import load_config, get_image_to_caption_map
from utils.img_features_utils import extract_and_save_img_features
def main():
config = load_config()
train_df = get_image_to_caption_map(config['preprocessing'], 'train')
val_df = get_image_to_caption_map(config['preprocessing'], 'val')
train_img_files = train_df['filename'].tolist()
val_img_files = val_df['filename'].tolist()
all_files = train_img_files + val_img_files
print('len of train_img_files = ', len(train_img_files))
print('len of val_img_files = ', len(val_img_files))
print('len of all_files = ', len(all_files))
#extract image features using transfer learning if not done already (or during the first run!)
if not os.path.exists(config['preprocessing']['images_features_dir']):
extract_and_save_img_features(all_files, config['nn_params']['BATCH_SIZE'])
if __name__ == '__main__':
'''
extract image features from raw images using a pretrained CNN model and save to disk
'''
main()