Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim][Dist] Support codegen of decompvjp interface #64464

Merged
merged 16 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,15 @@ cc_library(

#Note(risemeup1):compile some *.cc files which depend on primitive_vjp_experimental into op_dialect_vjp.a/lib
set(op_decomp_source_file ${PIR_DIALECT_BINARY_DIR}/op_decomp.cc)
# set(op_decomp_vjp_source_file ${PIR_DIALECT_BINARY_DIR}/op_decomp_vjp.cc)
set(op_decomp_vjp_source_file ${PIR_DIALECT_BINARY_DIR}/op_decomp_vjp.cc)

set(op_dialect_vjp_srcs
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp.cc
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp_vjp.cc
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc
${op_decomp_source_file}
${op_decomp_vjp_source_file}
${op_vjp_source_file}
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/base/decomp_trans.cc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,53 @@
# xshape output will no longer used after decomp, but return none to keep output num the same as origin op
decomp_ops_contain_unused_output = ["squeeze", "unsqueeze"]

decomp_vjp_interface_declare_gen_op_list = [
"add_grad",
"matmul_grad",
"relu_grad",
# prim op with one input and one output, with no attribute
UNARY_PRIM_VJP_OPS = [
'abs_grad',
'erf_grad',
'exp_grad',
'floor_grad',
'log_grad',
'rsqrt_grad',
'sin_grad',
'cos_grad',
'tanh_grad',
]

# prim op with two inputs and one output, with no attribute
BINARY_PRIM_VJP_OPS = [
'matmul_grad',
'add_grad',
'divide_grad',
'subtract_grad',
'multiply_grad',
'elementwise_pow_grad',
'maximum_grad',
'reduce_as_grad',
]

CUSTOM_VJP = [
'gelu_grad',
'hardswish_grad',
'leaky_relu_grad',
'mean_grad',
'minimum_grad',
'pow_grad',
'relu_grad',
'sigmoid_grad',
'silu_grad',
'softmax_grad',
'sqrt_grad',
'swiglu_grad',
'layer_norm_grad',
'group_norm_grad',
] # custom vjp list of composite op

# declare belongs to codegen, but implementation not
OTHER_VJP = ["stack_grad"]

vjp_list = UNARY_PRIM_VJP_OPS + BINARY_PRIM_VJP_OPS + CUSTOM_VJP

decomp_vjp_interface_declare_gen_op_list = vjp_list + OTHER_VJP

decomp_vjp_interface_implementation_gen_op_list = vjp_list
49 changes: 49 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,52 @@ def to_pascal_case(s):
return "".join([word.capitalize() for word in words]) + "_"
else:
return "".join([word.capitalize() for word in words]) + ""


attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['pir::Int32Attribute', 'int'],
'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'],
'Scalar(float)': ['pir::FloatAttribute', 'float'],
'Scalar(double)': ['pir::DoubleAttribute', 'double'],
'Scalar[]': [
'pir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'const std::vector<Scalar>&',
],
'int': ['pir::Int32Attribute', 'int'],
'int32_t': ['pir::Int32Attribute', 'int32_t'],
'int64_t': ['pir::Int64Attribute', 'int64_t'],
'long': ['pir::LongAttribute', 'long'],
'size_t': ['pir::Size_tAttribute', 'size_t'],
'float': ['pir::FloatAttribute', 'float'],
'float[]': [
'pir::ArrayAttribute<pir::FloatAttribute>',
'const std::vector<float>&',
],
'double': ['pir::DoubleAttribute', 'double'],
'bool': ['pir::BoolAttribute', 'bool'],
'bool[]': [
'pir::ArrayAttribute<pir::BoolAttribute>',
'const std::vector<bool>&',
],
'str': ['pir::StrAttribute', 'const std::string&'],
'str[]': [
'pir::ArrayAttribute<pir::StrAttribute>',
'const std::vector<std::string>&',
],
'Place': ['paddle::dialect::PlaceAttribute', 'const phi::Place&'],
'DataLayout': [
'paddle::dialect::DataLayoutAttribute',
'DataLayout',
],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'pir::ArrayAttribute<pir::Int64Attribute>',
'const std::vector<int64_t>&',
],
'int[]': [
'pir::ArrayAttribute<pir::Int32Attribute>',
'const std::vector<int>&',
],
}
50 changes: 1 addition & 49 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
decomp_interface_declare_gen_op_list,
decomp_vjp_interface_declare_gen_op_list,
)
from gen_utils import to_pascal_case
from gen_utils import attr_types_map, to_pascal_case
from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str
from op_all_func_gen import gen_op_all_func
from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke
Expand Down Expand Up @@ -344,54 +344,6 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
'expand',
}

attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['pir::Int32Attribute', 'int'],
'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'],
'Scalar(float)': ['pir::FloatAttribute', 'float'],
'Scalar(double)': ['pir::DoubleAttribute', 'double'],
'Scalar[]': [
'pir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'const std::vector<Scalar>&',
],
'int': ['pir::Int32Attribute', 'int'],
'int32_t': ['pir::Int32Attribute', 'int32_t'],
'int64_t': ['pir::Int64Attribute', 'int64_t'],
'long': ['pir::LongAttribute', 'long'],
'size_t': ['pir::Size_tAttribute', 'size_t'],
'float': ['pir::FloatAttribute', 'float'],
'float[]': [
'pir::ArrayAttribute<pir::FloatAttribute>',
'const std::vector<float>&',
],
'double': ['pir::DoubleAttribute', 'double'],
'bool': ['pir::BoolAttribute', 'bool'],
'bool[]': [
'pir::ArrayAttribute<pir::BoolAttribute>',
'const std::vector<bool>&',
],
'str': ['pir::StrAttribute', 'const std::string&'],
'str[]': [
'pir::ArrayAttribute<pir::StrAttribute>',
'const std::vector<std::string>&',
],
'Place': ['paddle::dialect::PlaceAttribute', 'const phi::Place&'],
'DataLayout': [
'paddle::dialect::DataLayoutAttribute',
'DataLayout',
],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'pir::ArrayAttribute<pir::Int64Attribute>',
'const std::vector<int64_t>&',
],
'int[]': [
'pir::ArrayAttribute<pir::Int32Attribute>',
'const std::vector<int>&',
],
}


def to_phi_and_fluid_op_name(op_item):
# Template: - op : phi_name (fluid_name)
Expand Down
Loading