-
Notifications
You must be signed in to change notification settings - Fork 0
/
kaggle_imagenet.py
69 lines (57 loc) · 2.4 KB
/
kaggle_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
import torch.nn.functional as F
import torchvision
import xml.etree.ElementTree as Xml
from config import Config
from PIL import Image
from torch.utils.data import Dataset
PADDING_VALUE = -1
class Annotation:
def __init__(self, config: Config, file_path, synsets: list):
self.config = config
self.synsets = synsets
self.labels = set()
xml = Xml.parse(file_path).getroot()
filename = xml.find("filename").text
if not file_path.endswith(f"{filename}.xml"):
raise Exception("File name mismatch with annotation")
for object in xml.iter("object"):
name = object.find("name").text
if name is not None:
self.labels.add(name)
if len(self.labels) == config.runtime.num_predictions:
break
def synset_idx_tensor(self):
labels = torch.tensor([self.synsets.index(label) for label in self.labels])
# Pad the tensor so all label tensors are the same size as the number of predictions.
# The fact that the match the prediction tensor size is somewhat arbitrary, but all
# label tensors must be the same size, regardless of what that size is.
return F.pad(
labels,
(0, self.config.runtime.num_predictions - len(labels)),
mode="constant",
value=PADDING_VALUE,
)
class KaggleImageNetDataset(Dataset):
def __init__(self, config: Config):
self.config = config
self.img_names = os.listdir(config.imagenet.data_dir)
self.img_names.sort()
self.transforms = torchvision.models.AlexNet_Weights.IMAGENET1K_V1.transforms()
self.synsets = []
with open(config.imagenet.synset_file) as synset_file:
for line in synset_file.readlines():
[synset_id, _] = line.split(" ", 1)
self.synsets.append(synset_id)
def __len__(self):
return len(self.img_names)
def __getitem__(self, i):
img_path = os.path.join(self.config.imagenet.data_dir, self.img_names[i])
img = Image.open(img_path).convert("RGB")
annotation_path = os.path.join(
self.config.imagenet.annotations_dir,
self.img_names[i].replace("JPEG", "xml"),
)
annotation = Annotation(self.config, annotation_path, self.synsets)
return self.transforms(img), annotation.synset_idx_tensor()