forked from NRCan/geo-deep-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CreateDataset.py
31 lines (21 loc) · 855 Bytes
/
CreateDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import h5py
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
"""Dataset for semantic segmentation"""
def __init__(self, work_folder, num_samples, dataset_type, transform=None):
self.work_folder = work_folder
self.num_samples = num_samples
self.dataset_type = dataset_type
self.transform = transform
def __len__(self):
return self.num_samples
def __getitem__(self, index):
hdf5_file = h5py.File(os.path.join(self.work_folder, self.dataset_type + "_samples.hdf5"), "r")
sat_img = hdf5_file["sat_img"][index, ...]
map_img = hdf5_file["map_img"][index, ...]
hdf5_file.close()
sample = {'sat_img': sat_img, 'map_img': map_img}
if self.transform:
sample = self.transform(sample)
return sample