Skip to content

Commit

Permalink
Refactor code auto-gene for no_need_buffer (#41025)
Browse files Browse the repository at this point in the history
* refactor code auto-gene for no_need_buffer

* fix some bug

* delete test code
  • Loading branch information
zyfncg authored Mar 30, 2022
1 parent 9219495 commit 97cd0f5
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ def ParseInplaceInfo(self):
self.inplace_map[key] = val

def ParseNoNeedBuffer(self):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents

if 'no_need_buffer' in forward_api_contents.keys():
no_need_buffer_str = forward_api_contents['no_need_buffer']
if 'no_need_buffer' in grad_api_contents.keys():
no_need_buffer_str = grad_api_contents['no_need_buffer']
for name in no_need_buffer_str.split(","):
name = name.strip()
name = RemoveSpecialSymbolsInName(name)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -56,7 +56,7 @@ def ParseArguments():
########################
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
"""
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
{} = egr::TensorWrapper({}, full_reserved, {});
}}
"""
Expand Down Expand Up @@ -121,19 +121,19 @@ class {} : public egr::GradNodeBase {{
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{}
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
Expand Down Expand Up @@ -192,7 +192,7 @@ class {} : public egr::GradNodeBase {{
if(require_any_grad) {{
{}
egr::EagerUtils::PassStopGradient({});
// Node Construction
{}
// SetAttributes
Expand Down Expand Up @@ -379,7 +379,7 @@ def __init__(self, forward_api_contents, grad_api_contents, namespace):
#self.forward_outputs_position_map
#self.optional_inputs
#self.no_need_buffers
#self.intermediate_outputs
#self.intermediate_outputs
#self.inplace_map
FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kerne
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)

cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl global_utils)
cc_library(sparse_api SRCS ${sparse_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl)
cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api sparse_api_custom_impl)
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api)
Expand Down
46 changes: 26 additions & 20 deletions python/paddle/utils/code_gen/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,31 @@ def gene_api_declaration(self):

return api_declaration

# Backward API Override this method
def gene_kernel_backend_select(self):
backend_select_code = ""
if self.kernel['backend'] is not None:
if '>' in self.kernel['backend']:
vars_list = self.kernel['backend'].split('>')
assert len(
vars_list
) == 2, f"{self.api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in self.attrs['names']) and (self.attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \
f"{self.api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
backend_select_code = f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

else:
backend_args = [
ele.strip() for ele in self.kernel['backend'].split(',')
]
backend_select_code = f"""
kernel_backend = ParseBackend({", ".join(backend_args)});
"""

return backend_select_code

def gene_kernel_select(self) -> str:
api = self.api
input_names = self.inputs['names']
Expand Down Expand Up @@ -345,26 +370,7 @@ def gene_kernel_select(self) -> str:
attr_data_type_count = attr_data_type_count + 1

# preprocess kernel configures
kernel_select_code = ""
if kernel['backend'] is not None:
if '>' in kernel['backend']:
vars_list = kernel['backend'].split('>')
assert len(
vars_list
) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Place'), \
f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Place type."
kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""

else:
args_str = ""
for ele in kernel['backend'].split(','):
args_str = args_str + ele.strip() + ', '
kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackend({args_str[:-2]});
"""
kernel_select_code = self.gene_kernel_backend_select()

if kernel['layout'] is not None:
if '>' in kernel['layout']:
Expand Down
24 changes: 24 additions & 0 deletions python/paddle/utils/code_gen/backward_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BackwardAPI(BaseAPI):
def __init__(self, backward_item_yaml):
super(BackwardAPI, self).__init__(backward_item_yaml)
self.check_args(backward_item_yaml['forward'])
self.no_need_buffer = self.parse_no_need_buffer(backward_item_yaml)

def get_api_name(self, api_item_yaml):
return api_item_yaml['backward_api']
Expand All @@ -41,6 +42,15 @@ def parse_forward_config(self, forward_config):

return api, fw_inputs, fw_attrs, outputs

def parse_no_need_buffer(self, api_item_yaml):
no_need_buffer = []
if 'no_need_buffer' in api_item_yaml:
no_need_buffer = [
item.strip()
for item in api_item_yaml['no_need_buffer'].split(',')
]
return no_need_buffer

def check_args(self, forward_config):
# parse the forward and backward config
_, fw_inputs, fw_attrs, fw_outputs = self.parse_forward_config(
Expand All @@ -67,6 +77,19 @@ def check_args(self, forward_config):
f"{self.api} : Output error: The number of outputs should be less then the number of inputs of forward api. \
Please check the output of {self.api} in yaml."

def gene_kernel_backend_select(self):
all_no_need_buffer = True
for in_name in self.inputs['names']:
if in_name not in self.no_need_buffer:
all_no_need_buffer = False

if all_no_need_buffer:
return """
kernel_backend = ParseBackend(egr::Controller::Instance().GetExpectedPlace());
"""
else:
return super().gene_kernel_backend_select()

def get_return_type(self, out_type_list):
return out_type_list[0] if len(
out_type_list) == 1 else "std::vector<std::vector<Tensor>>"
Expand Down Expand Up @@ -154,6 +177,7 @@ def source_include(header_file_path):
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
"""

Expand Down
3 changes: 3 additions & 0 deletions python/paddle/utils/code_gen/sparse_bw_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self, bw_api_item_yaml):
def get_api_func_name(self):
return self.api

def gene_kernel_backend_select(self):
return BackwardAPI.gene_kernel_backend_select(self)

def get_return_type(self, out_type_list):
return BackwardAPI.get_return_type(self, out_type_list)

Expand Down

0 comments on commit 97cd0f5

Please sign in to comment.