From d8788325395f52e198bfa13beeac1718251ce836 Mon Sep 17 00:00:00 2001 From: Christian Engel Date: Mon, 7 Aug 2023 19:20:27 +0200 Subject: [PATCH 1/2] Fix bug where self.target_transform is needed before assignment --- dataloader_pth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataloader_pth.py b/dataloader_pth.py index 008ddd6..880daed 100644 --- a/dataloader_pth.py +++ b/dataloader_pth.py @@ -96,10 +96,10 @@ def __init__(self, for aug in augmentation_list: self.augmentation_list.append(self.augmentations[aug]) trainform, testform = self.transform() + self.target_transform = self.to_binary self.build_train_dataset(trainform) self.build_val_dataset(trainform) self.build_test_dataset(testform) - self.target_transform = self.to_binary def list_dataset_variants(self): print(self.list_dataset_variant) From a7e0d1d26d881dcb8f1436f2340eb8897525732d Mon Sep 17 00:00:00 2001 From: Christian Engel Date: Mon, 7 Aug 2023 19:42:44 +0200 Subject: [PATCH 2/2] Fix T50 dataloader --- dataloader_pth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataloader_pth.py b/dataloader_pth.py index 880daed..cf44267 100644 --- a/dataloader_pth.py +++ b/dataloader_pth.py @@ -196,7 +196,7 @@ class T50(Dataset): def __init__(self, img_dir, label_file, transform=None, target_transform=None): label_data = json.load(open(label_file, "rb")) self.label_data = label_data["annotations"] - self.frames = self.label_data.keys() + self.frames = list(self.label_data.keys()) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform @@ -229,7 +229,7 @@ def get_binary_labels(self, labels): return (triplet_label, tool_label, verb_label, target_label, phase_label) def __getitem__(self, index): - labels = self.label_data["annotations"][self.frames[index]] + labels = self.label_data[self.frames[index]] basename = "{}.png".format(str(self.frames[index]).zfill(6)) img_path = os.path.join(self.img_dir, basename) image = Image.open(img_path)