Skip to content

Commit

Permalink
[Typing][B-93] Add type annotations for `python/paddle/reader/decorat…
Browse files Browse the repository at this point in the history
…or.py` (#66305)

* [Typing] Add type annotations for `python/paddle/reader/decorator.py`

* missing pep563
  • Loading branch information
SigureMo authored Jul 22, 2024
1 parent a486468 commit 0ba68a4
Showing 1 changed file with 110 additions and 23 deletions.
133 changes: 110 additions & 23 deletions python/paddle/reader/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import itertools
import logging
import multiprocessing
Expand All @@ -21,9 +23,27 @@
from itertools import zip_longest
from queue import Queue
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Sequence,
TypedDict,
TypeVar,
overload,
)

from typing_extensions import NotRequired, TypeAlias, Unpack

from paddle.base.reader import QUEUE_GET_TIMEOUT

if TYPE_CHECKING:

class _ComposeOptions(TypedDict):
check_alignment: NotRequired[bool]


__all__ = []

# On macOS, the 'spawn' start method is now the default in Python3.8 multiprocessing,
Expand All @@ -41,8 +61,18 @@
else:
fork_context = multiprocessing

_T = TypeVar('_T')
_T1 = TypeVar('_T1')
_T2 = TypeVar('_T2')
_T3 = TypeVar('_T3')
_T4 = TypeVar('_T4')
_U = TypeVar('_U')


def cache(reader):
_Reader: TypeAlias = Callable[[], Generator[_T, None, None]]


def cache(reader: _Reader[_T]) -> _Reader[_T]:
"""
Cache the reader data into memory.
Expand Down Expand Up @@ -77,12 +107,60 @@ def cache(reader):
"""
all_data = tuple(reader())

def __impl__():
def __impl__() -> Generator[_T, None, None]:
yield from all_data

return __impl__


# A temporary solution like builtin map function.
# `Map` maybe the final solution in the future.
# See https://github.com/python/typing/issues/1383
@overload
def map_readers(
func: Callable[[_T1], _U], reader1: _Reader[_T1], /
) -> _Reader[_U]:
...


@overload
def map_readers(
func: Callable[[_T1, _T2], _U],
reader1: _Reader[_T1],
reader2: _Reader[_T2],
/,
) -> _Reader[_U]:
...


@overload
def map_readers(
func: Callable[[_T1, _T2, _T3], _U],
reader1: _Reader[_T1],
reader2: _Reader[_T2],
reader3: _Reader[_T3],
/,
) -> _Reader[_U]:
...


@overload
def map_readers(
func: Callable[[_T1, _T2, _T3, _T4], _U],
reader1: _Reader[_T1],
reader2: _Reader[_T2],
reader3: _Reader[_T3],
reader4: _Reader[_T4],
/,
) -> _Reader[_U]:
...


@overload
def map_readers(func: Callable[..., _U], *readers: _Reader[Any]) -> _Reader[_U]:
...


def map_readers(func, *readers):
"""
Creates a data reader that outputs return value of function using
Expand Down Expand Up @@ -124,7 +202,7 @@ def reader():
return reader


def shuffle(reader, buf_size):
def shuffle(reader: _Reader[_T], buf_size: int) -> _Reader[_T]:
"""
This API creates a decorated reader that outputs the shuffled data.
Expand All @@ -151,7 +229,7 @@ def shuffle(reader, buf_size):
>>> # outputs are 0~4 unordered arrangement
"""

def data_reader():
def data_reader() -> Generator[_T, None, None]:
buf = []
for e in reader():
buf.append(e)
Expand All @@ -169,7 +247,7 @@ def data_reader():
return data_reader


def chain(*readers):
def chain(*readers: _Reader[_T]) -> _Reader[_T]:
"""
Use the input data readers to create a chained data reader. The new created reader
chains the outputs of input readers together as its output, and it do not change
Expand Down Expand Up @@ -218,8 +296,8 @@ def chain(*readers):
"""

def reader():
rs = []
def reader() -> Generator[_T, None, None]:
rs: list[Generator[_T, None, None]] = []
for r in readers:
rs.append(r())

Expand All @@ -232,7 +310,9 @@ class ComposeNotAligned(ValueError):
pass


def compose(*readers, **kwargs):
def compose(
*readers: _Reader[Any], **kwargs: Unpack[_ComposeOptions]
) -> _Reader[Any]:
"""
Creates a data reader whose output is the combination of input readers.
Expand Down Expand Up @@ -289,7 +369,7 @@ def reader():
return reader


def buffered(reader, size):
def buffered(reader: _Reader[_T], size: int) -> _Reader[_T]:
"""
Creates a buffered data reader.
Expand Down Expand Up @@ -339,10 +419,7 @@ def data_reader():
q = Queue(maxsize=size)
t = Thread(
target=read_worker,
args=(
r,
q,
),
args=(r, q),
)
t.daemon = True
t.start()
Expand All @@ -354,7 +431,7 @@ def data_reader():
return data_reader


def firstn(reader, n):
def firstn(reader: _Reader[_T], n: int) -> _Reader[_T]:
"""
This API creates a decorated reader, and limits the max number of
Expand Down Expand Up @@ -399,7 +476,13 @@ class XmapEndSignal:
pass


def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
def xmap_readers(
mapper: Callable[[_T], _U],
reader: _Reader[_T],
process_num: int,
buffer_size: int,
order: bool = False,
) -> _Reader[_U]:
"""
Use multi-threads to map samples from reader by a mapper defined by user.
Expand Down Expand Up @@ -495,7 +578,11 @@ def xreader():
return xreader


def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
def multiprocess_reader(
readers: Sequence[_Reader[_T]],
use_pipe: bool = True,
queue_size: int = 1000,
) -> _Reader[list[_T]]:
"""
This API use python ``multiprocessing`` to read data from ``readers`` parallelly,
and then ``multiprocess.Queue`` or ``multiprocess.Pipe`` is used to merge
Expand All @@ -508,13 +595,13 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
in some platforms.
Parameters:
readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
used to read input data
use_pipe (bool, optional): control the inner API used to implement the multi-processing,
default True - use ``multiprocess.Pipe`` which is recommended
queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
is used, default 1000. Increase this value can speed up the data reading, and more memory
will be consumed.
readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
used to read input data
use_pipe (bool, optional): control the inner API used to implement the multi-processing,
default True - use ``multiprocess.Pipe`` which is recommended
queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
is used, default 1000. Increase this value can speed up the data reading, and more memory
will be consumed.
Returns:
``generator``: a new reader which can be run parallelly
Expand Down

0 comments on commit 0ba68a4

Please sign in to comment.