forked from OpenGVLab/SAM-Med2D
-
Notifications
You must be signed in to change notification settings - Fork 1
/
DataLoader.py
115 lines (94 loc) · 4.6 KB
/
DataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import torch
import numpy as np
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import train_transforms, get_boxes_from_mask, init_point_sampling
import json
import random
class TestingDataset(Dataset):
def __init__(self, data_path, image_size=256, mode='test', requires_name=True, point_num=1, return_ori_mask=True, prompt_path=None):
"""
Initializes a TestingDataset object.
Args:
data_path (str): The path to the data.
image_size (int, optional): The size of the image. Defaults to 256.
mode (str, optional): The mode of the dataset. Defaults to 'test'.
requires_name (bool, optional): Indicates whether the dataset requires image names. Defaults to True.
point_num (int, optional): The number of points to retrieve. Defaults to 1.
return_ori_mask (bool, optional): Indicates whether to return the original mask. Defaults to True.
prompt_path (str, optional): The path to the prompt file. Defaults to None.
"""
self.image_size = image_size
self.return_ori_mask = return_ori_mask
self.prompt_path = prompt_path
self.prompt_list = {} if prompt_path is None else json.load(open(prompt_path, "r"))
self.requires_name = requires_name
self.point_num = point_num
json_file = open(os.path.join(data_path, f'label2image_{mode}.json'), "r")
dataset = json.load(json_file)
self.image_paths = list(dataset.values())
self.label_paths = list(dataset.keys())
self.pixel_mean = [123.675, 116.28, 103.53]
self.pixel_std = [58.395, 57.12, 57.375]
def __getitem__(self, index):
"""
Retrieves and preprocesses an item from the dataset.
Args:
index (int): The index of the item to retrieve.
Returns:
dict: A dictionary containing the preprocessed image and associated information.
"""
image_input = {}
try:
image = cv2.imread(self.image_paths[index])
image = (image - self.pixel_mean) / self.pixel_std
except:
print(self.image_paths[index])
mask_path = self.label_paths[index]
ori_np_mask = cv2.imread(mask_path, 0)
if ori_np_mask.max() == 255:
ori_np_mask = ori_np_mask / 255
assert np.array_equal(ori_np_mask, ori_np_mask.astype(bool)), f"Mask should only contain binary values 0 and 1. {self.label_paths[index]}"
h, w = ori_np_mask.shape
ori_mask = torch.tensor(ori_np_mask).unsqueeze(0)
transforms = train_transforms(self.image_size, h, w)
augments = transforms(image=image, mask=ori_np_mask)
image, mask = augments['image'], augments['mask'].to(torch.int64)
if self.prompt_path is None:
boxes = get_boxes_from_mask(mask)
point_coords, point_labels = init_point_sampling(mask, self.point_num)
else:
prompt_key = mask_path.split('/')[-1]
boxes = torch.as_tensor(self.prompt_list[prompt_key]["boxes"], dtype=torch.float)
point_coords = torch.as_tensor(self.prompt_list[prompt_key]["point_coords"], dtype=torch.float)
point_labels = torch.as_tensor(self.prompt_list[prompt_key]["point_labels"], dtype=torch.int)
image_input["image"] = image
image_input["label"] = mask.unsqueeze(0)
image_input["point_coords"] = point_coords
image_input["point_labels"] = point_labels
image_input["boxes"] = boxes
image_input["original_size"] = (h, w)
image_input["label_path"] = '/'.join(mask_path.split('/')[:-1])
if self.return_ori_mask:
image_input["ori_label"] = ori_mask
image_name = self.label_paths[index].split('/')[-1]
if self.requires_name:
image_input["name"] = image_name
return image_input
else:
return image_input
def __len__(self):
return len(self.label_paths)
if __name__ == "__main__":
test_dataset = TestingDataset("data_demo", image_size = 256, mode='test', requires_name = True, point_num=1, return_ori_mask=True, prompt_path = None)
print("Dataset:", len(test_dataset))
test_batch_sampler = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4)
for i, batched_image in enumerate(tqdm(test_batch_sampler)):
for k,v in batched_image.items():
print(k, v)