Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PluggableDevice] custom kernel supports multi cpp_dtypes #39385

Merged
merged 1 commit into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 76 additions & 38 deletions paddle/fluid/framework/custom_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function
namespace custom_kernel {

// Here we use dot <CPU, ANY, UINT8> for test
// This test will fail when these two kernels are aupported in framework
// Here we use fake_dot for test
// input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*>
template <typename T>
void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
template <typename T, typename Context>
void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
const paddle::Tensor& y,
const std::vector<paddle::Tensor>& fake_input_vec,
bool fake_attr_bool, int fake_attr_int, float fake_attr_float,
Expand Down Expand Up @@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
}
} // namespace custom_kernel

PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, UINT8,
custom_kernel::FakeDot<uint8_t>) {
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UINT8);
}
PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float,
double, int, int64_t, int8_t, uint8_t) {}

// Upper code will store dot kernels info into OpKernelInfoMap
TEST(CustomKernel, custom_kernel_dot) {
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataLayout layout = pten::DataLayout::ALL_LAYOUT;

// 1.custom kernel info parsed and store
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find("dot") !=
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) !=
paddle::OpKernelInfoMap::Instance().GetMap().end());

// 2.info check
EXPECT_EQ(
1, static_cast<int>(paddle::OpKernelInfoMap::Instance()["dot"].size()));
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetBackend() ==
6, static_cast<int>(paddle::OpKernelInfoMap::Instance()[op_name].size()));
// index 0
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataLayout() ==
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataType() ==
dtype);

// 3.register
EXPECT_TRUE(pten::KernelFactory::Instance().kernels().end() !=
pten::KernelFactory::Instance().kernels().find("dot"));

pten::KernelKey kernel_key(backend, layout, dtype);
EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) ==
pten::KernelFactory::Instance().kernels()["dot"].end());

EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() ==
pten::DataType::FLOAT32);
// index 5
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() ==
backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() ==
layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() ==
pten::DataType::UINT8);

// 3.before register
auto& kernel_factory_instance = pten::KernelFactory::Instance();
auto& kernels = pten::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name));

// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto& fake_dot_kernels = kernels[op_name];

EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) ==
fake_dot_kernels.end());

// register
paddle::framework::RegisterKernelWithMetaInfoMap(
paddle::OpKernelInfoMap::Instance());

EXPECT_TRUE(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) !=
pten::KernelFactory::Instance().kernels()["dot"].end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) !=
fake_dot_kernels.end());

// 4.kernel select
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
op_name, kernel_key);
auto kernel = kernel_factory_instance.SelectKernelOrThrowError(
op_name, pten::KernelKey(backend, layout, pten::DataType::UINT8));

// 5.prepare parameters for kernel
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
Expand Down Expand Up @@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper
TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper;
std::string op_name = "dot";
std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8;
pten::DataType dtype = pten::DataType::FLOAT32;

auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0];

Expand All @@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper::GetKernelKey(op_kernel_info));

paddle::CustomKernelFunc kernel_fn =
PD_PT_KERNEL(custom_kernel::FakeDot<uint8_t>);
PD_PT_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info));

void* variadic_func = PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<uint8_t>);
void* variadic_func =
PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(variadic_func,
OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info));

Expand Down
Loading