Skip to content

Commit

Permalink
support manual vjp
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit committed Aug 17, 2023
2 parents 95642e3 + 488071a commit 7fdb9c4
Show file tree
Hide file tree
Showing 17 changed files with 457 additions and 482 deletions.
20 changes: 18 additions & 2 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,19 @@ void NewIRInterpreter::UpdateSyncOpNum() {
VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_;
}

void NewIRInterpreter::UpdateNcclOpNum() {
static std::set<std::string> nccl_op_set = {
"pd.sync_batch_norm_", "pd.sync_batch_norm", "pd.sync_batch_norm_grad"};
int64_t nccl_op_num = 0;
for (auto& ins : vec_instruction_base_) {
if (nccl_op_set.count(ins->Name())) {
nccl_op_num = nccl_op_num + 1;
}
}
nccl_op_num_ = nccl_op_num;
VLOG(4) << "Update nccl op num, nccl op num is: " << nccl_op_num;
}

// Note(zhangbo):
// When there is a KQueueSync type OP in the model, breadth traversal is better
// than depth traversal. For example: OP(O) ->(direct_run)-> OP(A)
Expand Down Expand Up @@ -852,7 +865,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
VLOG(4) << "Done PreAnalysis";

// Run
if (FLAGS_enable_new_ir_in_executor_trace_run ||
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
Expand All @@ -867,7 +880,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
is_build_ = true;
is_shared_results_build_ = true;
} else {
if (FLAGS_enable_new_ir_in_executor_trace_run ||
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
TraceRunImpl();
Expand Down Expand Up @@ -1182,6 +1195,9 @@ void NewIRInterpreter::PreAnalysis() {

UpdateSyncOpNum();
VLOG(4) << "Done UpdateSyncOpNum";

UpdateNcclOpNum();
VLOG(4) << "Done UpdateNcclOpNum";
}

::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
private:
// build graph
void UpdateSyncOpNum();
void UpdateNcclOpNum();
void AnalyseExecuteOrderForTrace(
std::map<size_t, std::set<size_t>> op_downstream_map,
InstructionSchedulingPriorityLess compare);
Expand Down Expand Up @@ -148,6 +149,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {

// used for Trace
int64_t sync_op_num_{-1};
int64_t nccl_op_num_{-1};
std::vector<size_t> trace_execute_order_;

std::vector<HookFunc> hookfuncs_;
Expand Down
40 changes: 32 additions & 8 deletions paddle/fluid/ir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
vjp_interface_gen_op_list,
gen_op_vjp_str,
)
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str
from vjp_interface_gen_op_list import (
vjp_interface_declare_gen_op_list,
vjp_interface_implementation_gen_op_list,
)

# =====================================
# String Template for h file code gen
Expand Down Expand Up @@ -286,6 +290,9 @@ def __init__(self, op_yaml_item, op_compat_item):
self.attribute_build_arg_type_list = (
self.parse_attribute_build_arg_type_list()
)
self.attribute_gen_arg_type_list = (
self.parse_attribute_gen_arg_type_list()
)
self.attribute_data_type_list = self.parse_attribute_data_type_list()
self.attribute_default_value_list = (
self.parse_attribute_default_value_list()
Expand Down Expand Up @@ -584,6 +591,17 @@ def parse_attribute_build_arg_type_list(self):
type_list.append(self.get_phi_dtype_name(temp_type))
return type_list

def parse_attribute_gen_arg_type_list(self):
type_list = []
for attribute_info in self.op_yaml_item['attrs']:
assert (
attribute_info['typename'] in self.attr_types_map
), f"{self.op_phi_name} : Attr type error."

temp_type = self.attr_types_map[attribute_info['typename']][1]
type_list.append(self.get_phi_dtype_name(temp_type))
return type_list

def parse_attribute_type_list(self):
type_list = []
for attribute_info in self.op_yaml_item['attrs']:
Expand Down Expand Up @@ -741,7 +759,7 @@ def OpGenerator(

if (
op_info.backward_name
and op_info.op_phi_name[0] in vjp_interface_gen_op_list
and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list
):
op_interfaces += ["VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(op_info)
Expand Down Expand Up @@ -1038,12 +1056,18 @@ def OpGenerator(
op_vjp_str = ''

# TODO(chenzhiyang) add vjp gen code
# if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list:
# op_vjp_str = gen_op_vjp_str(op_class_name,
# op_info.backward_name,
# op_name,
# op_info_items[op_info.op_phi_name[0]],
# op_info_items[op_info.backward_name])
if (
op_info.backward_name
and op_info.op_phi_name[0]
in vjp_interface_implementation_gen_op_list
):
op_vjp_str = gen_op_vjp_str(
op_class_name,
op_info.backward_name,
op_name,
op_info_items[op_info.op_phi_name[0]],
op_info_items[op_info.backward_name],
)

ops_name_list.append(op_class_name)
ops_declare_list.append(op_declare_str)
Expand Down
129 changes: 88 additions & 41 deletions paddle/fluid/ir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# generator interfaces
from vjp_interface_gen_op_list import vjp_interface_gen_op_list
from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list

OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
Expand All @@ -23,57 +23,61 @@
"""

OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """
{input_type} {input_name}(std::make_shared<primitive::LazyTensor>(op_obj.{input_name}()));
"""
{input_type} {input_name}(std::make_shared<primitive::LazyTensor>(op_obj.{input_name}()));"""

OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::LazyTensor>((out_grads[{idx1}][{idx2}]);
"""
Tensor {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}][{idx2}]));"""

OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::LazyTensor>((out_grads[{idx1}]);
"""
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}]));"""

OP_VJP_CALL_VJP_TEMPLATE = """
Tensor std::vector<std::vector<Tensor>> tensor_res =
primitive::{op_phi_name}_vjp({inputs_list}, stop_gradients);
"""
OP_VJP_ATTRIBUTE_TEMPLATE = """
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();"""

OP_VJP_STOPGRADIENT_TEMPLATE = """
if(!stop_gradients[{idx1}][{idx2}]){{
res[{idx1}][{idx2}] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[idx1][idx2].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}}
"""
OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """
{attr_type} {attr_name} = {default_value};"""

OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>();

VLOG(6) << "Prepare inputs of {op_grad_name}";
OP_VJP_CALL_VJP_TEMPLATE = """ std::vector<std::vector<Tensor>> tensor_res =
primitive::{op_phi_name}_vjp(
{inputs_list}stop_gradients);"""

{forward_input_code}
{forward_output_code}
{forward_output_grad_code}
OP_VJP_STOPGRADIENT_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {{
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {{
if(tensor_res[i][j].defined()){{
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>();
}}
}}
}}"""

OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}
VLOG(6) << "Prepare inputs of {op_grad_name}";
{forward_input_code}
{forward_output_grad_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code}
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
{stop_gradient_input_grad_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code}
return res;
VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}";
{stop_gradient_input_grad_code}
return res;
}}
"""

input_types_map = {
'paddle::dialect::DenseTensorType': 'Tensor',
'ir::VectorType<paddle::dialect::DenseTensorType>': 'Tensor[]',
}


def gen_op_vjp_str(
op_class_name,
Expand All @@ -82,19 +86,62 @@ def gen_op_vjp_str(
op_info,
op_grad_info,
):
bw_input_list = op_grad_info.input_name_list
forward_input_code = ''
forward_output_code = ''
forward_output_grad_code = ''
build_args_str = ''
grad_idx = -1
for idx in range(len(bw_input_list)):
build_args_str += bw_input_list[idx] + ", "
if (
bw_input_list[idx] in op_info.input_name_list
or bw_input_list[idx] in op_info.output_name_list
):
forward_input_code += (
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
input_type=input_types_map[
op_grad_info.input_type_list[idx]
],
input_name=bw_input_list[idx],
)
)
else:
grad_idx += 1
forward_output_grad_code += (
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format(
output_grad_name=bw_input_list[idx], idx1=grad_idx, idx2=0
)
)
op_attribute_list = op_grad_info.attribute_name_list
attribute_code = ''
call_vjp_code = ''
stop_gradient_input_grad_code = ''
for idx in range(len(op_attribute_list)):
build_args_str += op_attribute_list[idx] + ", "
if op_attribute_list[idx] in op_info.attribute_name_list:
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
attr_parse_type=op_grad_info.attribute_type_list[idx],
)
else:
attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
default_value=op_grad_info.attribute_default_value_list[idx],
)
if op_phi_name[-1] == '_':
op_phi_name = op_phi_name[:-1]
call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format(
op_phi_name=op_phi_name,
inputs_list=build_args_str,
)
stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE

str = OP_VJP_DEFINE_TEMPLATE.format(
op_class_name=op_class_name,
op_grad_name=op_grad_name,
op_phi_name=op_phi_name,
res_size=len(op_info.input_name_list),
forward_input_code=forward_input_code,
forward_output_code=forward_output_code,
forward_output_grad_code=forward_output_grad_code,
attribute_code=attribute_code,
call_vjp_code=call_vjp_code,
Expand All @@ -119,6 +166,6 @@ def gen_exclusive_interface_str(op_info):
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] in vjp_interface_gen_op_list:
if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<ir::OpResult>> Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
# TODO(wanghao107)
# remove this file and support Vjp methods
# code gen.
vjp_interface_gen_op_list = ["tanh", "mean", "divide", "sum", "add"]
vjp_interface_declare_gen_op_list = ["tanh", "mean", "divide", "sum", "add"]
vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"]
Loading

0 comments on commit 7fdb9c4

Please sign in to comment.