Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.23】Add ConcatDataset API to Paddle #57720

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
WeightedRandomSampler,
get_worker_info,
random_split,
ConcatDataset,
)
from .reader import DataLoader

Expand All @@ -48,4 +49,5 @@
'WeightedRandomSampler',
'random_split',
'Subset',
'ConcatDataset',
]
1 change: 1 addition & 0 deletions python/paddle/io/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .dataset import ChainDataset
from .dataset import random_split
from .dataset import Subset
from .dataset import ConcatDataset

from .batch_sampler import BatchSampler
from .batch_sampler import DistributedBatchSampler
Expand Down
81 changes: 81 additions & 0 deletions python/paddle/io/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import bisect

import paddle

from ... import framework
Expand Down Expand Up @@ -566,3 +568,82 @@ def _accumulate(iterable, fn=lambda x, y: x + y):
for element in it:
total = fn(total, element)
yield total


class ConcatDataset(Dataset):
"""
Dataset as a concatenation of multiple datasets.

This class is useful to assemble different existing datasets.

Args:
datasets (sequence): List of datasets to be concatenated

Returns:
Dataset: A Dataset which concatenated by multiple datasets.

Examples:

.. code-block:: python

>>> import numpy as np
>>> import paddle
>>> from paddle.io import Dataset, ConcatDataset


>>> # define a random dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([32]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)])
>>> for i in range(len(dataset)):
... image, label = dataset[i]
... # do something
"""

@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r

def __init__(self, datasets) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加datasets的类型判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

具体是对什么类型做判断呢,判断是否为sequence(list、tuple)吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以加上typehint吗,Iterable[Dataset]表示Dataset是一个可迭代对象

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

super().__init__()
self.datasets = list(datasets)
assert (
len(self.datasets) > 0
), 'datasets should not be an empty iterable'
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
51 changes: 51 additions & 0 deletions test/legacy_test/test_multiprocess_dataloader_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.io import (
ChainDataset,
ComposeDataset,
ConcatDataset,
DataLoader,
Dataset,
IterableDataset,
Expand Down Expand Up @@ -440,5 +441,55 @@ def test_iterable_dataset(self):
self.run_main(dataset, 10, 3)


class TestConcatDataset(unittest.TestCase):
def run_main(self, num_workers, places):
result = ConcatDataset([[0], [1]])
self.assertEqual(2, len(result))
self.assertEqual(0, result[0])
self.assertEqual(1, result[1])

result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])

result = ConcatDataset([[0, 1, 2, 3, 4], [], [5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])

result = ConcatDataset([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
with self.assertRaises(IndexError):
# this one goes to 11
result[11]

def test_main(self):
places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for p in places:
self.run_main(num_workers=0, places=p)

def test_iterable_dataset_err(self):
d1 = TensorDataset([paddle.rand((7, 3, 28, 28)), paddle.rand((7,))])
it1 = RandomIterableDataset(10)
it2 = RandomIterableDataset(10)

with self.assertRaisesRegex(
AssertionError, "does not support IterableDataset"
):
ConcatDataset([d1, it2, it1])

with self.assertRaisesRegex(
AssertionError, "does not support IterableDataset"
):
ConcatDataset([it2])

with self.assertRaisesRegex(
AssertionError, "does not support IterableDataset"
):
ConcatDataset([it1, d1])


if __name__ == '__main__':
unittest.main()