Skip to content

Commit

Permalink
【Hackathon 5th No.23】Add ConcatDataset API to Paddle (#57720)
Browse files Browse the repository at this point in the history
* add ConcatDataset API

* remove redundant code
  • Loading branch information
Patrick-Star125 authored Nov 17, 2023
1 parent 2811d1f commit 1916775
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
WeightedRandomSampler,
get_worker_info,
random_split,
ConcatDataset,
)
from .reader import DataLoader

Expand All @@ -50,4 +51,5 @@
'random_split',
'Subset',
'SubsetRandomSampler',
'ConcatDataset',
]
1 change: 1 addition & 0 deletions python/paddle/io/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .dataset import ChainDataset
from .dataset import random_split
from .dataset import Subset
from .dataset import ConcatDataset

from .batch_sampler import BatchSampler
from .batch_sampler import DistributedBatchSampler
Expand Down
81 changes: 81 additions & 0 deletions python/paddle/io/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import bisect
from typing import Iterable

import paddle

from ... import framework
Expand Down Expand Up @@ -566,3 +569,81 @@ def _accumulate(iterable, fn=lambda x, y: x + y):
for element in it:
total = fn(total, element)
yield total


class ConcatDataset(Dataset):
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
Returns:
Dataset: A Dataset which concatenated by multiple datasets.
Examples:
.. code-block:: python
>>> import numpy as np
>>> import paddle
>>> from paddle.io import Dataset, ConcatDataset
>>> # define a random dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([32]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)])
>>> for i in range(len(dataset)):
... image, label = dataset[i]
... # do something
"""

@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r

def __init__(self, datasets: Iterable[Dataset]):
self.datasets = list(datasets)
assert (
len(self.datasets) > 0
), 'datasets should not be an empty iterable'
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
50 changes: 50 additions & 0 deletions test/legacy_test/test_multiprocess_dataloader_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.io import (
ChainDataset,
ComposeDataset,
ConcatDataset,
DataLoader,
Dataset,
IterableDataset,
Expand Down Expand Up @@ -440,5 +441,54 @@ def test_iterable_dataset(self):
self.run_main(dataset, 10, 3)


class TestConcatDataset(unittest.TestCase):
def run_main(self, num_workers, places):
result = ConcatDataset([[0], [1]])
self.assertEqual(2, len(result))
self.assertEqual(0, result[0])
self.assertEqual(1, result[1])

result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])

result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])

result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
with self.assertRaises(IndexError):
result[11]

def test_main(self):
places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for p in places:
self.run_main(num_workers=0, places=p)

def test_iterable_dataset_err(self):
d1 = TensorDataset([paddle.rand((7, 3, 28, 28)), paddle.rand((7,))])
it1 = RandomIterableDataset(10)
it2 = RandomIterableDataset(10)

with self.assertRaisesRegex(
AssertionError, "does not support IterableDataset"
):
ConcatDataset([d1, it2, it1])

with self.assertRaisesRegex(
AssertionError, "does not support IterableDataset"
):
ConcatDataset([it2])

with self.assertRaisesRegex(
AssertionError, "does not support IterableDataset"
):
ConcatDataset([it1, d1])


if __name__ == '__main__':
unittest.main()

0 comments on commit 1916775

Please sign in to comment.