-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_dataset.py
112 lines (92 loc) · 4.7 KB
/
custom_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class CustomImageDataset(Dataset):
"""
Custom dataset for loading images from folders named with digits ('0', '1', '2', etc.).
Each folder represents a class, and the images inside are labeled accordingly.
Args:
img_dir (str): Path to the directory containing the image folders.
transform (callable, optional): A function/transform to apply to the images.
limit_per_class (int, optional): Maximum number of images to load per class.
"""
def __init__(self, img_dir, transform=None, limit_per_class=None):
"""
Initialize the dataset by gathering image file paths and their corresponding labels.
Args:
img_dir (str): The directory where the image folders are stored.
transform (callable, optional): Transformation to be applied to the images.
limit_per_class (int, optional): A limit on the number of images to load per class.
"""
self.img_dir = img_dir
self.transform = transform
self.image_filenames = []
self.labels = []
# Map each folder name (class) to a unique integer label
self.class_names = sorted(os.listdir(img_dir))
self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)}
# # Loop through each folder (class) and collect image paths and labels
# for label in sorted(os.listdir(img_dir)):
# folder_path = os.path.join(img_dir, label)
# if os.path.isdir(folder_path) and label.isdigit(): # Ensure the folder name is a digit
# # Collect image files with '.jpg' or '.png' extension
# files = [file_name for file_name in os.listdir(folder_path)
# if file_name.endswith('.jpg') or file_name.endswith('.png')]
#
# # Limit the number of files per class if needed
# if limit_per_class:
# files = files[:limit_per_class]
#
# # Append each image path and its label (folder name) to the respective lists
# for file_name in files:
# self.image_filenames.append(os.path.join(folder_path, file_name))
# self.labels.append(int(label))
# Loop through each folder and collect image paths and labels
for class_name in self.class_names:
folder_path = os.path.join(img_dir, class_name)
if os.path.isdir(folder_path):
# Collect image files with '.jpg' or '.png' extension
files = [file_name for file_name in os.listdir(folder_path)
if file_name.endswith('.jpg') or file_name.endswith('.png')]
# Limit the number of files per class if needed
if limit_per_class:
files = files[:limit_per_class]
# Append each image path and its mapped label
for file_name in files:
self.image_filenames.append(os.path.join(folder_path, file_name))
self.labels.append(self.class_to_idx[class_name]) # Use mapped integer label
def __len__(self):
"""
Returns the total number of images in the dataset.
"""
return len(self.image_filenames)
def __getitem__(self, idx):
"""
Retrieve an image and its label by index.
Args:
idx (int): Index of the image.
Returns:
image (PIL Image or transformed image): The image at the specified index.
label (int): The label corresponding to the image.
"""
img_path = self.image_filenames[idx]
image = Image.open(img_path).convert('L') # Convert image to grayscale
label = self.labels[idx]
# Apply transformations if provided
if self.transform:
image = self.transform(image)
return image, label
if __name__ == "__main__":
# Define the transformations to be applied to each image
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224 pixels
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values to have mean 0.5 and std 0.5
])
dataset_path = '/content/drive/MyDrive/ViT/data_alphabets_digits/' # Path to the dataset
dataset = CustomImageDataset(img_dir=dataset_path, transform=transform)
# Output the total number of images and details of the first image
print(f"Number of images in dataset: {len(dataset)}")
image, label = dataset[0]
print(f"Image shape: {image.shape}, Label: {label}")