Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade pytorch-paddle mapping table generating process #6752

Merged
merged 3 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
print(script_dir)

from validate_mapping_in_api_difference import (
DiffMeta,
get_meta_from_diff_file,
process_mapping_index as reference_mapping_item,
)
Expand Down Expand Up @@ -64,9 +65,13 @@ def mapping_type_to_description(mapping_type):
return "【未知类型】", False


REFERENCE_PATTERN = re.compile(
# 以后没有 REFERENCE-ITEM 需要维护了,全部从 api_difference/ 目录生成
_REFERENCE_ITEM_PATTERN = re.compile(
r"^\| *REFERENCE-MAPPING-ITEM\( *(?P<torch_api>[^,]+) *, *(?P<diff_url>.+) *\) *\|$"
)
REFERENCE_TABLE_PATTERN = re.compile(
r"^\| *REFERENCE-MAPPING-TABLE\( *(?P<api_prefix>[^,]+) *(, *max_depth *= *(?P<max_depth>\d+) *)?\) *\|$"
)
ALIAS_PATTERN = re.compile(
r"^\| *ALIAS-REFERENCE-ITEM\( *(?P<alias_name>[^,]+) *, *(?P<torch_api>[^,]+) *\) *\|$"
)
Expand All @@ -79,6 +84,7 @@ def mapping_type_to_description(mapping_type):


def docs_url_to_relative_page(url):
"""将映射文档的 PaddlePaddle/docs url 转换为网页路径"""
if not url.startswith(DOCS_REPO_BASEURL):
return url

Expand All @@ -88,110 +94,101 @@ def docs_url_to_relative_page(url):
return md_path


def apply_reference_to_row(line, metadata_dict, table_row_idx, line_idx):
reference_match = REFERENCE_PATTERN.match(line)
alias_match = ALIAS_PATTERN.match(line)
not_implemented_match = NOT_IMPLEMENTED_PATTERN.match(line)
def doc_path_to_relative_page(path):
"""将映射文档的本地路径转换为网页相对路径"""
md_path = os.path.relpath(path, script_dir)

row_idx_s = str(table_row_idx)
assert md_path.endswith(".md"), f"Unexpected mapping doc path: {path}"

if reference_match:
torch_api = reference_match["torch_api"].strip("`").replace(r"\_", "_")
diff_url = reference_match["diff_url"]
return md_path[:-3] + ".html"

diff_page_url = docs_url_to_relative_page(diff_url)

if torch_api not in metadata_dict:
raise Exception(
f"Cannot find torch_api: {torch_api} in line {line_idx}"
)
def reference_table_match_to_condition(m):
api_prefix = m["api_prefix"].strip("`")
max_depth = m["max_depth"]
if max_depth is None:
max_depth = 255
else:
max_depth = int(max_depth)
return api_prefix, max_depth

meta_dict[torch_api]["diff_url"] = diff_page_url

reference_item = metadata_dict.get(torch_api, None)
torch_api_url = reference_item["torch_api_url"]
torch_api_column = f"[`{torch_api}`]({torch_api_url})"
def get_referenced_api_columns(torch_api, metadata_dict, alias=None):
assert (
torch_api in metadata_dict
), f'Error: cannot find mapping doc of api "{torch_api}"'
api_data: DiffMeta = metadata_dict[torch_api]

mapping_type = reference_item["mapping_type"]
mapping_type_column = mapping_type
diff_page_url = doc_path_to_relative_page(api_data["source_file"])

_mapping_type_desc, show_diff_url = mapping_type_to_description(
mapping_type
)
mapping_url_column = ""
if show_diff_url:
mapping_url_column = f"[详细对比]({diff_page_url})"

if "paddle_api" not in reference_item:
if mapping_type not in ["组合替代实现", "可删除", "功能缺失"]:
print(
f"Cannot find paddle_api for torch_api: {torch_api} in line {line_idx}"
)
paddle_api_column = ""
else:
paddle_api = reference_item["paddle_api"]
paddle_api_url = reference_item["paddle_api_url"]
paddle_api_column = f"[`{paddle_api}`]({paddle_api_url})"
torch_api_url = api_data["torch_api_url"]
api_disp_name = torch_api if alias is None else alias
torch_api_column = f"[`{api_disp_name}`]({torch_api_url})"

content = [
row_idx_s,
torch_api_column,
paddle_api_column,
mapping_type_column,
mapping_url_column,
]
mapping_type = api_data["mapping_type"]
mapping_type_column = mapping_type

output = "| " + " | ".join(content) + " |\n"
return output
elif alias_match:
alias_name = alias_match["alias_name"].strip("`").replace(r"\_", "_")
torch_api = alias_match["torch_api"].strip("`").replace(r"\_", "_")

if torch_api not in metadata_dict:
raise Exception(
f"Cannot find torch_api: {torch_api} in line {line_idx}"
)
_mapping_type_desc, show_diff_url = mapping_type_to_description(
mapping_type
)
desc_column = ""
if show_diff_url:
desc_column = f"[详细对比]({diff_page_url})"
if alias is not None:
desc_column = f"`{torch_api}` 别名,{desc_column}"

if "paddle_api" not in api_data:
if mapping_type not in ["组合替代实现", "可删除", "功能缺失"]:
print(f"Error: cannot find paddle_api for torch_api: {torch_api}")
paddle_api_column = ""
else:
paddle_api = api_data["paddle_api"]
paddle_api_url = api_data["paddle_api_url"]
paddle_api_column = f"[`{paddle_api}`]({paddle_api_url})"

return [
torch_api_column,
paddle_api_column,
mapping_type_column,
desc_column,
]

diff_page_url = metadata_dict[torch_api].get("diff_url", "")

reference_item = metadata_dict.get(torch_api, None)
torch_api_url = reference_item["torch_api_url"]
alisa_name_column = f"[`{alias_name}`]({torch_api_url})"
def apply_reference_to_row_ex(line, metadata_dict, context, line_idx):
reference_table_match = REFERENCE_TABLE_PATTERN.match(line)
alias_match = ALIAS_PATTERN.match(line)
not_implemented_match = NOT_IMPLEMENTED_PATTERN.match(line)

mapping_type = reference_item["mapping_type"]
mapping_type_column = mapping_type
row_idx_s = str(context["table_row_idx"])

if reference_table_match:
condition = reference_table_match_to_condition(reference_table_match)
api_list = context["c2a_dict"][
condition
] # 这个键一定存在,否则说明前面出错了
output_lines = []
cur_row_idx = context["table_row_idx"]
for api in api_list:
content = get_referenced_api_columns(api, metadata_dict)
content.insert(0, str(cur_row_idx))
output = "| " + " | ".join(content) + " |\n"
output_lines.append(output)
cur_row_idx += 1
# 因为外面会给 table_row_idx 自动加 1,所以这里减去 1
context["table_row_idx"] = cur_row_idx - 1
return output_lines
elif alias_match:
alias_name = alias_match["alias_name"].strip("`").replace(r"\_", "_")
torch_api = alias_match["torch_api"].strip("`").replace(r"\_", "_")

_mapping_type_desc, show_diff_url = mapping_type_to_description(
mapping_type
content = get_referenced_api_columns(
torch_api, metadata_dict, alias=alias_name
)

desc_column = f"`{torch_api}` 别名"

if show_diff_url:
desc_column += f",[详细对比]({diff_page_url})"

if "paddle_api" not in reference_item:
if mapping_type not in ["组合替代实现", "可删除", "功能缺失"]:
print(
f"Cannot find paddle_api for torch_api: {torch_api} in line {line_idx}"
)
paddle_api_column = ""
else:
paddle_api = reference_item["paddle_api"]
paddle_api_url = reference_item["paddle_api_url"]
paddle_api_column = f"[`{paddle_api}`]({paddle_api_url})"

content = [
row_idx_s,
alisa_name_column,
paddle_api_column,
mapping_type_column,
desc_column,
]
content.insert(0, row_idx_s)

output = "| " + " | ".join(content) + " |\n"
return output

return [output]
elif not_implemented_match:
torch_api = (
not_implemented_match["torch_api"].strip("`").replace(r"\_", "_")
Expand All @@ -213,12 +210,12 @@ def apply_reference_to_row(line, metadata_dict, table_row_idx, line_idx):
]

output = "| " + " | ".join(content) + " |\n"
return output
return [output]
else:
raise ValueError(
f"found manual-maintaining row at line [{line_idx}]: {line}"
)
return line
return [line]


def reference_mapping_item_processer(line, line_idx, state, output, context):
Expand All @@ -232,11 +229,7 @@ def reference_mapping_item_processer(line, line_idx, state, output, context):
# check column names in common process
output.append(line)
return True
elif state == 1:
# check seperator of table to ignore in common process
output.append(line)
return True
elif state == 2:
elif state == 1 or state == 2:
# check seperator of table to process in common process
output.append(line)
return True
Expand All @@ -246,17 +239,56 @@ def reference_mapping_item_processer(line, line_idx, state, output, context):
return True
elif state == 6:
# check content of table to process in common process
referenced_row = apply_reference_to_row(
line, metadata_dict, context["table_row_idx"], line_idx + 1
output_lines = apply_reference_to_row_ex(
line, metadata_dict, context, line_idx + 1
)

output.append(referenced_row)
output += output_lines
return True

print(state)
return False


def reference_table_scanner(line, _line_idx, state, output, context):
if not line.startswith("|"):
return True

if state >= 0 and state <= 2:
return True
elif state == 5:
return True
elif state == 6:
# check content of table to process in common process
rtm = REFERENCE_TABLE_PATTERN.match(line)
if rtm:
condition = reference_table_match_to_condition(rtm)
context["table_conditions"].append(condition)
return True

return False


def get_c2a_dict(conditions, meta_dict):
c2a_dict = {c: [] for c in conditions}
conditions.sort(
key=lambda c: (-len(c[0]), c[1])
) # 先按照字符串长度降序,随后按照最大深度升序
for api in meta_dict:
for api_prefix, max_depth in conditions:
if not api.startswith(api_prefix):
continue
depth = len(api.split(".")) - 1
if depth > max_depth:
continue
c2a_dict[(api_prefix, max_depth)].append(api)
break
else:
print(f"Warning: cannot find a suitable condition for api {api}")

return c2a_dict


if __name__ == "__main__":
# convert from pytorch basedir
cfp_basedir = os.path.dirname(__file__)
Expand Down Expand Up @@ -287,7 +319,21 @@ def reference_mapping_item_processer(line, line_idx, state, output, context):
"metadata_dict": meta_dict,
"ret_code": 0,
"output": [],
"table_conditions": [],
}

# 第一遍预读,用来分析有哪些表格和匹配条件
ret_code = reference_mapping_item(
mapping_index_file, reference_table_scanner, reference_context
)
assert ret_code == 0
reference_context["output"] = []

# 现在 c2a_dict 包含每个条件对应的 api 列表
c2a_dict = get_c2a_dict(reference_context["table_conditions"], meta_dict)
reference_context["c2a_dict"] = c2a_dict

# 第二遍正式读,读并处理
ret_code = reference_mapping_item(
mapping_index_file, reference_mapping_item_processer, reference_context
)
Expand Down
Loading