Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PRIM][IR]Complete IR vjp code gen for more vjp code gen #56798

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions paddle/fluid/ir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,26 @@ def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr):

def _gen_ret_type(self, op_info):
type_list = op_info.output_type_list
if len(type_list) > 1:
intermediate_list = op_info.output_intermediate_list
assert len(type_list) == len(intermediate_list)

output_num = len(type_list) - intermediate_list.count('true')
if output_num > 1:
return 'std::tuple<{}>'.format(
', '.join([self._type_map[type] for type in type_list])
', '.join(
[
self._type_map[type]
for type, intermediate in zip(
type_list, intermediate_list
)
if intermediate == 'false'
]
)
)
elif len(type_list) == 1:
return self._type_map[type_list[0]]
elif len(type_list) == 0:
elif output_num == 1:
index = intermediate_list.index('false')
return self._type_map[type_list[index]]
elif output_num == 0:
return 'void'

def _gen_one_declare(self, op_info, op_name, is_mutable_attr):
Expand Down Expand Up @@ -252,10 +265,16 @@ def _gen_compute_op(
def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
assert len(name_list) == len(type_list) == len(intermediate_list)

split_op_str = ''
ret_list = []
for i, (name, type) in enumerate(zip(name_list, type_list)):
for i, (name, type, intermediate) in enumerate(
zip(name_list, type_list, intermediate_list)
):
if intermediate == 'true':
continue
if VECTOR_TYPE in type:
split_op_name = f'{name}_split_op'
split_op_str += SPLIT_OP_TEMPLATE.format(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def GenBuildInputArgsStr(
'float': 'phi::DataType::FLOAT32',
'std::vector<int64_t>': 'phi::DataType::INT64',
'const std::vector<int64_t>&': 'phi::DataType::INT64',
'bool': 'phi::DataType::BOOL',
}


Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/ir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,19 @@ def parse_attribute_build_arg_type_list(self):
if 'Scalar' in temp_type:
if 'data_type' in attribute_info:
temp_type = attribute_info['data_type']
op_name = self.op_yaml_item['name']
attr_name = attribute_info['name']
if (
op_name not in ["isclose", "allclose"]
and self.op_compat_item is not None
and 'scalar' in self.op_compat_item.keys()
and attr_name in self.op_compat_item['scalar'].keys()
and 'data_type'
in self.op_compat_item['scalar'][attr_name].keys()
):
temp_type = self.op_compat_item['scalar'][attr_name][
'data_type'
]
if 'IntArray' in temp_type:
if 'data_type' in attribute_info:
temp_type = "const " + attribute_info['data_type'] + "&"
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,15 @@ ir::OpResult split_grad(std::vector<ir::OpResult> out_grads,

return split_grad_op.x_grad();
}

ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, int axis) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>(out_grads);
paddle::dialect::SplitGradOp split_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>(
combine_op.out(), axis);

return split_grad_op.x_grad();
}
} // namespace dialect
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ namespace dialect {

ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, ir::OpResult axis);

ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, int axis);
} // namespace dialect
} // namespace paddle
7 changes: 2 additions & 5 deletions paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));

IntArray axis = op_obj.axis()
.GetDefiningOp()
->attribute("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data();
Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));

bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data();
bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp(
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/operators/generator/tests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,21 @@ def supports_no_need_buffer(op):

def is_tensor_list(s):
return s == 'Tensor[]'


def exist_mutable_attribute(attrs):
for attr in attrs:
if (
attr['typename'] in ['Scalar', 'IntArray']
and attr['support_tensor'] is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Scalar的类型不仅仅是Scalar还可能是Scalar(int) Scalar(int64_t)等,这个函数可以借助tests_utils.py中的
def is_scalar(s):
    return re.match(r"Scalar(\(\w+\))*", s) is not None


def is_intarray(s):
    return s == 'IntArray'

进行判断。
2. 新IR下可变attribute是否需要对:

attr['tensor_name'] is not None or attr['tensors_name'] is not None

进行判断。

Copy link
Contributor Author

@Charles-hit Charles-hit Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢提醒,第一个点我理解有明确类型数据类型应该不需要修改了,第二点已经在gen.py中进行处理了。

):
return True
else:
return False


def is_mutable_attribute(attr):
return (
attr['typename'] in ['Scalar', 'IntArray']
and attr['support_tensor'] is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

属于这两类'Scalar', 'IntArray',但是没有support_tensor属性的算子怎样处理的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会变成常量处理

)
8 changes: 0 additions & 8 deletions paddle/fluid/primitive/backend/manual/manual_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;

template <typename T>
std::vector<Tensor> concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis);

template <typename T>
Tensor split_grad(const std::vector<Tensor>& out_grads, const Tensor& axis);

} // namespace backend
} // namespace primitive
} // namespace paddle
48 changes: 0 additions & 48 deletions paddle/fluid/primitive/backend/manual/manual_static_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,54 +23,6 @@ namespace backend {

using LazyTensor = paddle::primitive::LazyTensor;

template <>
std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis) {
std::vector<ir::OpResult> x_res;
for (uint64_t idx = 0; idx < x.size(); idx++) {
x_res.emplace_back(std::static_pointer_cast<LazyTensor>(x[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}

ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();

ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();

std::vector<ir::OpResult> op_res =
paddle::dialect::concat_grad(x_res, out_grad_res, axis_res);

std::vector<Tensor> op_result;
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
op_result.emplace_back(
std::make_shared<primitive::LazyTensor>(op_res[idx]));
}
return op_result;
}

template <>
Tensor split_grad<LazyTensor>(const std::vector<Tensor>& out_grads,
const Tensor& axis) {
std::vector<ir::OpResult> out_grads_res;
for (uint64_t idx = 0; idx < out_grads.size(); idx++) {
out_grads_res.emplace_back(
std::static_pointer_cast<LazyTensor>(out_grads[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}
ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::split_grad(out_grads_res, axis_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}

} // namespace backend
} // namespace primitive
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/fluid/primitive/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
set(compat_path "${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml")
set(destination_dir "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/")
set(scripts "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/gen.py")

Expand All @@ -17,7 +18,7 @@ execute_process(
${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path
${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path
${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
--destination_dir ${destination_dir}
--compat_path ${compat_path} --destination_dir ${destination_dir}
RESULT_VARIABLE _result)
if(${_result})
message(
Expand Down
Loading