Skip to content

Commit

Permalink
[Typing][PEP585 Upgrade][2-3] Use stardand collections for type hints…
Browse files Browse the repository at this point in the history
… in `paddle/fluid/operators/generator/*` (#66968)

* fix

* move import

---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
enkilee and SigureMo authored Aug 5, 2024
1 parent 7174d1b commit 75b4b78
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/operators/generator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import itertools
import re
from typing import Dict, List, Sequence
from typing import TYPE_CHECKING

from type_mapping import (
attr_types_map,
Expand All @@ -30,6 +31,9 @@
sr_output_types_map,
)

if TYPE_CHECKING:
from collections.abc import Sequence


def get_infer_var_type_func(op_name):
if op_name == "assign":
Expand Down Expand Up @@ -238,7 +242,7 @@ def to_composite_grad_opmaker_name(backward_op_name):
return composite_grad_opmaker_name


def to_variable_names(dict_list: List[Dict], key: str) -> List[str]:
def to_variable_names(dict_list: list[dict], key: str) -> list[str]:
names = []
for var in dict_list:
names.append(var[key])
Expand Down
33 changes: 17 additions & 16 deletions paddle/fluid/operators/generator/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
from copy import copy
from typing import Any, Dict, List, Tuple
from typing import Any

from tests_utils import is_attr, is_input, is_output, is_vec
from type_mapping import opmaker_attr_types_map


def to_named_dict(items: List[Dict], is_op=False) -> Dict[str, Dict]:
def to_named_dict(items: list[dict], is_op=False) -> dict[str, dict]:
named_dict = {}
if is_op:
for item in items:
Expand All @@ -46,7 +47,7 @@ def to_named_dict(items: List[Dict], is_op=False) -> Dict[str, Dict]:
return named_dict


def parse_arg(op_name: str, s: str) -> Dict[str, str]:
def parse_arg(op_name: str, s: str) -> dict[str, str]:
"""parse an argument in following formats:
1. typename name
2. typename name = default_value
Expand Down Expand Up @@ -82,7 +83,7 @@ def parse_arg(op_name: str, s: str) -> Dict[str, str]:

def parse_input_and_attr(
op_name: str, arguments: str
) -> Tuple[List, List, Dict, Dict]:
) -> tuple[list, list, dict, dict]:
args_str = arguments.strip()
assert args_str.startswith('(') and args_str.endswith(')'), (
f"Args declaration should start with '(' and end with ')', "
Expand Down Expand Up @@ -122,7 +123,7 @@ def parse_input_and_attr(
return inputs, attrs


def parse_output(op_name: str, s: str) -> Dict[str, str]:
def parse_output(op_name: str, s: str) -> dict[str, str]:
"""parse an output, typename or typename(name)."""
match = re.search(
r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
Expand All @@ -149,7 +150,7 @@ def parse_output(op_name: str, s: str) -> Dict[str, str]:
return {"typename": typename, "name": name}


def parse_outputs(op_name: str, outputs: str) -> List[Dict]:
def parse_outputs(op_name: str, outputs: str) -> list[dict]:
if outputs is None:
return []
outputs = parse_plain_list(outputs, sep=",")
Expand All @@ -159,14 +160,14 @@ def parse_outputs(op_name: str, outputs: str) -> List[Dict]:
return output_items


def parse_infer_meta(infer_meta: Dict[str, Any]) -> Dict[str, Any]:
def parse_infer_meta(infer_meta: dict[str, Any]) -> dict[str, Any]:
infer_meta = copy(infer_meta) # to prevent mutating the input
if "param" not in infer_meta:
infer_meta["param"] = None
return infer_meta


def parse_candidates(s: str) -> Dict[str, Any]:
def parse_candidates(s: str) -> dict[str, Any]:
"parse candidates joined by either '>'(ordered) or ','(unordered)"
delimiter = ">" if ">" in s else ","
ordered = delimiter == ">"
Expand All @@ -175,7 +176,7 @@ def parse_candidates(s: str) -> Dict[str, Any]:
return {"ordered": ordered, "candidates": candidates}


def parse_plain_list(s: str, sep=",") -> List[str]:
def parse_plain_list(s: str, sep=",") -> list[str]:
if sep == ",":
patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}"
items = re.split(patten, s.strip())
Expand All @@ -185,7 +186,7 @@ def parse_plain_list(s: str, sep=",") -> List[str]:
return [item.strip() for item in s.strip().split(sep)]


def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
def parse_kernel(op_name: str, kernel_config: dict[str, Any]) -> dict[str, Any]:
# kernel :
# func : [], Kernel functions (example: scale, scale_sr)
# param : [], Input params of kernel
Expand Down Expand Up @@ -276,7 +277,7 @@ def delete_bracket(name: str):
return name


def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
def parse_inplace(op_name: str, inplace_cfg: str) -> dict[str, str]:
inplace_map = {}
inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
pairs = parse_plain_list(inplace_cfg)
Expand All @@ -288,7 +289,7 @@ def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
return inplace_map


def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
def parse_invoke(op_name: str, invoke_config: str) -> dict[str, Any]:
invoke_config = invoke_config.strip()
func, rest = invoke_config.split("(", 1)
func = func.strip()
Expand All @@ -297,15 +298,15 @@ def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]:
return invocation


def extract_type_and_name(records: List[Dict]) -> List[Dict]:
def extract_type_and_name(records: list[dict]) -> list[dict]:
"""extract type and name from forward call, it is simpler than forward op ."""
extracted = [
{"name": item["name"], "typename": item["typename"]} for item in records
]
return extracted


def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
def parse_forward(op_name: str, forward_config: str) -> dict[str, Any]:
# op_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
result = re.search(
r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
Expand All @@ -330,7 +331,7 @@ def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]:
def parse_composite(
op_name: str,
composite_config: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
# composite_config: func(args1, args2,.....)
result = re.search(
r"(?P<func_name>[a-z][a-z0-9_]+)\s*\((?P<func_args>[^\)]+)\)",
Expand Down Expand Up @@ -402,7 +403,7 @@ def check_op_config(op_entry, op_name):
), f"Op ({op_name}) : invalid key (kernel.{kernel_key}) in Yaml."


def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
def parse_op_entry(op_entry: dict[str, Any], name_field="op"):
op_name = op_entry[name_field]
inputs, attrs = parse_input_and_attr(op_name, op_entry["args"])
outputs = parse_outputs(op_name, op_entry["output"])
Expand Down

0 comments on commit 75b4b78

Please sign in to comment.