diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 49e699aa94523a..cbda11e0375ba7 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import logging import multiprocessing @@ -21,9 +23,27 @@ from itertools import zip_longest from queue import Queue from threading import Thread +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Sequence, + TypedDict, + TypeVar, + overload, +) + +from typing_extensions import NotRequired, TypeAlias, Unpack from paddle.base.reader import QUEUE_GET_TIMEOUT +if TYPE_CHECKING: + + class _ComposeOptions(TypedDict): + check_alignment: NotRequired[bool] + + __all__ = [] # On macOS, the 'spawn' start method is now the default in Python3.8 multiprocessing, @@ -41,8 +61,18 @@ else: fork_context = multiprocessing +_T = TypeVar('_T') +_T1 = TypeVar('_T1') +_T2 = TypeVar('_T2') +_T3 = TypeVar('_T3') +_T4 = TypeVar('_T4') +_U = TypeVar('_U') + -def cache(reader): +_Reader: TypeAlias = Callable[[], Generator[_T, None, None]] + + +def cache(reader: _Reader[_T]) -> _Reader[_T]: """ Cache the reader data into memory. @@ -77,12 +107,60 @@ def cache(reader): """ all_data = tuple(reader()) - def __impl__(): + def __impl__() -> Generator[_T, None, None]: yield from all_data return __impl__ +# A temporary solution like builtin map function. +# `Map` maybe the final solution in the future. +# See https://github.com/python/typing/issues/1383 +@overload +def map_readers( + func: Callable[[_T1], _U], reader1: _Reader[_T1], / +) -> _Reader[_U]: + ... + + +@overload +def map_readers( + func: Callable[[_T1, _T2], _U], + reader1: _Reader[_T1], + reader2: _Reader[_T2], + /, +) -> _Reader[_U]: + ... + + +@overload +def map_readers( + func: Callable[[_T1, _T2, _T3], _U], + reader1: _Reader[_T1], + reader2: _Reader[_T2], + reader3: _Reader[_T3], + /, +) -> _Reader[_U]: + ... + + +@overload +def map_readers( + func: Callable[[_T1, _T2, _T3, _T4], _U], + reader1: _Reader[_T1], + reader2: _Reader[_T2], + reader3: _Reader[_T3], + reader4: _Reader[_T4], + /, +) -> _Reader[_U]: + ... + + +@overload +def map_readers(func: Callable[..., _U], *readers: _Reader[Any]) -> _Reader[_U]: + ... + + def map_readers(func, *readers): """ Creates a data reader that outputs return value of function using @@ -124,7 +202,7 @@ def reader(): return reader -def shuffle(reader, buf_size): +def shuffle(reader: _Reader[_T], buf_size: int) -> _Reader[_T]: """ This API creates a decorated reader that outputs the shuffled data. @@ -151,7 +229,7 @@ def shuffle(reader, buf_size): >>> # outputs are 0~4 unordered arrangement """ - def data_reader(): + def data_reader() -> Generator[_T, None, None]: buf = [] for e in reader(): buf.append(e) @@ -169,7 +247,7 @@ def data_reader(): return data_reader -def chain(*readers): +def chain(*readers: _Reader[_T]) -> _Reader[_T]: """ Use the input data readers to create a chained data reader. The new created reader chains the outputs of input readers together as its output, and it do not change @@ -218,8 +296,8 @@ def chain(*readers): """ - def reader(): - rs = [] + def reader() -> Generator[_T, None, None]: + rs: list[Generator[_T, None, None]] = [] for r in readers: rs.append(r()) @@ -232,7 +310,9 @@ class ComposeNotAligned(ValueError): pass -def compose(*readers, **kwargs): +def compose( + *readers: _Reader[Any], **kwargs: Unpack[_ComposeOptions] +) -> _Reader[Any]: """ Creates a data reader whose output is the combination of input readers. @@ -289,7 +369,7 @@ def reader(): return reader -def buffered(reader, size): +def buffered(reader: _Reader[_T], size: int) -> _Reader[_T]: """ Creates a buffered data reader. @@ -339,10 +419,7 @@ def data_reader(): q = Queue(maxsize=size) t = Thread( target=read_worker, - args=( - r, - q, - ), + args=(r, q), ) t.daemon = True t.start() @@ -354,7 +431,7 @@ def data_reader(): return data_reader -def firstn(reader, n): +def firstn(reader: _Reader[_T], n: int) -> _Reader[_T]: """ This API creates a decorated reader, and limits the max number of @@ -399,7 +476,13 @@ class XmapEndSignal: pass -def xmap_readers(mapper, reader, process_num, buffer_size, order=False): +def xmap_readers( + mapper: Callable[[_T], _U], + reader: _Reader[_T], + process_num: int, + buffer_size: int, + order: bool = False, +) -> _Reader[_U]: """ Use multi-threads to map samples from reader by a mapper defined by user. @@ -495,7 +578,11 @@ def xreader(): return xreader -def multiprocess_reader(readers, use_pipe=True, queue_size=1000): +def multiprocess_reader( + readers: Sequence[_Reader[_T]], + use_pipe: bool = True, + queue_size: int = 1000, +) -> _Reader[list[_T]]: """ This API use python ``multiprocessing`` to read data from ``readers`` parallelly, and then ``multiprocess.Queue`` or ``multiprocess.Pipe`` is used to merge @@ -508,13 +595,13 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): in some platforms. Parameters: - readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list - used to read input data - use_pipe (bool, optional): control the inner API used to implement the multi-processing, - default True - use ``multiprocess.Pipe`` which is recommended - queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue`` - is used, default 1000. Increase this value can speed up the data reading, and more memory - will be consumed. + readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list + used to read input data + use_pipe (bool, optional): control the inner API used to implement the multi-processing, + default True - use ``multiprocess.Pipe`` which is recommended + queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue`` + is used, default 1000. Increase this value can speed up the data reading, and more memory + will be consumed. Returns: ``generator``: a new reader which can be run parallelly