Skip to content

Commit

Permalink
Revert "fix cifar label dimension. test=develop (#33475)" (#34242)
Browse files Browse the repository at this point in the history
This reverts commit 6c11034.
  • Loading branch information
heavengate authored Jul 20, 2021
1 parent 301fb64 commit 1f6f223
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 16 deletions.
12 changes: 0 additions & 12 deletions python/paddle/tests/test_dataset_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def test_main(self):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)


Expand All @@ -51,8 +49,6 @@ def test_main(self):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)

# test cv2 backend
Expand All @@ -67,8 +63,6 @@ def test_main(self):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)

with self.assertRaises(ValueError):
Expand All @@ -89,8 +83,6 @@ def test_main(self):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)


Expand All @@ -108,8 +100,6 @@ def test_main(self):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)

# test cv2 backend
Expand All @@ -124,8 +114,6 @@ def test_main(self):
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(len(label.shape) == 1)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 99)

with self.assertRaises(ValueError):
Expand Down
7 changes: 3 additions & 4 deletions python/paddle/vision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def _load_data(self):
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in six.moves.zip(data, labels):
self.data.append((sample,
np.array([label]).astype('int64')))
self.data.append((sample, label))

def __getitem__(self, idx):
image, label = self.data[idx]
Expand All @@ -162,9 +161,9 @@ def __getitem__(self, idx):
image = self.transform(image)

if self.backend == 'pil':
return image, label.astype('int64')
return image, np.array(label).astype('int64')

return image.astype(self.dtype), label.astype('int64')
return image.astype(self.dtype), np.array(label).astype('int64')

def __len__(self):
return len(self.data)
Expand Down

0 comments on commit 1f6f223

Please sign in to comment.