diff --git a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc index 263ea5a09ca76..c80fda30990ff 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc @@ -34,6 +34,25 @@ class Scope; } // namespace framework } // namespace paddle +namespace { + +template +void ConvertTensorType(phi::DenseTensor* tensor) { + phi::DenseTensor tmp_tensor; + tmp_tensor.set_type(phi::CppTypeToDataType::Type()); + tmp_tensor.Resize(tensor->dims()); + auto* tmp_data = tmp_tensor.mutable_data(paddle::platform::CPUPlace()); + auto* data = tensor->mutable_data(paddle::platform::CPUPlace()); + for (int i = 0; i < tensor->numel(); i++) { + tmp_data[i] = static_cast(data[i]); + } + tensor->clear(); + paddle::framework::TensorCopySync( + tmp_tensor, paddle::platform::CPUPlace(), tensor); +} + +} // namespace + namespace paddle { namespace framework { namespace ir { @@ -157,15 +176,23 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, if (with_bn_) { ew_bias_add_out->assert_is_op_input("batch_norm", "X"); bn_bias = pattern->NewNode(bn_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() ->assert_is_op_input("batch_norm", "Bias") ->assert_has_n_outputs(1); bn_mean = pattern->NewNode(bn_mean_repr()) + ->AsInput() + ->assert_is_persistable_var() ->assert_is_op_input("batch_norm", "Mean") ->assert_has_n_outputs(1); bn_scale = pattern->NewNode(bn_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() ->assert_is_op_input("batch_norm", "Scale") ->assert_has_n_outputs(1); bn_var = pattern->NewNode(bn_var_repr()) + ->AsInput() + ->assert_is_persistable_var() ->assert_is_op_input("batch_norm", "Variance") ->assert_has_n_outputs(1); bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); @@ -420,13 +447,17 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, // recompute bias and weight for conv2d_xpu op auto* filter_t = scope->FindVar(conv_filter->Name())->GetMutable(); + // conv_filter fp16 --> fp32 + auto tensor_type = filter_t->dtype(); + if (tensor_type == phi::DataType::FLOAT16) { + ConvertTensorType(filter_t); + } auto filter_dims = filter_t->dims(); bool has_bias = with_bn || with_conv_bias; - bool has_branch = with_branch_x || with_branch_y; // Create conv_fusion_bias (conv bias) variable Node* fusion_bias_node = nullptr; if (has_bias) { - if (ew_bias_add != nullptr) { + if (with_conv_bias) { auto* ew_bias_add_y_t = scope->FindVar(ew_bias_add_y->Name()) ->GetMutable(); auto ew_bias_add_y_dims = ew_bias_add_y_t->dims(); @@ -439,7 +470,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, filter_dims[0])); PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node); } - if (bn != nullptr) { + if (with_bn) { auto bn_bias_t = scope->Var(bn_bias->Name())->GetMutable(); PADDLE_ENFORCE_EQ(filter_dims[0], @@ -469,7 +500,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, auto filter_len = filter_t->numel(); auto filter_stride = filter_len / mean_len; float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); - if (fusion_bias_node == nullptr) { // prev node is conv + if (!with_conv_bias) { // prev node is conv PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); } auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) @@ -477,10 +508,10 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, float* fusion_bias_ptr = fusion_bias_t->mutable_data(paddle::platform::CPUPlace()); // recompute bias and weights - if (ew_bias_add == nullptr) { + if (!with_conv_bias) { // prev node is conv for (int i = 0; i < mean_len; ++i) { bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i]; + fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i]; for (int j = 0; j < filter_stride; j++) { filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; } @@ -488,21 +519,25 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, } else { for (int i = 0; i < mean_len; ++i) { bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); - bn_bias_ptr[i] += + fusion_bias_ptr[i] = + bn_bias_ptr[i] + (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; for (int j = 0; j < filter_stride; j++) { filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; } } - memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float)); } } } + if (tensor_type == phi::DataType::FLOAT16) { + ConvertTensorType(filter_t); + } // filter max Node* filter_int16 = nullptr; Node* filter_max = nullptr; PrepareWeight( graph, scope, block, conv_filter, &filter_int16, &filter_max, false); + bool has_branch = with_branch_x || with_branch_y; // output && output max std::string conv2d_xpu_out_name; if (!act_type.empty()) { diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 87400ecd61c9f..d5f9a9e0d481d 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -5,14 +5,14 @@ # otherwise the operator only could be used in static mode. - op : conv2d_xpu - args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param) + args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param) output : Tensor(out), Tensor(out_max) infer_meta : func : Conv2dXPUInferMeta kernel : func : conv2d_xpu data_type : x - optional : bias, branch, x_max + optional : bias, branch, branch_max ,x_max - op : embedding_with_eltwise_add_xpu args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 07ad62e442d76..c05b51ac7bdd5 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -58,7 +58,8 @@ XPUOpMap& get_kl2_ops() { {"atan_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"batch_norm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bmm_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -401,7 +402,8 @@ XPUOpMap& get_kl2_ops() { {"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_sigmoid", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"hard_swish_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -438,7 +440,8 @@ XPUOpMap& get_kl2_ops() { {"layer_norm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"leaky_relu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"less_equal", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, @@ -554,7 +557,8 @@ XPUOpMap& get_kl2_ops() { {"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_max", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_mean", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -646,7 +650,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64, phi::DataType::INT32, phi::DataType::FLOAT16})}, - {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sigmoid", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sign", XPUKernelSet({phi::DataType::FLOAT32})}, {"slice_grad", @@ -676,7 +681,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::INT32})}, - {"sqrt", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -733,7 +738,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT16, phi::DataType::INT32})}, {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"swish", XPUKernelSet({phi::DataType::FLOAT32})}, + {"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"tanh_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 262c5fb04e19f..f32ee528075db 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -41,6 +41,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const MetaTensor& filter_max, const MetaTensor& bias, const MetaTensor& branch, + const MetaTensor& branch_max, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 38f4bc8c6c5be..0fddf3995ffa0 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -28,6 +28,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const MetaTensor& filter_max, const MetaTensor& bias, const MetaTensor& branch, + const MetaTensor& branch_max, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, diff --git a/paddle/phi/kernels/activation_kernel.cc b/paddle/phi/kernels/activation_kernel.cc index ef6135d25c99d..ab367a672696f 100644 --- a/paddle/phi/kernels/activation_kernel.cc +++ b/paddle/phi/kernels/activation_kernel.cc @@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(swish, #if defined PADDLE_WITH_XPU PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {} -PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {} +PD_REGISTER_KERNEL( + swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {} #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/phi/kernels/batch_norm_kernel.cc b/paddle/phi/kernels/batch_norm_kernel.cc index 570ba8dae06cf..bf04c99dab0a3 100644 --- a/paddle/phi/kernels/batch_norm_kernel.cc +++ b/paddle/phi/kernels/batch_norm_kernel.cc @@ -106,6 +106,10 @@ PD_REGISTER_KERNEL(batch_norm_infer, phi::dtype::float16) {} #endif #ifdef PADDLE_WITH_XPU -PD_REGISTER_KERNEL( - batch_norm_infer, XPU, ALL_LAYOUT, phi::BatchNormInferKernel, float) {} +PD_REGISTER_KERNEL(batch_norm_infer, + XPU, + ALL_LAYOUT, + phi::BatchNormInferKernel, + float, + phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc index 0f7d8902de328..f82d9fdd9fdea 100644 --- a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc @@ -27,6 +27,7 @@ void Conv2dXPUKernel(const Context& ctx, const DenseTensor& filter_max, const paddle::optional& bias, const paddle::optional& branch, + const paddle::optional& branch_max, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, @@ -69,10 +70,12 @@ void Conv2dXPUKernel(const Context& ctx, branch.get_ptr() == nullptr ? nullptr : reinterpret_cast(branch.get_ptr()->data()); + const float* branch_max_data = branch_max.get_ptr() == nullptr + ? nullptr + : branch_max.get_ptr()->data(); const float* bias_data = bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); auto* out_data = reinterpret_cast(ctx.template Alloc(out)); - xpu::Activation_t act(static_cast(act_type)); if (act_type == xpu::Activation_t::LEAKY_RELU) { act.leaky_alpha = act_param; @@ -102,7 +105,7 @@ void Conv2dXPUKernel(const Context& ctx, /* const float* bias */ bias_data, /* const TY* branch */ branch_data, /* const baidu::xpu::api::Activation_t& act */ act, - /* const float* branch_maxptr */ nullptr, + /* const float* branch_maxptr */ branch_max_data, /* const float* scale */ nullptr); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu"); } diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 490c56d13127f..41a1cceb11935 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -195,6 +195,13 @@ void PowKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& factor, DenseTensor* out) { + // using XPUType = typename XPUTypeTrait::Type; + // // dev_ctx.template Alloc(out); + // auto pow_factor = factor.to(); + // const auto* x_data = reinterpret_cast(x.data()); + // auto* y_data = reinterpret_cast(dev_ctx.template Alloc(out)); + // // const T* x_data = x.data(); + // // T* y_data = out->data(); dev_ctx.template Alloc(out); float pow_factor = factor.to(); const T* x_data = x.data(); @@ -534,9 +541,28 @@ PD_REGISTER_KERNEL( relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL( silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {} - -#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ - PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} +PD_REGISTER_KERNEL( + sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(swish_raw, + XPU, + ALL_LAYOUT, + phi::SwishRawKernel, + float, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(hard_sigmoid, + XPU, + ALL_LAYOUT, + phi::HardSigmoidKernel, + float, + phi::dtype::float16) {} +PD_REGISTER_KERNEL(leaky_relu, + XPU, + ALL_LAYOUT, + phi::LeakyReluKernel, + float, + phi::dtype::float16) {} +PD_REGISTER_KERNEL( + sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL( tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} @@ -547,18 +573,21 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {} +#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ + PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} + PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) -PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) -PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) +// PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) +// PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) -PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) -PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) -PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) +// PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) +// PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) +// PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel) diff --git a/paddle/phi/kernels/xpu/batch_norm_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_kernel.cc index b3f2843315c15..b95dda1fed13d 100644 --- a/paddle/phi/kernels/xpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_kernel.cc @@ -39,6 +39,7 @@ void BatchNormKernel(const Context& dev_ctx, DenseTensor* saved_mean, DenseTensor* saved_variance, DenseTensor* reserve_space) { + using XPUType = typename XPUTypeTrait::Type; bool test_mode = is_test && (!trainable_statistics); bool global_stats = test_mode || use_global_stats; const auto data_layout = phi::StringToDataLayout(data_layout_str); @@ -68,12 +69,12 @@ void BatchNormKernel(const Context& dev_ctx, W = W * D; - const auto* x_data = x.data(); + const auto* x_data = reinterpret_cast(x.data()); const auto* scale_data = scale.data(); const auto* bias_data = bias.data(); // alloc memory - auto* y_data = dev_ctx.template Alloc(y); + auto* y_data = reinterpret_cast(dev_ctx.template Alloc(y)); dev_ctx.template Alloc(mean_out); dev_ctx.template Alloc(variance_out); dev_ctx.template Alloc(saved_mean); @@ -95,43 +96,48 @@ void BatchNormKernel(const Context& dev_ctx, auto* saved_mean_data = saved_mean->data(); auto* saved_variance_data = saved_variance->data(); - int r = xpu::batch_norm(dev_ctx.x_context(), - x_data, - y_data, - N, - C, - H, - W, - epsilon, - momentum, - scale_data, - bias_data, - saved_mean_data, - saved_variance_data, - mean_out_data, - variance_out_data, - is_nchw); + int r = xpu::batch_norm(dev_ctx.x_context(), + x_data, + y_data, + N, + C, + H, + W, + epsilon, + momentum, + scale_data, + bias_data, + saved_mean_data, + saved_variance_data, + mean_out_data, + variance_out_data, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm"); } else { const auto* mean_data = mean.data(); const auto* variance_data = variance.data(); - int r = xpu::batch_norm_infer(dev_ctx.x_context(), - x_data, - y_data, - N, - C, - H, - W, - epsilon, - scale_data, - bias_data, - mean_data, - variance_data, - is_nchw); + int r = xpu::batch_norm_infer(dev_ctx.x_context(), + x_data, + y_data, + N, + C, + H, + W, + epsilon, + scale_data, + bias_data, + mean_data, + variance_data, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer"); } } } // namespace phi -PD_REGISTER_KERNEL(batch_norm, XPU, ALL_LAYOUT, phi::BatchNormKernel, float) {} +PD_REGISTER_KERNEL(batch_norm, + XPU, + ALL_LAYOUT, + phi::BatchNormKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc index b9ba5178a5dd7..cb0bfb6218a88 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc @@ -50,4 +50,6 @@ void MeanRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(mean_raw, XPU, ALL_LAYOUT, phi::MeanRawKernel, float) {} +PD_REGISTER_KERNEL( + mean_raw, XPU, ALL_LAYOUT, phi::MeanRawKernel, float, phi::dtype::float16) { +}