Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iterable dataset support for multiprocess DataLoader #25558

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/paddle/fluid/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,9 @@
from . import batch_sampler
from .batch_sampler import *

from . import dataloader_iter
from .dataloader_iter import *

__all__ = dataset.__all__ \
+ batch_sampler.__all__
+ batch_sampler.__all__ \
+ dataloader_iter.__all__
37 changes: 31 additions & 6 deletions python/paddle/fluid/dataloader/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import division

import numpy as np
from .dataset import Dataset
from .dataset import Dataset, IterableDataset

__all__ = ["BatchSampler"]

Expand Down Expand Up @@ -106,12 +106,18 @@ def __init__(self,
assert isinstance(indices, list) or isinstance(indices, tuple), \
"indices should be a list or tuple, but got {}".format(type(indices))
self.indices = indices
self.sampler_iter = None
else:
assert isinstance(dataset, Dataset), \
"dataset should be an instance of paddle.io.Dataset"
assert indices is None, \
"should not set both dataset and indices"
self.indices = list(range(len(dataset)))
if isinstance(dataset, IterableDataset):
self.sampler_iter = iter(
_InfiniteIterableSampler(dataset, batch_size))
else:
self.sampler_iter = None
assert isinstance(dataset, Dataset), \
"dataset should be an instance of paddle.io.Dataset"
assert indices is None, \
"should not set both dataset and indices"
self.indices = list(range(len(dataset)))

assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer, but got {}".format(batch_size)
Expand All @@ -124,6 +130,9 @@ def __init__(self,
self.drop_last = drop_last

def __iter__(self):
if self.sampler_iter:
yield next(self.sampler_iter)

if self.shuffle:
np.random.shuffle(self.indices)
_iter = iter(self.indices)
Expand All @@ -138,6 +147,22 @@ def __iter__(self):
yield batch_indices

def __len__(self):
if self.sampler_iter:
raise RuntimeError("'{}' should not be called for IterableDataset".
format('__len__'))
num_samples = len(self.indices)
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size


class _InfiniteIterableSampler(object):
def __init__(self, dataset, batch_size=1):
assert isinstance(
dataset, IterableDataset
), "dataset should be an instance of paddle.io.IterableDataset"
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
while True:
yield [None] * self.batch_size
Loading