Skip to content

Commit

Permalink
[XPU] Support elementwise_min/max/floordiv/where more precision types…
Browse files Browse the repository at this point in the history
… for KL1 and KL2 (PaddlePaddle#58422)
  • Loading branch information
newway authored and zeroRains committed Nov 8, 2023
1 parent d01915a commit c8985fa
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 27 deletions.
75 changes: 66 additions & 9 deletions paddle/phi/backends/xpu/xpu1_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ namespace xpu {
XPUOpMap& get_kl1_ops() {
// KL1支持的op,通过op_name, data_type
static XPUOpMap s_xpu1_kernels{
{"abs", XPUKernelSet({phi::DataType::FLOAT32})},
{"abs",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
{"adam", XPUKernelSet({phi::DataType::FLOAT32})},
{"adamw", XPUKernelSet({phi::DataType::FLOAT32})},
Expand All @@ -34,6 +37,11 @@ XPUOpMap& get_kl1_ops() {
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"assign_value",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp", XPUKernelSet({phi::DataType::FLOAT32})},
Expand All @@ -48,13 +56,20 @@ XPUOpMap& get_kl1_ops() {
{"cast",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::BOOL,
phi::DataType::INT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT64,
phi::DataType::INT32})},
{"concat", XPUKernelSet({phi::DataType::FLOAT32})},
{"concat",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::INT64,
phi::DataType::INT32})},
{"concat_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})},
Expand All @@ -67,20 +82,39 @@ XPUOpMap& get_kl1_ops() {
{"dropout_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_add", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_add",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
{"elementwise_add_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_div_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_floordiv", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_floordiv",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_max",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_min",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_mul", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"embedding_with_eltwise_add_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"equal", XPUKernelSet({phi::DataType::INT64})},
Expand Down Expand Up @@ -115,14 +149,26 @@ XPUOpMap& get_kl1_ops() {
phi::DataType::INT32,
phi::DataType::INT64,
})},
{"greater_than",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32})},
{"hard_switch_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_switch", XPUKernelSet({phi::DataType::FLOAT32})},
{"index_select",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})},
{"lamb", XPUKernelSet({phi::DataType::FLOAT32})},
{"layer_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"layer_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"less_than",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT32})},
{"load",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT8,
Expand Down Expand Up @@ -206,7 +252,10 @@ XPUOpMap& get_kl1_ops() {
{"rnn", XPUKernelSet({phi::DataType::FLOAT32})},
{"roi_align_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"roi_align", XPUKernelSet({phi::DataType::FLOAT32})},
{"scale", XPUKernelSet({phi::DataType::FLOAT32})},
{"scale",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::INT32})},
{"sgd", XPUKernelSet({phi::DataType::FLOAT32})},
{"shape",
XPUKernelSet({phi::DataType::FLOAT64,
Expand All @@ -218,7 +267,11 @@ XPUOpMap& get_kl1_ops() {
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"sign", XPUKernelSet({phi::DataType::FLOAT32})},
{"slice_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"slice",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16})},
{"softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax_with_cross_entropy_grad",
Expand Down Expand Up @@ -306,6 +359,10 @@ XPUOpMap& get_kl1_ops() {
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
{"where_index", XPUKernelSet({phi::DataType::BOOL})},
{"where",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
// AddMore
};

Expand Down
24 changes: 20 additions & 4 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"add_layernorm_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT8,
phi::DataType::INT64})},
{"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down Expand Up @@ -248,15 +253,24 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64,
phi::DataType::INT32})},
{"elementwise_floordiv",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_max_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_max",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_min_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_min",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"elementwise_mul_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_mul",
Expand Down Expand Up @@ -665,6 +679,8 @@ XPUOpMap& get_kl2_ops() {
{"relu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"relu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"repeat_interleave",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"reshape2_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,25 @@ PD_REGISTER_KERNEL(floor_divide_raw,
ALL_LAYOUT,
phi::FloorDivideRawKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(maximum_raw,
XPU,
ALL_LAYOUT,
phi::MaximumRawKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(minimum_raw,
XPU,
ALL_LAYOUT,
phi::MinimumRawKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(remainder_raw,
XPU,
ALL_LAYOUT,
Expand Down
11 changes: 9 additions & 2 deletions paddle/phi/kernels/xpu/abs_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,12 @@ void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
}
} // namespace phi

PD_REGISTER_KERNEL(
abs, XPU, ALL_LAYOUT, phi::AbsKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(abs,
XPU,
ALL_LAYOUT,
phi::AbsKernel,
float,
phi::dtype::float16,
int8_t,
int32_t,
int64_t) {}
24 changes: 19 additions & 5 deletions paddle/phi/kernels/xpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,25 @@ PD_REGISTER_KERNEL(floor_divide,
ALL_LAYOUT,
phi::FloorDivideKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
maximum, XPU, ALL_LAYOUT, phi::MaximumKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
minimum, XPU, ALL_LAYOUT, phi::MinimumKernel, float, phi::dtype::float16) {}
phi::dtype::float16,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(maximum,
XPU,
ALL_LAYOUT,
phi::MaximumKernel,
float,
phi::dtype::float16,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(minimum,
XPU,
ALL_LAYOUT,
phi::MinimumKernel,
float,
phi::dtype::float16,
int32_t,
int64_t) {}
PD_REGISTER_KERNEL(remainder,
XPU,
ALL_LAYOUT,
Expand Down
28 changes: 24 additions & 4 deletions paddle/phi/kernels/xpu/index_select_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

#include "paddle/phi/kernels/index_select_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"

Expand All @@ -40,14 +40,33 @@ void IndexSelectKernel(const Context& ctx,
index_type,
phi::DataType::INT32,
phi::DataType::INT64));

auto* in_data = x.data<T>();
std::vector<int> in_shape = phi::vectorize<int>(input_dim);
int index_len = output->dims()[dim];
T* out_data = ctx.template Alloc<T>(output);
int r = 0;
xpu::ctx_guard RAII_GUARD(ctx.x_context());
int8_t* index_ptr = nullptr; // temp xpu buffer
int byte_times = SizeOf(index_type);
if (index.place() == CPUPlace()) {
index_ptr = RAII_GUARD.alloc_l3_or_gm<int8_t>(byte_times * index.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(index_ptr);
const void* cpu_idx_data = nullptr;
if (index_type == phi::DataType::INT64) {
cpu_idx_data = reinterpret_cast<const void*>(index.data<int64_t>());
} else if (index_type == phi::DataType::INT32) {
cpu_idx_data = reinterpret_cast<const void*>(index.data<int>());
}
memory_utils::Copy(ctx.GetPlace(),
reinterpret_cast<void*>(index_ptr),
CPUPlace(),
cpu_idx_data,
byte_times * index.numel());
}
if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
const int64_t* index_data =
index_ptr ? reinterpret_cast<const int64_t*>(index_ptr)
: index.template data<int64_t>();
r = xpu::gather<T, int64_t>(ctx.x_context(),
in_data,
index_data,
Expand All @@ -56,7 +75,8 @@ void IndexSelectKernel(const Context& ctx,
index_len,
dim);
} else {
const int* index_data = index.data<int>();
const int* index_data = index_ptr ? reinterpret_cast<const int*>(index_ptr)
: index.template data<int>();
r = xpu::gather<T, int>(ctx.x_context(),
in_data,
index_data,
Expand Down

0 comments on commit c8985fa

Please sign in to comment.