Skip to content

Commit

Permalink
[Typing][PEP585 Upgrade][BUAA][21-30] Use standard collections for ty…
Browse files Browse the repository at this point in the history
…pe hints for 9 files in `python/paddle/` (#67119)
  • Loading branch information
Caogration authored Aug 7, 2024
1 parent 09cb24e commit 1415c88
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 43 deletions.
2 changes: 1 addition & 1 deletion python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
_RetT = TypeVar("_RetT")

if TYPE_CHECKING:
from typing import Generator, Sequence
from collections.abc import Generator, Sequence

from paddle.static.amp.fp16_utils import AmpOptions

Expand Down
5 changes: 4 additions & 1 deletion python/paddle/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

from __future__ import annotations

from typing import Callable, Generator, TypeVar
from typing import TYPE_CHECKING, Callable, TypeVar

if TYPE_CHECKING:
from collections.abc import Generator

_T = TypeVar('_T')
__all__ = []
Expand Down
25 changes: 15 additions & 10 deletions python/paddle/decomposition/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import math
from typing import List, Sequence, Tuple
from typing import TYPE_CHECKING

import paddle
from paddle import pir
from paddle.autograd import backward_utils
from paddle.base import core

if TYPE_CHECKING:
from collections.abc import Sequence

_PADDLE_DTYPE_2_NBYTES = {
core.DataType.BOOL: 1,
core.DataType.FLOAT16: 2,
Expand All @@ -37,7 +42,7 @@
}

# define the default recompute ops that can be fused between pairs
DEFAULT_RECOMPUTABLE_OPS: List[str] = [
DEFAULT_RECOMPUTABLE_OPS: list[str] = [
"pd_op.full_int_array",
"pd_op.full",
"pd_op.sum",
Expand Down Expand Up @@ -125,11 +130,11 @@
"pd_op.isnan",
]

VIEW_OPS: List[str] = []
VIEW_OPS: list[str] = []

RANDOM_OPS: List[str] = ["pd_op.randint", "pd_op.uniform", "pd_op.dropout"]
RANDOM_OPS: list[str] = ["pd_op.randint", "pd_op.uniform", "pd_op.dropout"]

COMPUTE_INTENSIVE_OPS: List[str] = [
COMPUTE_INTENSIVE_OPS: list[str] = [
"pd_op.matmul",
"pd_op.conv2d",
"pd_op.layer_norm",
Expand All @@ -151,7 +156,7 @@ def auto_recompute(
fwd_op_end_idx: int,
backward_op_start_idx: int,
recomputable_ops: Sequence[str] = None,
) -> Tuple[paddle.static.Program, int]:
) -> tuple[paddle.static.Program, int]:
'''
Considering the compiler fuse strategy, we model the pir graph.
Convert the pir calculation graph into a networkx calculation
Expand Down Expand Up @@ -440,12 +445,12 @@ def _ban_recomputation(value_node):

def partition_joint_graph(
program: paddle.static.Program,
saved_values: List[pir.Value],
inputs: List[pir.Value],
outputs: List[pir.Value],
saved_values: list[pir.Value],
inputs: list[pir.Value],
outputs: list[pir.Value],
fwd_op_end_idx: int,
backward_op_start_idx: int,
) -> Tuple[paddle.static.Program, int]:
) -> tuple[paddle.static.Program, int]:
"""
Partition the joint graph, recompute the intermediate values
by saved values to save memory.
Expand Down
16 changes: 9 additions & 7 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
from types import MethodType
from typing import Callable, List, Tuple, Union
from typing import Callable

import numpy as np

Expand Down Expand Up @@ -2641,9 +2643,9 @@ class ShardDataloader:
def __init__(
self,
dataloader: paddle.io.DataLoader,
meshes: Union[ProcessMesh, List[ProcessMesh], Tuple[ProcessMesh]],
input_keys: Union[List[str], Tuple[str]] = None,
shard_dims: Union[list, tuple, str, int] = None,
meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh],
input_keys: list[str] | tuple[str] = None,
shard_dims: list | tuple | str | int = None,
is_dataset_splitted: bool = False,
):
# do some check
Expand Down Expand Up @@ -2895,9 +2897,9 @@ def __call__(self):

def shard_dataloader(
dataloader: paddle.io.DataLoader,
meshes: Union[ProcessMesh, List[ProcessMesh], Tuple[ProcessMesh]],
input_keys: Union[List[str], Tuple[str]] = None,
shard_dims: Union[list, tuple, str, int] = None,
meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh],
input_keys: list[str] | tuple[str] = None,
shard_dims: list | tuple | str | int = None,
is_dataset_splitted: bool = False,
) -> ShardDataloader:
"""
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from functools import reduce
from typing import List, Tuple

import numpy as np

Expand Down Expand Up @@ -322,7 +322,7 @@ def set_mesh(mesh):
_g_mesh = mesh


def create_mesh(mesh_dims: List[Tuple[str, int]]):
def create_mesh(mesh_dims: list[tuple[str, int]]):
"""
Create a global process_mesh for auto parallel.
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import csv
import os
from typing import Tuple

import pandas as pd

Expand Down Expand Up @@ -53,7 +53,7 @@ def sort_metric(self, direction, metric_name) -> None:

def get_best(
self, metric, direction, buffer=None, max_mem_usage=None
) -> Tuple[dict, bool]:
) -> tuple[dict, bool]:
self.sort_metric(direction=direction, metric_name=metric)
if len(self.history) == 0:
return (None, True)
Expand Down Expand Up @@ -142,7 +142,7 @@ def store_history(self, path="./history.csv"):
self.store_path = path
self._store_history_impl(data=self.history, path=path)

def load_history(self, path="./history.csv") -> Tuple[list, bool]:
def load_history(self, path="./history.csv") -> tuple[list, bool]:
"""Load history from csv file."""
err = False
if self.store_path is None:
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/distributed/auto_tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import csv
import itertools
import logging
import os
import re
from typing import Tuple

import paddle

Expand Down Expand Up @@ -1405,7 +1405,7 @@ def gen_new_ctx(ctx, cur_cfg, tuner_cfg):

def read_metric_log(
path, file="workerlog.0", target_metric='step/s'
) -> Tuple[float, int]:
) -> tuple[float, int]:
"""For extracting metric from log file."""
"""
return:
Expand Down Expand Up @@ -1469,7 +1469,7 @@ def read_metric_log(

def read_step_time_log(
path, file="workerlog.0", target_metric='interval_runtime'
) -> Tuple[float, int]:
) -> tuple[float, int]:
target_file = path + "/" + file
if not os.path.exists(target_file):
return None
Expand Down Expand Up @@ -1539,7 +1539,7 @@ def read_allocated_memory_log(
return metric_list[-1]


def read_memory_log(path, file) -> Tuple[float, bool]:
def read_memory_log(path, file) -> tuple[float, bool]:
log_path = os.path.join(path, file)
if not os.path.exists(log_path):
return (0.0, True)
Expand Down Expand Up @@ -1599,7 +1599,7 @@ def read_log(
metric_file="workerlog.0",
target_metric='step/s',
memory_file="0.gpu.log",
) -> Tuple[float, float, int]:
) -> tuple[float, float, int]:
"""
extract metric and max memory usage from log file
return:
Expand Down
14 changes: 7 additions & 7 deletions python/paddle/distributed/checkpoint/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Tuple


@dataclass
Expand All @@ -22,8 +22,8 @@ class LocalTensorMetadata:
The location of a local tensor in the global tensor.
"""

global_offset: Tuple[int]
local_shape: Tuple[int]
global_offset: tuple[int]
local_shape: tuple[int]
dtype: str


Expand All @@ -34,11 +34,11 @@ class LocalTensorIndex:
"""

tensor_key: str
global_offset: Tuple[int]
global_offset: tuple[int]


@dataclass
class Metadata:
state_dict_metadata: Dict[str, List[LocalTensorMetadata]] = None
storage_metadata: Dict[LocalTensorIndex, str] = None
flat_mapping: Dict[str, Tuple[str]] = None
state_dict_metadata: dict[str, list[LocalTensorMetadata]] = None
storage_metadata: dict[LocalTensorIndex, str] = None
flat_mapping: dict[str, tuple[str]] = None
16 changes: 9 additions & 7 deletions python/paddle/distributed/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
from typing import List, Tuple, Union
from typing import TYPE_CHECKING

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.framework import core

if TYPE_CHECKING:
from paddle.framework import core


def get_coordinator(mesh: Union[np.array, List[List[int]]], rank: int):
def get_coordinator(mesh: np.array | list[list[int]], rank: int):
mesh = paddle.to_tensor(mesh)
rand_coordinator = (mesh == rank).nonzero()
assert rand_coordinator.shape[0] in (
Expand All @@ -46,10 +48,10 @@ def balanced_split(total_nums, num_of_pieces):


def compute_local_shape_and_global_offset(
global_shape: List[int],
global_shape: list[int],
process_mesh: core.ProcessMesh,
placements: List[core.Placement],
) -> Tuple[Tuple[int], Tuple[int]]:
placements: list[core.Placement],
) -> tuple[tuple[int], tuple[int]]:
mesh = np.array(process_mesh.process_ids).reshape(process_mesh.shape)
# deal with cross mesh case
if paddle.distributed.get_rank() not in mesh:
Expand Down

0 comments on commit 1415c88

Please sign in to comment.