Skip to content

Commit

Permalink
support ir ops and fused ops vjp gen
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit committed Oct 17, 2023
1 parent 3e540c4 commit d21fea6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
9 changes: 8 additions & 1 deletion paddle/fluid/primitive/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ set(fwd_path ${parsed_yaml_path}/ops.parsed.yaml)
set(fwd_legacy_path ${parsed_yaml_path}/legacy_ops.parsed.yaml)
set(rev_path ${parsed_yaml_path}/backward_ops.parsed.yaml)
set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml)
set(fwd_pd_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops.parsed.yaml
)
set(rev_pd_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated//ops_backward.parsed.yaml
)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
Expand All @@ -17,7 +23,8 @@ execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path
${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path
${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
${rev_legacy_path} --fwd_pd_op_path ${fwd_pd_op_path} --rev_pd_op_path
${rev_pd_op_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
--compat_path ${compat_path} --destination_dir ${destination_dir}
RESULT_VARIABLE _result)
if(${_result})
Expand Down
44 changes: 40 additions & 4 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@
'add_n_grad',
]

BACKENDS_BLACK_LIST = ['copy_to', 'add_n_grad', "allclose", "isclose"]
BACKENDS_BLACK_LIST = [
'copy_to',
'add_n_grad',
"allclose",
"isclose",
"send_v2",
"assert",
]


PRIM_VJP = [
Expand Down Expand Up @@ -301,6 +308,8 @@ def gen(
rev_path: pathlib.Path,
rev_legacy_path: pathlib.Path,
compat_path: pathlib.Path,
fwd_pd_op_path: pathlib.Path,
rev_pd_op_path: pathlib.Path,
templates_dir: pathlib.Path,
destination_dir: pathlib.Path,
):
Expand All @@ -316,23 +325,38 @@ def gen(
rev_legacy_path (pathlib.Path): The YAML file path of the legacy
backward API.
compat_path: (pathlib.Path): The YAML file path of the ops compat.
fwd_pd_op_path (pathlib.Path): The YAML file path of the ir forward API.
rev_pd_op_path (pathlib.Path): The YAML file path of the ir backward API.
templates_dir (pathlib.Path): The directory of the templates.
destination_dir (pathlib.Path): The Directory of the generated file.
Returns:
None
"""
prims, fwds, legacy_fwds, revs, legacy_revs, compats = (
(
prims,
fwds,
legacy_fwds,
revs,
legacy_revs,
compats,
ir_fwds,
ir_revs,
) = (
load(prim_path),
load(fwd_path),
load(fwd_legacy_path),
load(rev_path),
load(rev_legacy_path),
load(compat_path),
load(fwd_pd_op_path),
load(rev_pd_op_path),
)
filter_compat_info(compats)
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds]
apis = apis + [{**api, **{'is_fwd': False}} for api in revs + legacy_revs]
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds + ir_fwds]
apis = apis + [
{**api, **{'is_fwd': False}} for api in revs + legacy_revs + ir_revs
]
apis = [
{**api, **{'is_prim': True}}
if api['name'] in prims
Expand Down Expand Up @@ -383,6 +407,16 @@ def gen(
type=str,
help='The parsed ops compat yaml file.',
)
parser.add_argument(
'--fwd_pd_op_path',
type=str,
help='The ir forward ops parsed yaml file.',
)
parser.add_argument(
'--rev_pd_op_path',
type=str,
help='The ir backward ops parsed yaml file.',
)
parser.add_argument(
'--templates_dir',
type=str,
Expand All @@ -402,6 +436,8 @@ def gen(
pathlib.Path(args.rev_path),
pathlib.Path(args.rev_legacy_path),
pathlib.Path(args.compat_path),
pathlib.Path(args.fwd_pd_op_path),
pathlib.Path(args.rev_pd_op_path),
pathlib.Path(args.templates_dir),
pathlib.Path(args.destination_dir),
)

0 comments on commit d21fea6

Please sign in to comment.