-
Notifications
You must be signed in to change notification settings - Fork 4
/
lt_dataloaders.py
156 lines (131 loc) · 5.61 KB
/
lt_dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import random
import numpy as np
import os, sys
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Sampler
# from base import BaseDataLoader
from PIL import Image
class BalancedSampler(Sampler):
def __init__(self, buckets, retain_epoch_size=False):
for bucket in buckets:
random.shuffle(bucket)
self.bucket_num = len(buckets)
self.buckets = buckets
self.bucket_pointers = [0 for _ in range(self.bucket_num)]
self.retain_epoch_size = retain_epoch_size
def __iter__(self):
count = self.__len__()
while count > 0:
yield self._next_item()
count -= 1
def _next_item(self):
bucket_idx = random.randint(0, self.bucket_num - 1)
bucket = self.buckets[bucket_idx]
item = bucket[self.bucket_pointers[bucket_idx]]
self.bucket_pointers[bucket_idx] += 1
if self.bucket_pointers[bucket_idx] == len(bucket):
self.bucket_pointers[bucket_idx] = 0
random.shuffle(bucket)
return item
def __len__(self):
if self.retain_epoch_size:
return sum([len(bucket) for bucket in self.buckets]) # Actually we need to upscale to next full batch
else:
return max([len(bucket) for bucket in self.buckets]) * self.bucket_num # Ensures every instance has the chance to be visited in an epoch
class LT_Dataset(Dataset):
def __init__(self, root, txt, transform=None, training=False):
self.img_path = []
self.labels = []
self.transform = transform
with open(txt) as f:
if 'ImageNet_LT_test' in txt:
for line in f:
tmp = line.split()[0]
tmp = tmp[:3] + tmp[13:]
pth = os.path.join(root, tmp)
self.img_path.append(pth)
self.labels.append(int(line.split()[1]))
else:
for line in f:
self.img_path.append(os.path.join(root, line.split()[0]))
self.labels.append(int(line.split()[1]))
self.targets = self.labels # Sampler needs to use targets
self.train = training
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
path = self.img_path[index]
label = self.labels[index]
with open(path, 'rb') as f:
sample = Image.open(f).convert('RGB')
# print(sample.size)
# sample.save('./test_1.jpg')
if self.transform is not None:
sample_ts = self.transform(sample)
# return sample, label, path
if self.train:
return sample_ts, label, index
else:
return sample_ts, label,
class ImageNetLTDataLoader(DataLoader):
"""
ImageNetLT Data Loader
"""
def __init__(self, data_dir, batch_size, shuffle=True, num_workers=1, training=True, balanced=False, retain_epoch_size=True,
train_txt="./data_txt/ImageNet_LT/ImageNet_LT_train.txt",
val_txt="./data_txt/ImageNet_LT/ImageNet_LT_val.txt",
test_txt="./data_txt/ImageNet_LT/ImageNet_LT_test.txt"):
train_trsfm = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_trsfm = transforms.Compose([
transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if training:
dataset = LT_Dataset(data_dir, train_txt, train_trsfm, training=training)
# dataset = LT_Dataset(data_dir, train_txt, test_trsfm, training=training)
val_dataset = LT_Dataset(data_dir, val_txt, test_trsfm, training=False)
else: # test
dataset = LT_Dataset(data_dir, test_txt, test_trsfm, training=False)
val_dataset = None
self.dataset = dataset
self.val_dataset = val_dataset
self.n_samples = len(self.dataset)
num_classes = len(np.unique(dataset.targets))
assert num_classes == 1000
self.num_classes = num_classes
cls_num_list = [0] * num_classes
for label in dataset.targets:
cls_num_list[label] += 1
self.cls_num_list = cls_num_list
if balanced:
if training:
buckets = [[] for _ in range(num_classes)]
for idx, label in enumerate(dataset.targets):
buckets[label].append(idx)
sampler = BalancedSampler(buckets, retain_epoch_size)
shuffle = False
else:
print("Test set will not be evaluated with balanced sampler, nothing is done to make it balanced")
else:
sampler = None
self.shuffle = shuffle
self.init_kwargs = {
'batch_size': batch_size,
'shuffle': self.shuffle,
'num_workers': num_workers
}
super().__init__(dataset=self.dataset, **self.init_kwargs, sampler=sampler) # Note that sampler does not apply to validation set
def split_validation(self):
# If you do not want to validate:
#return None
# If you want to validate:
return DataLoader(dataset=self.val_dataset, **self.init_kwargs)