diff --git a/paddle/fluid/operators/generator/filters.py b/paddle/fluid/operators/generator/filters.py index 6b5f23a04acc3..6007c1b2b8502 100644 --- a/paddle/fluid/operators/generator/filters.py +++ b/paddle/fluid/operators/generator/filters.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import itertools import re -from typing import Dict, List, Sequence +from typing import TYPE_CHECKING from type_mapping import ( attr_types_map, @@ -30,6 +31,9 @@ sr_output_types_map, ) +if TYPE_CHECKING: + from collections.abc import Sequence + def get_infer_var_type_func(op_name): if op_name == "assign": @@ -238,7 +242,7 @@ def to_composite_grad_opmaker_name(backward_op_name): return composite_grad_opmaker_name -def to_variable_names(dict_list: List[Dict], key: str) -> List[str]: +def to_variable_names(dict_list: list[dict], key: str) -> list[str]: names = [] for var in dict_list: names.append(var[key]) diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index f61697aa47866..2ae377362a57b 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -11,16 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import re from copy import copy -from typing import Any, Dict, List, Tuple +from typing import Any from tests_utils import is_attr, is_input, is_output, is_vec from type_mapping import opmaker_attr_types_map -def to_named_dict(items: List[Dict], is_op=False) -> Dict[str, Dict]: +def to_named_dict(items: list[dict], is_op=False) -> dict[str, dict]: named_dict = {} if is_op: for item in items: @@ -46,7 +47,7 @@ def to_named_dict(items: List[Dict], is_op=False) -> Dict[str, Dict]: return named_dict -def parse_arg(op_name: str, s: str) -> Dict[str, str]: +def parse_arg(op_name: str, s: str) -> dict[str, str]: """parse an argument in following formats: 1. typename name 2. typename name = default_value @@ -82,7 +83,7 @@ def parse_arg(op_name: str, s: str) -> Dict[str, str]: def parse_input_and_attr( op_name: str, arguments: str -) -> Tuple[List, List, Dict, Dict]: +) -> tuple[list, list, dict, dict]: args_str = arguments.strip() assert args_str.startswith('(') and args_str.endswith(')'), ( f"Args declaration should start with '(' and end with ')', " @@ -122,7 +123,7 @@ def parse_input_and_attr( return inputs, attrs -def parse_output(op_name: str, s: str) -> Dict[str, str]: +def parse_output(op_name: str, s: str) -> dict[str, str]: """parse an output, typename or typename(name).""" match = re.search( r"(?P[a-zA-Z0-9_[\]]+)\s*(?P\([a-zA-Z0-9_@]+\))?\s*(?P\{[^\}]+\})?", @@ -149,7 +150,7 @@ def parse_output(op_name: str, s: str) -> Dict[str, str]: return {"typename": typename, "name": name} -def parse_outputs(op_name: str, outputs: str) -> List[Dict]: +def parse_outputs(op_name: str, outputs: str) -> list[dict]: if outputs is None: return [] outputs = parse_plain_list(outputs, sep=",") @@ -159,14 +160,14 @@ def parse_outputs(op_name: str, outputs: str) -> List[Dict]: return output_items -def parse_infer_meta(infer_meta: Dict[str, Any]) -> Dict[str, Any]: +def parse_infer_meta(infer_meta: dict[str, Any]) -> dict[str, Any]: infer_meta = copy(infer_meta) # to prevent mutating the input if "param" not in infer_meta: infer_meta["param"] = None return infer_meta -def parse_candidates(s: str) -> Dict[str, Any]: +def parse_candidates(s: str) -> dict[str, Any]: "parse candidates joined by either '>'(ordered) or ','(unordered)" delimiter = ">" if ">" in s else "," ordered = delimiter == ">" @@ -175,7 +176,7 @@ def parse_candidates(s: str) -> Dict[str, Any]: return {"ordered": ordered, "candidates": candidates} -def parse_plain_list(s: str, sep=",") -> List[str]: +def parse_plain_list(s: str, sep=",") -> list[str]: if sep == ",": patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}" items = re.split(patten, s.strip()) @@ -185,7 +186,7 @@ def parse_plain_list(s: str, sep=",") -> List[str]: return [item.strip() for item in s.strip().split(sep)] -def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]: +def parse_kernel(op_name: str, kernel_config: dict[str, Any]) -> dict[str, Any]: # kernel : # func : [], Kernel functions (example: scale, scale_sr) # param : [], Input params of kernel @@ -276,7 +277,7 @@ def delete_bracket(name: str): return name -def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]: +def parse_inplace(op_name: str, inplace_cfg: str) -> dict[str, str]: inplace_map = {} inplace_cfg = inplace_cfg.lstrip("(").rstrip(")") pairs = parse_plain_list(inplace_cfg) @@ -288,7 +289,7 @@ def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]: return inplace_map -def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]: +def parse_invoke(op_name: str, invoke_config: str) -> dict[str, Any]: invoke_config = invoke_config.strip() func, rest = invoke_config.split("(", 1) func = func.strip() @@ -297,7 +298,7 @@ def parse_invoke(op_name: str, invoke_config: str) -> Dict[str, Any]: return invocation -def extract_type_and_name(records: List[Dict]) -> List[Dict]: +def extract_type_and_name(records: list[dict]) -> list[dict]: """extract type and name from forward call, it is simpler than forward op .""" extracted = [ {"name": item["name"], "typename": item["typename"]} for item in records @@ -305,7 +306,7 @@ def extract_type_and_name(records: List[Dict]) -> List[Dict]: return extracted -def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]: +def parse_forward(op_name: str, forward_config: str) -> dict[str, Any]: # op_name (const Tensor& input, ... , int attr, ...) -> Tensor(out) result = re.search( r"(?P[a-z][a-z0-9_]+)\s*(?P\([^\)]+\))\s*->\s*(?P.+)", @@ -330,7 +331,7 @@ def parse_forward(op_name: str, forward_config: str) -> Dict[str, Any]: def parse_composite( op_name: str, composite_config: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: # composite_config: func(args1, args2,.....) result = re.search( r"(?P[a-z][a-z0-9_]+)\s*\((?P[^\)]+)\)", @@ -402,7 +403,7 @@ def check_op_config(op_entry, op_name): ), f"Op ({op_name}) : invalid key (kernel.{kernel_key}) in Yaml." -def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): +def parse_op_entry(op_entry: dict[str, Any], name_field="op"): op_name = op_entry[name_field] inputs, attrs = parse_input_and_attr(op_name, op_entry["args"]) outputs = parse_outputs(op_name, op_entry["output"])