Skip to content

Commit

Permalink
Paddle Tensor Operation Library initial implementation (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…34425)

* initial tensor design & sign kernel demo

* add move constructor for meta & add lodtensor

* add dirs & sign xpu kernel

* add mean cpu&cuda kernel impl

* move sign & mean xpu & npu kernel

* add selected_rows basic impl

* refactor design, BaseTensor to DenseTensor, etc.

* add scale mkldnn kernel

* polish xpu & npu impl details

* fix mkldnn reuse compile failed

* change tensor operation lib name

* rename util filename

* add more comments

* change TensorImplInterface to TensorInterface

* add kernel key and factory

* remove MKLDNNTensorMeta, add MKLDNNDenseTensor

* change XXDeviceContext to XXContext

* add base kernel registrar utils & test on sign

* replace boost::any by paddle::any

* fix several ci failed

* fix npu compile error

* add ordered map util

* fix multiple ordered_map compile errors

* move dev into include dir

* support sign op in static op run

* fix static op run error

* fix new executor compile failed

* add dygraph branch & remove sign_op.h

* fix test_infer_no_need_buffer_slots

* fix rocm compile link error

* fix unitybuild error & clear glog

* fix npu compile failed

* skip quant trans test

* fix part windows compile problem

* fix xpu enforce error

* fix inference test failed

* remove ordered_map to solve quant failed

* fix part of rcom compile faild

* add more register kernels

* revert scale kernel temporarily

* fix code format error

* add new kernel registrar marco

* rename top to tcmpt

* revert xpu, npu, mkldnn impl & remove op def

* add kernel args parse functor to auto parse args

* revert some change & add scale kernels

* add op proto in dygraph kernelcontext building

* polish kernel dispatch logic & nameing rule

* fix scale kernel match error

* fix scale test failed

* add mean API and unittest

* test mean api success

* add branch to solve compiled error

* skip clang format error

* add mean skip rule in op_library

* add dot kernel, api and unittest (#6)

* remove old kernel and add symbol link

* fix dot compiled failed

* add merco for module declare

* fix npu and xpu compile error

* revert sign, mean, scale, dot kernel removing

* add comment for keeping old kernel impl

* fix mutable_data error

* fix bfloat16 conflit

* fix inference undef error

* adapt to msvc compile rules

* polish comment for template inst

* add cmake template instantiation for win

* fix backend to place device id bug

* fix ifdef error

* Op2functor (#7)

* add kernel args maker class

* make args maker non-const

* remove debug log

* modify codes by review options

* split constructPrKernelContext function

* fix output name bug

* fix test_mean_op test_sign_op failed

* fill_any_like kernel refactor (#10)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* skip dtype for fill_any_like

* add attrs for kernel key constrcut

* add use_pt_kernel Flags to control whether to use pt kernel (#13)

* add use_pt_kernel Flags to control whether to use pt kernel

* change the default value to true for cheking pt kernels

* fix mutable_data cuda place error

* move high level apis into hapi

* remove selectedrows adapting temporarily

* Support Scalar in Tensor Compute Library (#14)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* Support Scalar in Tensor Compute Library

* add scalar in dygraph and static graph mode

* keep the basic type for attr, instead of using scalar for all

* merge the code

* remove mkldnn tensor & polish details

* use flat_hash_map and small_vector in kernel factory

* Refactor flatten kernel (#12)

* refactor flatten kernel

* update infershape function

* fix compile bugs

* fix bugs when merge

* fix compiler bugs

* fix bugs when run test_flatten_api

* fix bugs when run test

* Revert "use flat_hash_map and small_vector in kernel factory"

This reverts commit 2309149.

* Move cpu, cuda and other device code into kernels (#15)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* Support Scalar in Tensor Compute Library

* add scalar in dygraph and static graph mode

* keep the basic type for attr, instead of using scalar for all

* merge the code

* start refactor matmul

* move cpu, cuda and other device modules into kernels

* merge code

* polish code in operator.cc

* Perfect unitests (#16)

* perfect unittest

* update license

* replace with flat_hash_map, small_vector (#19)

* fix small_vector build error on windows platform

* replace with flat_hash_map, small_vector

* remove todo

* Perfect unitests (#20)

* perfect unittest

* update license

* fix bug when run tcmpt_utils_test

* refactor execution adapting impl

* fix insert conflit

* Fix CI bug of test_yolov3 (#21)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* Support Scalar in Tensor Compute Library

* add scalar in dygraph and static graph mode

* keep the basic type for attr, instead of using scalar for all

* merge the code

* start refactor matmul

* move cpu, cuda and other device modules into kernels

* merge code

* polish code in operator.cc

* Fix CI bug of test_yolov3

* add the tensor base class, test=develop (#17)

* update the tensor base class, test=develop

* remove two funcs, test=develop

* update the error msg, test=develop

Co-authored-by: Chen Weihang <chenweihang@baidu.com>

* [no-verify] commit backend and tensor signature changes

* Rename tcmpt to pten (#23)

* rename tcmpt to pten

* update omitted files for rename to pten

* update omitted file for rename to pten

* remove k of all enum var

* remove kernel_instantiate (#26)

* remove symbols and spatial_tensor

* change common to functions

* readd share tensor impl methods

* add a candidate dense tensor class, test=develop (#28)

* change all Pt to Pten

* resolve conflit with xiaowei

* Op2functor opt1 (#27)

* replace to small vector and change to const &

* add std::move

Co-authored-by: Chen Weihang <chenweihang@baidu.com>

* polish kernel factory and kernel registry

* fix operator test error msg mismatch

* remove tensor signature and backend set member

* move scalar and polish enforce

* revert dtype layout change to fix error

* fix enum operator override error

* add several base unittests

* add pten utils tests

* polish some details

* Dev/op2func refactor 3 (#30)

* add a candidate dense tensor class, test=develop

* remove TensorBase::backend(), test=develop

* remove some ops, test=develop

* cherry-pick the pr of tensor meta, test=develop

* moves the dense tensor and some ops, test=develop

* update the linalg operator, test=develop

* update other operators, test=develop

* fix errors, test=develop

* fix bugs, test=develop

* try to resolve the problem of windows ci, test=develop

* updates codes, test=develop

* fix the tensor_utils.cc, test=develop

* modify the dense tensor, test=develop

* fix the data type, test=develop

Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>

* polish some details

* polish kernel signature details

* fix a bug about offsets of the tensor, test=develop (#31)

Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>

* polish some details

Co-authored-by: chentianyu03 <ctychentianyu@gmail.com>
Co-authored-by: zyfncg <1370305206@qq.com>
Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: 石晓伟 <39303645+Shixiaowei02@users.noreply.github.com>
  • Loading branch information
5 people authored Nov 1, 2021
1 parent 3c0a68c commit b9fdd3b
Show file tree
Hide file tree
Showing 147 changed files with 8,516 additions and 195 deletions.
17 changes: 17 additions & 0 deletions cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ function(find_fluid_modules TARGET_NAME)
endif()
endfunction(find_fluid_modules)

set_property(GLOBAL PROPERTY PTEN_MODULES "")
# find all pten modules is used for paddle static library
# for building inference libs
function(find_pten_modules TARGET_NAME)
get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE)
string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path})
string(FIND "${__target_path}" "pten" pos)
if(pos GREATER 1)
get_property(pten_modules GLOBAL PROPERTY PTEN_MODULES)
set(pten_modules ${pten_modules} ${TARGET_NAME})
set_property(GLOBAL PROPERTY PTEN_MODULES "${pten_modules}")
endif()
endfunction(find_pten_modules)

function(common_link TARGET_NAME)
if (WITH_PROFILER)
target_link_libraries(${TARGET_NAME} gperftools::profiler)
Expand Down Expand Up @@ -310,6 +324,7 @@ function(cc_library TARGET_NAME)
else()
add_library(${TARGET_NAME} STATIC ${cc_library_SRCS})
find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME})
endif()
if(cc_library_DEPS)
# Don't need link libwarpctc.so
Expand Down Expand Up @@ -482,6 +497,7 @@ function(nv_library TARGET_NAME)
else()
add_library(${TARGET_NAME} STATIC ${nv_library_SRCS})
find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME})
endif()
if (nv_library_DEPS)
add_dependencies(${TARGET_NAME} ${nv_library_DEPS})
Expand Down Expand Up @@ -572,6 +588,7 @@ function(hip_library TARGET_NAME)
else()
hip_add_library(${TARGET_NAME} STATIC ${hip_library_SRCS})
find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME})
endif()
if (hip_library_DEPS)
add_dependencies(${TARGET_NAME} ${hip_library_DEPS})
Expand Down
1 change: 1 addition & 0 deletions paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(scripts)
add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
add_subdirectory(pten)
add_subdirectory(fluid)
9 changes: 7 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,12 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va

IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
ENDIF()

cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
Expand Down Expand Up @@ -394,6 +396,8 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc DEPS enforce place)

cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils)

# Get the current working branch
execute_process(
COMMAND git rev-parse --abbrev-ref HEAD
Expand Down Expand Up @@ -456,3 +460,4 @@ if(WITH_TESTING AND TEST selected_rows_test)
endif()

cc_test(scope_guard_test SRCS scope_guard_test.cc)
cc_test(pten_utils_test SRCS pten_utils_test.cc DEPS pten_utils)
223 changes: 191 additions & 32 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"

namespace paddle {
namespace framework {
Expand All @@ -49,6 +50,7 @@ DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check);
PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0,
"number of threads for inner op");
DECLARE_bool(run_pten_kernel);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -1120,8 +1122,24 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
#endif

if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(*runtime_ctx, scope, place);
auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx);

// TODO(chenweihang): Now we are still reusing a lot of the original fluid
// implementation, this is a gradual replacement process
// TODO(chenweihang): in the first phase of project, we only support CPU, CUDA
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
// phase
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
ChoosePtenKernel(exe_ctx);
}
run_pten_kernel_ = pt_kernel_->IsValid();
}
if (!run_pten_kernel_) {
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(exe_ctx);
}
}

// do data transformScope &transfer_scope;
Expand Down Expand Up @@ -1159,8 +1177,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
{
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
if (run_pten_kernel_) {
auto op_kernel_ctx = BuildPtenKernelContext(*runtime_ctx, *dev_ctx);
(*pt_kernel_)(&op_kernel_ctx);
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
}
}

if (!transfered_inplace_vars.empty()) {
Expand Down Expand Up @@ -1208,25 +1231,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}

void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
const Scope& scope,
const platform::Place& place) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);

// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
type_));

OpKernelMap& kernels = kernels_iter->second;
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
const ExecutionContext& ctx) const {
auto& dev_ctx = ctx.device_context();

auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx));
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace();
Expand All @@ -1243,9 +1252,9 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time.
if (SupportGPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
expected_kernel_key.place_ = dev_ctx.GetPlace();
} else if (SupportNPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
expected_kernel_key.place_ = dev_ctx.GetPlace();
} else {
expected_kernel_key.place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
Expand All @@ -1256,6 +1265,47 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
}
VLOG(3) << "op type:" << type_
<< ", expected_kernel_key:" << expected_kernel_key;
return expected_kernel_key;
}

void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));

VLOG(1) << KernelSignatureToString(*pt_kernel_signature_.get());

kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));

auto pt_kernel_name = pten::KernelName(pt_kernel_signature_->name);
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
pt_kernel_.reset(
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key)));

if (pt_kernel_->IsValid()) {
VLOG(1) << "Static mode ChoosePtenKernel - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_;
} else {
VLOG(1) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}

void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
type_));

OpKernelMap& kernels = kernels_iter->second;

auto expected_kernel_key = InnerGetExpectedKernelType(ctx);

auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
Expand Down Expand Up @@ -1562,11 +1612,10 @@ Scope* OperatorWithKernel::PrepareData(
}

void OperatorWithKernel::ParseInputDataType(
const ExecutionContext& ctx, const std::string& name,
const std::vector<Variable*>& vars, const std::string& name,
proto::VarType::Type* data_type) const {
proto::VarType::Type default_data_type =
static_cast<proto::VarType::Type>(-1);
const std::vector<Variable*> vars = ctx.MultiInputVar(name);
for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) {
Expand All @@ -1588,10 +1637,9 @@ void OperatorWithKernel::ParseInputDataType(
if (t != nullptr) {
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.",
Type(), name, ctx.InputNames(name).at(i)));
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(), name));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
Expand All @@ -1614,7 +1662,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : ctx.InNameList()) {
ParseInputDataType(ctx, input, &data_type);
const std::vector<Variable*> vars = ctx.MultiInputVar(input);
ParseInputDataType(vars, input, &data_type);
}
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
Expand All @@ -1628,7 +1677,7 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
ParseInputDataType(ctx, name, &data_type);
ParseInputDataType(ctx.MultiInputVar(name), name, &data_type);
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -1711,5 +1760,115 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
tensor.layout());
}

KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
if (!KernelSignatureMap::Instance().Has(Type())) {
// TODO(chenweihang): we can generate this map by proto info in compile time
KernelArgsNameMakerByOpProto maker(Info().proto_);
KernelSignatureMap::Instance().Emplace(
Type(), std::move(maker.GetKernelSignature()));
}
return KernelSignatureMap::Instance().Get(Type());
}

pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
pten::KernelContext op_kernel_ctx(dev_ctx);

auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args);
auto& output_names = std::get<2>(pt_kernel_signature_->args);

auto input_defs = pt_kernel_->args_def().input_defs();
auto attr_defs = pt_kernel_->args_def().attribute_defs();
auto output_defs = pt_kernel_->args_def().output_defs();

PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
"The size of inputs_args names (%d) must be equal to "
"the size of kernel input_defs (%d).",
input_names.size(), input_defs.size()));

PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(),
platform::errors::InvalidArgument(
"The size of outputs_args names (%d) must be equal to "
"the size of kernel output_defs (%d).",
output_names.size(), output_defs.size()));

PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(),
platform::errors::InvalidArgument(
"The size of attribute_args names (%d) must be equal "
"to the size of kernel attribute_defs (%d).",
attr_names.size(), attr_defs.size()));

for (size_t i = 0; i < input_names.size(); ++i) {
auto in_def = input_defs.at(i);
VLOG(2) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", "
<< in_def.layout;

auto ins_vector = ctx.inputs.at(input_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
for (auto var : ins_vector) {
tmp_inputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(*var, in_def));
}
op_kernel_ctx.EmplaceBackInputs(std::move(tmp_inputs));
}

for (size_t i = 0; i < output_names.size(); ++i) {
auto out_def = output_defs.at(i);
auto outs_vector = ctx.outputs.at(output_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto var : outs_vector) {
tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(var, out_def));
}
op_kernel_ctx.EmplaceBackOutputs(std::move(tmp_outputs));
}

for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext.",
attr_names[i]));
}
} else {
// TODO(chenweihang): support other attrs later
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
"KernelContext.",
attr_names[i]));
}
}
}

return op_kernel_ctx;
}

} // namespace framework
} // namespace paddle
Loading

0 comments on commit b9fdd3b

Please sign in to comment.