Skip to content

Commit

Permalink
support ir ops and fused ops vjp gen
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit committed Oct 17, 2023
1 parent 3e540c4 commit f3b34ad
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 17 deletions.
9 changes: 8 additions & 1 deletion paddle/fluid/primitive/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ set(fwd_path ${parsed_yaml_path}/ops.parsed.yaml)
set(fwd_legacy_path ${parsed_yaml_path}/legacy_ops.parsed.yaml)
set(rev_path ${parsed_yaml_path}/backward_ops.parsed.yaml)
set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml)
set(fwd_pd_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops.parsed.yaml
)
set(rev_pd_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops_backward.parsed.yaml
)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
Expand All @@ -17,7 +23,8 @@ execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path
${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path
${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
${rev_legacy_path} --fwd_pd_op_path ${fwd_pd_op_path} --rev_pd_op_path
${rev_pd_op_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
--compat_path ${compat_path} --destination_dir ${destination_dir}
RESULT_VARIABLE _result)
if(${_result})
Expand Down
64 changes: 51 additions & 13 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,15 @@
'add_n_grad',
]

BACKENDS_BLACK_LIST = ['copy_to', 'add_n_grad', "allclose", "isclose"]
BACKENDS_BLACK_LIST = [
'copy_to',
'add_n_grad',
"allclose",
"isclose",
"send_v2",
"assert",
"embedding_grad_sparse",
]


PRIM_VJP = [
Expand Down Expand Up @@ -280,18 +288,19 @@ def process_backward_invoke_info(apis):

def process_optional_output_info(apis):
for api in apis:
if not api['is_fwd']:
continue
inputs_dict = to_named_dict(api['inputs'])
for output in api['outputs']:
if (
api.get("inplace", None)
and output['name'] in api['inplace']
and inputs_dict[api['inplace'][output['name']]]['optional']
):
output['optional'] = True
else:
if not api['is_fwd']:
output['optional'] = False
else:
if (
api.get("inplace", None)
and output['name'] in api['inplace']
and inputs_dict[api['inplace'][output['name']]]['optional']
):
output['optional'] = True
else:
output['optional'] = False


def gen(
Expand All @@ -301,6 +310,8 @@ def gen(
rev_path: pathlib.Path,
rev_legacy_path: pathlib.Path,
compat_path: pathlib.Path,
fwd_pd_op_path: pathlib.Path,
rev_pd_op_path: pathlib.Path,
templates_dir: pathlib.Path,
destination_dir: pathlib.Path,
):
Expand All @@ -316,23 +327,38 @@ def gen(
rev_legacy_path (pathlib.Path): The YAML file path of the legacy
backward API.
compat_path: (pathlib.Path): The YAML file path of the ops compat.
fwd_pd_op_path (pathlib.Path): The YAML file path of the ir forward API.
rev_pd_op_path (pathlib.Path): The YAML file path of the ir backward API.
templates_dir (pathlib.Path): The directory of the templates.
destination_dir (pathlib.Path): The Directory of the generated file.
Returns:
None
"""
prims, fwds, legacy_fwds, revs, legacy_revs, compats = (
(
prims,
fwds,
legacy_fwds,
revs,
legacy_revs,
compats,
ir_fwds,
ir_revs,
) = (
load(prim_path),
load(fwd_path),
load(fwd_legacy_path),
load(rev_path),
load(rev_legacy_path),
load(compat_path),
load(fwd_pd_op_path),
load(rev_pd_op_path),
)
filter_compat_info(compats)
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds]
apis = apis + [{**api, **{'is_fwd': False}} for api in revs + legacy_revs]
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds + ir_fwds]
apis = apis + [
{**api, **{'is_fwd': False}} for api in revs + legacy_revs + ir_revs
]
apis = [
{**api, **{'is_prim': True}}
if api['name'] in prims
Expand Down Expand Up @@ -383,6 +409,16 @@ def gen(
type=str,
help='The parsed ops compat yaml file.',
)
parser.add_argument(
'--fwd_pd_op_path',
type=str,
help='The ir forward ops parsed yaml file.',
)
parser.add_argument(
'--rev_pd_op_path',
type=str,
help='The ir backward ops parsed yaml file.',
)
parser.add_argument(
'--templates_dir',
type=str,
Expand All @@ -402,6 +438,8 @@ def gen(
pathlib.Path(args.rev_path),
pathlib.Path(args.rev_legacy_path),
pathlib.Path(args.compat_path),
pathlib.Path(args.fwd_pd_op_path),
pathlib.Path(args.rev_pd_op_path),
pathlib.Path(args.templates_dir),
pathlib.Path(args.destination_dir),
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;

{% for api in apis %}
{%- if api is only_composite_op -%}{#- render nothing -#}
{%- if api is only_composite_op or "infer_meta" not in api and "composite" not in api and "invoke" not in api -%}{#- render nothing -#}
{%- elif api.name not in backend_black_list -%}
{%- if 'invoke' not in api or 'invoke' in api and api.is_fwd -%}
{% if api.attrs is exist_mutable_attribute %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace backend {

{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#}
{{common.sequence('', '', ', ', inputs)}}
{%- if attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{%- if attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between
nputs and attrs -#}
{{common.sequence('', '', ', ', attrs)}}
{%- endmacro -%}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}}


{% for api in apis %}
{%- if api is only_composite_op -%}{#- render nothing -#}
{%- if api is only_composite_op or "infer_meta" not in api and "composite" not in api and "invoke" not in api -%}{#- render nothing -#}
{% elif api.name not in backend_black_list %}
{%- if 'invoke' not in api or 'invoke' in api and api.is_fwd-%}
{% set api_outputs = api.outputs | trip_intermediate %}
Expand Down

0 comments on commit f3b34ad

Please sign in to comment.