Skip to content

Commit

Permalink
Add iterable dataset support for multiprocess DataLoader (#25558)
Browse files Browse the repository at this point in the history
* add IterableDataset support in multiprocess DataLoader. test=develop
  • Loading branch information
heavengate authored Aug 12, 2020
1 parent 54003b8 commit dbc88bb
Show file tree
Hide file tree
Showing 12 changed files with 932 additions and 58 deletions.
6 changes: 5 additions & 1 deletion python/paddle/fluid/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,9 @@
from . import batch_sampler
from .batch_sampler import *

from . import dataloader_iter
from .dataloader_iter import *

__all__ = dataset.__all__ \
+ batch_sampler.__all__
+ batch_sampler.__all__ \
+ dataloader_iter.__all__
37 changes: 31 additions & 6 deletions python/paddle/fluid/dataloader/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import division

import numpy as np
from .dataset import Dataset
from .dataset import Dataset, IterableDataset

__all__ = ["BatchSampler"]

Expand Down Expand Up @@ -106,12 +106,18 @@ def __init__(self,
assert isinstance(indices, list) or isinstance(indices, tuple), \
"indices should be a list or tuple, but got {}".format(type(indices))
self.indices = indices
self.sampler_iter = None
else:
assert isinstance(dataset, Dataset), \
"dataset should be an instance of paddle.io.Dataset"
assert indices is None, \
"should not set both dataset and indices"
self.indices = list(range(len(dataset)))
if isinstance(dataset, IterableDataset):
self.sampler_iter = iter(
_InfiniteIterableSampler(dataset, batch_size))
else:
self.sampler_iter = None
assert isinstance(dataset, Dataset), \
"dataset should be an instance of paddle.io.Dataset"
assert indices is None, \
"should not set both dataset and indices"
self.indices = list(range(len(dataset)))

assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer, but got {}".format(batch_size)
Expand All @@ -124,6 +130,9 @@ def __init__(self,
self.drop_last = drop_last

def __iter__(self):
if self.sampler_iter:
yield next(self.sampler_iter)

if self.shuffle:
np.random.shuffle(self.indices)
_iter = iter(self.indices)
Expand All @@ -138,6 +147,22 @@ def __iter__(self):
yield batch_indices

def __len__(self):
if self.sampler_iter:
raise RuntimeError("'{}' should not be called for IterableDataset".
format('__len__'))
num_samples = len(self.indices)
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size


class _InfiniteIterableSampler(object):
def __init__(self, dataset, batch_size=1):
assert isinstance(
dataset, IterableDataset
), "dataset should be an instance of paddle.io.IterableDataset"
self.dataset = dataset
self.batch_size = batch_size

def __iter__(self):
while True:
yield [None] * self.batch_size
Loading

0 comments on commit dbc88bb

Please sign in to comment.