diff --git a/lite/backends/arm/math/fp16/interpolate_fp16.cc b/lite/backends/arm/math/fp16/interpolate_fp16.cc index 13e9bd61fca..03488e024ab 100644 --- a/lite/backends/arm/math/fp16/interpolate_fp16.cc +++ b/lite/backends/arm/math/fp16/interpolate_fp16.cc @@ -411,13 +411,39 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type) { + std::string interpolate_type, + std::vector scale_data) { int in_h = X->dims()[2]; int in_w = X->dims()[3]; + float height_scale = 0.f; + float width_scale = 0.f; + if (SizeTensor.size() > 0) { auto new_size = get_new_shape(SizeTensor); out_height = new_size[0]; out_width = new_size[1]; + } else if (scale_data.size() > 0) { + if (scale_data.size() == 1) { + if (scale_data[0] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[0]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } else if (scale_data.size() == 2) { + if (scale_data[0] > 0 && scale_data[1] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[1]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } + auto out_size = OutSize; + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = out_size_data[0]; + out_width = out_size_data[1]; + } } else { auto scale_tensor = Scale; if (scale_tensor != nullptr) { @@ -435,11 +461,14 @@ void interpolate(lite::Tensor* X, out_width = out_size_data[1]; } } - float height_scale = scale; - float width_scale = scale; + height_scale = scale; + width_scale = scale; if (out_width > 0 && out_height > 0) { height_scale = static_cast(out_height / X->dims()[2]); width_scale = static_cast(out_width / X->dims()[3]); + } else { + out_height = static_cast(X->dims()[2] * height_scale + 0.5f); + out_width = static_cast(X->dims()[3] * width_scale + 0.5f); } int num_cout = X->dims()[0]; int c_cout = X->dims()[1]; diff --git a/lite/backends/arm/math/fp16/interpolate_fp16.h b/lite/backends/arm/math/fp16/interpolate_fp16.h index 7ac96cbdf45..c78fa414da7 100644 --- a/lite/backends/arm/math/fp16/interpolate_fp16.h +++ b/lite/backends/arm/math/fp16/interpolate_fp16.h @@ -68,7 +68,8 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type); + std::string interpolate_type, + std::vector scale_data); } // namespace fp16 } // namespace math diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index fb935d00218..0bff70459b3 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -509,13 +509,39 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type) { + std::string interpolate_type, + std::vector scale_data) { int in_h = X->dims()[2]; int in_w = X->dims()[3]; + float height_scale = 0.f; + float width_scale = 0.f; + if (SizeTensor.size() > 0) { auto new_size = get_new_shape(SizeTensor); out_height = new_size[0]; out_width = new_size[1]; + } else if (scale_data.size() > 0) { + if (scale_data.size() == 1) { + if (scale_data[0] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[0]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } else if (scale_data.size() == 2) { + if (scale_data[0] > 0 && scale_data[1] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[1]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } + auto out_size = OutSize; + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = out_size_data[0]; + out_width = out_size_data[1]; + } } else { auto scale_tensor = Scale; if (scale_tensor != nullptr) { @@ -533,11 +559,14 @@ void interpolate(lite::Tensor* X, out_width = out_size_data[1]; } } - float height_scale = scale; - float width_scale = scale; + height_scale = scale; + width_scale = scale; if (out_width > 0 && out_height > 0) { height_scale = static_cast(out_height / X->dims()[2]); width_scale = static_cast(out_width / X->dims()[3]); + } else { + out_height = static_cast(X->dims()[2] * height_scale + 0.5f); + out_width = static_cast(X->dims()[3] * width_scale + 0.5f); } int num_cout = X->dims()[0]; int c_cout = X->dims()[1]; diff --git a/lite/backends/arm/math/interpolate.h b/lite/backends/arm/math/interpolate.h index 5c37670ec57..a4db2adf43c 100644 --- a/lite/backends/arm/math/interpolate.h +++ b/lite/backends/arm/math/interpolate.h @@ -52,7 +52,8 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type); + std::string interpolate_type, + std::vector scale_data); } /* namespace math */ } /* namespace arm */ diff --git a/lite/backends/x86/math/avx/conv_utils.cc b/lite/backends/x86/math/avx/conv_utils.cc index 814e686f62a..4692d179162 100644 --- a/lite/backends/x86/math/avx/conv_utils.cc +++ b/lite/backends/x86/math/avx/conv_utils.cc @@ -861,31 +861,36 @@ void im2col_s1(const float* data_im, (width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) + 1; const int in_channel_size = height * width; const int out_channel_size = output_h * output_w; - const int output_plane_size = output_h * output_w * kernel_h * kernel_w; - memset(data_col, 0, output_plane_size * channels * sizeof(float)); + const unsigned int output_plane_size = + output_h * output_w * kernel_h * kernel_w; + size_t tmp_size = static_cast(output_plane_size); + size_t mem_size = tmp_size * channels * sizeof(float); + memset(data_col, 0, mem_size); #pragma omp parallel for for (int c = 0; c < channels; c++) { - int data_im_z = c * in_channel_size; - int data_col_z1 = c * output_plane_size; + unsigned int data_im_z = c * in_channel_size; + unsigned int data_col_z1 = c * output_plane_size; for (int ky = 0, h_offset = 0; ky < kernel_h; ky++, h_offset += dilation_h) { - int data_col_z2 = ky * out_channel_size * kernel_w; + unsigned int data_col_z2 = ky * out_channel_size * kernel_w; for (int kx = 0, w_offset = 0; kx < kernel_w; kx++, w_offset += dilation_w) { - int data_col_z3 = kx * out_channel_size; - int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; - int oh_begin = std::max(((pad_top - h_offset)), 0); - int oh_end = std::min(((height + pad_bottom - h_offset)), output_h); + unsigned int data_col_z3 = kx * out_channel_size; + unsigned int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; + unsigned int oh_begin = std::max(((pad_top - h_offset)), 0); + unsigned int oh_end = + std::min(((height + pad_bottom - h_offset)), output_h); oh_end = std::max(oh_begin, oh_end); - int ow_begin = std::max(((pad_left - w_offset)), 0); - int ow_end = std::min(((width + pad_right - w_offset)), output_w); + unsigned int ow_begin = std::max(((pad_left - w_offset)), 0); + unsigned int ow_end = + std::min(((width + pad_right - w_offset)), output_w); ow_end = std::max(ow_begin, ow_end); - int ih = oh_begin - pad_top + h_offset; + unsigned int ih = oh_begin - pad_top + h_offset; for (int oh = oh_begin; oh < oh_end; ++oh, ++ih) { - int iw = ow_begin - pad_left + w_offset; - int ow = ow_begin; - int data_im_offset = data_im_z + ih * width; - int data_col_offset = data_col_z + oh * output_w; + unsigned int iw = ow_begin - pad_left + w_offset; + unsigned int ow = ow_begin; + unsigned int data_im_offset = data_im_z + ih * width; + unsigned int data_col_offset = data_col_z + oh * output_w; const float* data_im_ptr = data_im + data_im_offset; float* data_col_ptr = data_col + data_col_offset; #ifdef __AVX__ @@ -929,33 +934,36 @@ void im2col_s2(const float* data_im, (width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) / 2 + 1; const int in_channel_size = height * width; - const int output_plane_size = output_h * output_w * kernel_h * kernel_w; - memset(data_col, 0, output_plane_size * channels * sizeof(float)); + const unsigned int output_plane_size = + output_h * output_w * kernel_h * kernel_w; + size_t tmp_size = static_cast(output_plane_size); + size_t mem_size = tmp_size * channels * sizeof(float); + memset(data_col, 0, mem_size); #pragma omp parallel for for (int c = 0; c < channels; c++) { - int data_im_z = c * in_channel_size; - int data_col_z1 = c * output_plane_size; + unsigned int data_im_z = c * in_channel_size; + unsigned int data_col_z1 = c * output_plane_size; for (int ky = 0, h_offset = 0; ky < kernel_h; ky++, h_offset += dilation_h) { - int data_col_z2 = ky * output_h * output_w * kernel_w; + unsigned int data_col_z2 = ky * output_h * output_w * kernel_w; for (int kx = 0, w_offset = 0; kx < kernel_w; kx++, w_offset += dilation_w) { - int data_col_z3 = kx * output_h * output_w; - int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; - int oh_begin = std::max(((pad_top - h_offset + 1) / 2), 0); - int oh_end = + unsigned int data_col_z3 = kx * output_h * output_w; + unsigned int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; + unsigned int oh_begin = std::max(((pad_top - h_offset + 1) / 2), 0); + unsigned int oh_end = std::min(((height + pad_bottom - h_offset + 1) / 2), output_h); oh_end = std::max(oh_begin, oh_end); - int ow_begin = std::max(((pad_left - w_offset + 1) / 2), 0); - int ow_end = + unsigned int ow_begin = std::max(((pad_left - w_offset + 1) / 2), 0); + unsigned int ow_end = std::min(((width + pad_right - w_offset + 1) / 2), output_w); ow_end = std::max(ow_begin, ow_end); - int ih = oh_begin * 2 - pad_top + h_offset; + unsigned int ih = oh_begin * 2 - pad_top + h_offset; for (int oh = oh_begin; oh < oh_end; ++oh, ih += 2) { - int iw = ow_begin * 2 - pad_left + w_offset; - int ow = ow_begin; - int data_im_offset = data_im_z + ih * width; - int data_col_offset = data_col_z + oh * output_w; + unsigned int iw = ow_begin * 2 - pad_left + w_offset; + unsigned int ow = ow_begin; + unsigned int data_im_offset = data_im_z + ih * width; + unsigned int data_col_offset = data_col_z + oh * output_w; const float* data_im_ptr = data_im + data_im_offset; float* data_col_ptr = data_col + data_col_offset; for (; ow + 3 < ow_end; ow += 4, iw += 8) { diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 6602aedaf02..1dfa44b98dc 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -37,6 +37,7 @@ add_kernel(transpose_compute_arm ARM basic SRCS transpose_compute.cc) add_kernel(shuffle_channel_compute_arm ARM basic SRCS shuffle_channel_compute.cc) add_kernel(argmax_compute_arm ARM basic SRCS argmax_compute.cc) add_kernel(conv_transpose_compute_arm ARM basic SRCS conv_transpose_compute.cc) +add_kernel(depthwise_conv_transpose_compute_arm ARM extra SRCS depthwise_conv_transpose_compute.cc) add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc) add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc) add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc) diff --git a/lite/kernels/arm/depthwise_conv_transpose_compute.cc b/lite/kernels/arm/depthwise_conv_transpose_compute.cc new file mode 100644 index 00000000000..b78ad73f598 --- /dev/null +++ b/lite/kernels/arm/depthwise_conv_transpose_compute.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/depthwise_conv_transpose_compute.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm {} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTransFp32; +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTranInt8_Fp32; +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTranInt8_Int8; + +#ifdef ENABLE_ARM_FP16 +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTranFp16; + +REGISTER_LITE_KERNEL( + depthwise_conv2d_transpose, kARM, kFP16, kNCHW, DepConvTranFp16, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); + +#endif // ENABLE_ARM_FP16 + +REGISTER_LITE_KERNEL( + depthwise_conv2d_transpose, kARM, kFloat, kNCHW, DepConvTransFp32, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d_transpose, + kARM, + kInt8, + kNCHW, + DepConvTranInt8_Fp32, + fp32_out) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d_transpose, + kARM, + kInt8, + kNCHW, + DepConvTranInt8_Int8, + int8_out) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); diff --git a/lite/kernels/arm/depthwise_conv_transpose_compute.h b/lite/kernels/arm/depthwise_conv_transpose_compute.h new file mode 100644 index 00000000000..b70908acc33 --- /dev/null +++ b/lite/kernels/arm/depthwise_conv_transpose_compute.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/kernel.h" +#include "lite/operators/conv_transpose_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { +template +class DepthwiseConv2DTransposeCompute : public KernelLite { +}; +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/interpolate_compute.cc b/lite/kernels/arm/interpolate_compute.cc index 0a335e06e4e..6bc6d1181c0 100644 --- a/lite/kernels/arm/interpolate_compute.cc +++ b/lite/kernels/arm/interpolate_compute.cc @@ -39,11 +39,12 @@ namespace arm { int out_h = param.out_h; \ bool align_corners = param.align_corners; \ int align_mode = param.align_mode; \ + auto scale_v = param.scale_v; \ std::string interp_method = method_name; #define INTERP_PARAM \ X, OutSize, SizeTensor, Scale, Out, out_h, out_w, scale, align_corners, \ - align_mode, interp_method + align_mode, interp_method, scale_v template <> void BilinearInterpCompute::Run() { diff --git a/lite/kernels/x86/conv_compute.cc b/lite/kernels/x86/conv_compute.cc index a1ea4585644..0a33848f17a 100644 --- a/lite/kernels/x86/conv_compute.cc +++ b/lite/kernels/x86/conv_compute.cc @@ -118,11 +118,11 @@ void Conv2dCompute::Run() { auto& ctx = ctx_->As(); INIT_PARAM bool flag_bias = (param.bias != nullptr); - int group_size_out = m * n; - int group_size_weights = m * k; - int group_size_coldata = n * k; - int channel_in_size = chin * hin * win; - int channel_out_size = chout * hout * wout; + unsigned int group_size_out = m * n; + unsigned int group_size_weights = m * k; + unsigned int group_size_coldata = n * k; + unsigned int channel_in_size = chin * hin * win; + unsigned int channel_out_size = chout * hout * wout; auto paddings = *param.paddings; auto dilations = *param.dilations; @@ -135,9 +135,9 @@ void Conv2dCompute::Run() { float* col_data = nullptr; if (!flag_1x1gemm_) { - int col_size = group * group_size_coldata; - col_data = static_cast( - TargetMalloc(TARGET(kX86), col_size * sizeof(float))); + size_t col_size = group_size_coldata * group; + size_t col_data_size = static_cast(col_size * sizeof(float)); + col_data = static_cast(TargetMalloc(TARGET(kX86), col_data_size)); } auto act_param = param.activation_param; paddle::lite::x86::math::Blas matmul(ctx);