-
Notifications
You must be signed in to change notification settings - Fork 110
/
video_loader.py
executable file
·115 lines (97 loc) · 3.86 KB
/
video_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import print_function, absolute_import
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import random
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
class VideoDataset(Dataset):
"""Video Person ReID Dataset.
Note batch data has shape (batch, seq_len, channel, height, width).
"""
sample_methods = ['evenly', 'random', 'all']
def __init__(self, dataset, seq_len=15, sample='evenly', transform=None):
self.dataset = dataset
self.seq_len = seq_len
self.sample = sample
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_paths, pid, camid = self.dataset[index]
num = len(img_paths)
if self.sample == 'random':
"""
Randomly sample seq_len consecutive frames from num frames,
if num is smaller than seq_len, then replicate items.
This sampling strategy is used in training phase.
"""
frame_indices = range(num)
rand_end = max(0, len(frame_indices) - self.seq_len - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.seq_len, len(frame_indices))
indices = frame_indices[begin_index:end_index]
for index in indices:
if len(indices) >= self.seq_len:
break
indices.append(index)
indices=np.array(indices)
imgs = []
for index in indices:
index=int(index)
img_path = img_paths[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0)
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
#imgs=imgs.permute(1,0,2,3)
return imgs, pid, camid
elif self.sample == 'dense':
"""
Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
This sampling strategy is used in test phase.
"""
cur_index=0
frame_indices = range(num)
indices_list=[]
while num-cur_index > self.seq_len:
indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
cur_index+=self.seq_len
last_seq=frame_indices[cur_index:]
for index in last_seq:
if len(last_seq) >= self.seq_len:
break
last_seq.append(index)
indices_list.append(last_seq)
imgs_list=[]
for indices in indices_list:
imgs = []
for index in indices:
index=int(index)
img_path = img_paths[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0)
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
#imgs=imgs.permute(1,0,2,3)
imgs_list.append(imgs)
imgs_array = torch.stack(imgs_list)
return imgs_array, pid, camid
else:
raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods))