diff --git a/paddle/phi/backends/xpu/xpu1_op_list.cc b/paddle/phi/backends/xpu/xpu1_op_list.cc index f99805c0959927..52c9661e9f55ac 100644 --- a/paddle/phi/backends/xpu/xpu1_op_list.cc +++ b/paddle/phi/backends/xpu/xpu1_op_list.cc @@ -21,7 +21,10 @@ namespace xpu { XPUOpMap& get_kl1_ops() { // KL1支持的op,通过op_name, data_type static XPUOpMap s_xpu1_kernels{ - {"abs", XPUKernelSet({phi::DataType::FLOAT32})}, + {"abs", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, {"adam", XPUKernelSet({phi::DataType::FLOAT32})}, {"adamw", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -34,6 +37,11 @@ XPUOpMap& get_kl1_ops() { phi::DataType::INT32, phi::DataType::INT64, phi::DataType::BOOL})}, + {"assign_value", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL})}, {"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"bilinear_interp", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -48,13 +56,20 @@ XPUOpMap& get_kl1_ops() { {"cast", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64, + phi::DataType::BOOL, phi::DataType::INT32})}, {"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"coalesce_tensor", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64, phi::DataType::INT32})}, - {"concat", XPUKernelSet({phi::DataType::FLOAT32})}, + {"concat", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::INT64, + phi::DataType::INT32})}, {"concat_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"conv2d", XPUKernelSet({phi::DataType::FLOAT32})}, {"conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -67,20 +82,39 @@ XPUOpMap& get_kl1_ops() { {"dropout_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_allreduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_add", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_add", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32})}, {"elementwise_add_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_div_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_floordiv", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_floordiv", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_max", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_max", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_min", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_min", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_mul", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_sub", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"embedding_with_eltwise_add_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"equal", XPUKernelSet({phi::DataType::INT64})}, @@ -115,14 +149,26 @@ XPUOpMap& get_kl1_ops() { phi::DataType::INT32, phi::DataType::INT64, })}, + {"greater_than", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32})}, {"hard_switch_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_switch", XPUKernelSet({phi::DataType::FLOAT32})}, + {"index_select", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, {"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})}, {"lamb", XPUKernelSet({phi::DataType::FLOAT32})}, {"layer_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"layer_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"less_than", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32})}, {"load", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT8, @@ -206,7 +252,10 @@ XPUOpMap& get_kl1_ops() { {"rnn", XPUKernelSet({phi::DataType::FLOAT32})}, {"roi_align_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"roi_align", XPUKernelSet({phi::DataType::FLOAT32})}, - {"scale", XPUKernelSet({phi::DataType::FLOAT32})}, + {"scale", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32})}, {"sgd", XPUKernelSet({phi::DataType::FLOAT32})}, {"shape", XPUKernelSet({phi::DataType::FLOAT64, @@ -218,7 +267,11 @@ XPUOpMap& get_kl1_ops() { {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"sign", XPUKernelSet({phi::DataType::FLOAT32})}, {"slice_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"slice", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"slice", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, {"softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})}, {"softmax_with_cross_entropy_grad", @@ -306,6 +359,10 @@ XPUOpMap& get_kl1_ops() { phi::DataType::UINT8, phi::DataType::FLOAT32})}, {"where_index", XPUKernelSet({phi::DataType::BOOL})}, + {"where", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT32})}, // AddMore }; diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 65d1f5e59a2e67..0b229631709987 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -26,7 +26,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"add_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"abs", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::INT64})}, {"abs_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -248,15 +253,24 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64, phi::DataType::INT32})}, {"elementwise_floordiv", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_max_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_max", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_min_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_min", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_mul", @@ -665,6 +679,8 @@ XPUOpMap& get_kl2_ops() { {"relu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"relu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"repeat_interleave", + XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, {"reshape2_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc b/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc index 00aee2d41b1537..2e4bf779d26cdd 100644 --- a/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc @@ -121,19 +121,25 @@ PD_REGISTER_KERNEL(floor_divide_raw, ALL_LAYOUT, phi::FloorDivideRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(maximum_raw, XPU, ALL_LAYOUT, phi::MaximumRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(minimum_raw, XPU, ALL_LAYOUT, phi::MinimumRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(remainder_raw, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/xpu/abs_kernel.cc b/paddle/phi/kernels/xpu/abs_kernel.cc index 7abdd1f0715b60..053e6410416839 100644 --- a/paddle/phi/kernels/xpu/abs_kernel.cc +++ b/paddle/phi/kernels/xpu/abs_kernel.cc @@ -31,5 +31,12 @@ void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { } } // namespace phi -PD_REGISTER_KERNEL( - abs, XPU, ALL_LAYOUT, phi::AbsKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(abs, + XPU, + ALL_LAYOUT, + phi::AbsKernel, + float, + phi::dtype::float16, + int8_t, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/elementwise_kernel.cc b/paddle/phi/kernels/xpu/elementwise_kernel.cc index 386ad2e13ff0ed..83dce5437c9ecb 100644 --- a/paddle/phi/kernels/xpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_kernel.cc @@ -82,11 +82,25 @@ PD_REGISTER_KERNEL(floor_divide, ALL_LAYOUT, phi::FloorDivideKernel, float, - phi::dtype::float16) {} -PD_REGISTER_KERNEL( - maximum, XPU, ALL_LAYOUT, phi::MaximumKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - minimum, XPU, ALL_LAYOUT, phi::MinimumKernel, float, phi::dtype::float16) {} + phi::dtype::float16, + int32_t, + int64_t) {} +PD_REGISTER_KERNEL(maximum, + XPU, + ALL_LAYOUT, + phi::MaximumKernel, + float, + phi::dtype::float16, + int32_t, + int64_t) {} +PD_REGISTER_KERNEL(minimum, + XPU, + ALL_LAYOUT, + phi::MinimumKernel, + float, + phi::dtype::float16, + int32_t, + int64_t) {} PD_REGISTER_KERNEL(remainder, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/xpu/index_select_kernel.cc b/paddle/phi/kernels/xpu/index_select_kernel.cc index cbe6e99c43ae9e..75c19aa028bce7 100644 --- a/paddle/phi/kernels/xpu/index_select_kernel.cc +++ b/paddle/phi/kernels/xpu/index_select_kernel.cc @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/phi/kernels/index_select_kernel.h" - #include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" @@ -40,14 +40,33 @@ void IndexSelectKernel(const Context& ctx, index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto* in_data = x.data(); std::vector in_shape = phi::vectorize(input_dim); int index_len = output->dims()[dim]; T* out_data = ctx.template Alloc(output); int r = 0; + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + int8_t* index_ptr = nullptr; // temp xpu buffer + int byte_times = SizeOf(index_type); + if (index.place() == CPUPlace()) { + index_ptr = RAII_GUARD.alloc_l3_or_gm(byte_times * index.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(index_ptr); + const void* cpu_idx_data = nullptr; + if (index_type == phi::DataType::INT64) { + cpu_idx_data = reinterpret_cast(index.data()); + } else if (index_type == phi::DataType::INT32) { + cpu_idx_data = reinterpret_cast(index.data()); + } + memory_utils::Copy(ctx.GetPlace(), + reinterpret_cast(index_ptr), + CPUPlace(), + cpu_idx_data, + byte_times * index.numel()); + } if (index_type == phi::DataType::INT64) { - const int64_t* index_data = index.data(); + const int64_t* index_data = + index_ptr ? reinterpret_cast(index_ptr) + : index.template data(); r = xpu::gather(ctx.x_context(), in_data, index_data, @@ -56,7 +75,8 @@ void IndexSelectKernel(const Context& ctx, index_len, dim); } else { - const int* index_data = index.data(); + const int* index_data = index_ptr ? reinterpret_cast(index_ptr) + : index.template data(); r = xpu::gather(ctx.x_context(), in_data, index_data,