From 84c77c2a6842d1d2eead6c374c522e3e84a50429 Mon Sep 17 00:00:00 2001 From: mjp9527 <54735487+mjp9527@users.noreply.github.com> Date: Mon, 27 Sep 2021 11:00:06 +0800 Subject: [PATCH 1/4] [X86/ARM] add gru mode for rnn (#7026) * [X86] Add GRU for RNN, complete elementwise op, move cast from arm to host, fix precision_profile bug * pre-commit * [ARM] add RNN-GRU OP; Optimize RNN-GRU OP * fix complie bug * fix elementwise left problem * merge develop * fix windows ci * change arm cast test to host cast test --- lite/backends/arm/math/gru.h | 248 ++++++ lite/backends/arm/math/lstm.cc | 1 + lite/backends/x86/math/elementwise.h | 39 +- lite/backends/x86/math/fill_bias_activate.cc | 7 - lite/backends/x86/math/rnn.cc | 172 ---- lite/backends/x86/math/rnn.h | 431 +++++++++- lite/core/profile/precision_profiler.h | 6 +- lite/kernels/arm/CMakeLists.txt | 1 - lite/kernels/arm/rnn_compute.cc | 562 +++++++++--- lite/kernels/host/CMakeLists.txt | 1 + lite/kernels/{arm => host}/cast_compute.cc | 23 +- lite/kernels/{arm => host}/cast_compute.h | 6 +- lite/kernels/host/compare_compute.cc | 4 +- lite/kernels/host/expand_v2_compute.cc | 4 +- lite/kernels/x86/CMakeLists.txt | 4 +- lite/kernels/x86/elementwise_compute.cc | 31 + lite/kernels/x86/elementwise_op_function.h | 802 ++++++++++++++++++ lite/kernels/x86/rnn_compute.cc | 568 ++++++++++--- lite/kernels/x86/rnn_compute.h | 2 +- lite/tests/kernels/cast_compute_test.cc | 4 +- .../elementwise_common_broadcast_test.cc | 2 +- 21 files changed, 2441 insertions(+), 477 deletions(-) create mode 100644 lite/backends/arm/math/gru.h delete mode 100644 lite/backends/x86/math/rnn.cc rename lite/kernels/{arm => host}/cast_compute.cc (91%) rename lite/kernels/{arm => host}/cast_compute.h (90%) create mode 100644 lite/kernels/x86/elementwise_op_function.h diff --git a/lite/backends/arm/math/gru.h b/lite/backends/arm/math/gru.h new file mode 100644 index 00000000000..1492c57d6a2 --- /dev/null +++ b/lite/backends/arm/math/gru.h @@ -0,0 +1,248 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/backends/arm/math/sgemm.h" +#ifdef LITE_WITH_ARM +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +struct RNNGRUValue { + const T* gate_weight; + const T* state_weight; + const T* reset_bias; + T* gate_value; + T* reset_output_value; + T* output_value; + const T* prev_out_value; +}; + +template +void rnn_activation(const T* din, + T* dout, + int size, + lite_api::ActivationType act_type, + int threads) { + switch (act_type) { + case lite_api::ActivationType::kSigmoid: + act_sigmoid(din, dout, size, threads); + break; + case lite_api::ActivationType::kSigmoid_v2: + act_sigmoid(din, dout, size, threads); + break; + case lite_api::ActivationType::kTanh: + act_tanh(din, dout, size, threads); + break; + case lite_api::ActivationType::kTanh_v2: + act_tanh(din, dout, size, threads); + break; + case lite_api::ActivationType::kRelu: + act_relu(din, dout, size, threads); + break; + default: + LOG(FATAL) << "unsupport activation type:" << static_cast(act_type); + break; + } +} + +template +void compute_kernel(RNNGRUValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + + for (int b = 0; b < batch_size; b++) { + rnn_activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + rnn_activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + for (int i = 0; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + rnn_activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template <> +void compute_kernel(RNNGRUValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + int i = 0; + float32x4_t vec_one = vdupq_n_f32(1.f); + + for (int b = 0; b < batch_size; b++) { + rnn_activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + rnn_activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + for (i = 0; i + 3 < frame_size; i += 4) { + float32x4_t vec_out = vld1q_f32(value_reset_output + i); + float32x4_t vec_reset = vld1q_f32(value_reset_gate + i); + float32x4_t vec_bias = vld1q_f32(value_reset_bias + i); + vec_out = vmulq_f32(vaddq_f32(vec_out, vec_bias), vec_reset); + vst1q_f32(value_reset_output + i, vec_out); + vst1q_f32(cell_state_value + i, + vaddq_f32(vec_out, vld1q_f32(cell_state_value + i))); + } + for (; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + rnn_activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + for (i = 0; i + 3 < frame_size; i += 4) { + float32x4_t vec_vug = vld1q_f32(value_update_gate + i); + float32x4_t vec_vpo = vld1q_f32(value_prev_out + i); + float32x4_t vec_csv = vld1q_f32(cell_state_value + i); + vec_vpo = vmulq_f32(vec_vug, vec_vpo); + float32x4_t vec_out = + vmlaq_f32(vec_vpo, vsubq_f32(vec_one, vec_vug), vec_csv); + vst1q_f32(value_output + i, vec_out); + } + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + for (i = 0; i + 3 < frame_size; i += 4) { + float32x4_t vec_vug = vld1q_f32(value_update_gate + i); + float32x4_t vec_csv = vld1q_f32(cell_state_value + i); + float32x4_t vec_out = vmulq_f32(vsubq_f32(vec_one, vec_vug), vec_csv); + vst1q_f32(value_output + i, vec_out); + } + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template +struct RnnGruUnitFunctorV2 { + static void compute(ARMContext* ctx, + RNNGRUValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + if (value.prev_out_value) { + operators::ActivationParam act_param; + act_param.has_active = false; + lite::arm::math::sgemm(false, + true, + batch_size, + frame_size, + frame_size, + 1.f, + value.prev_out_value, + frame_size, + value.state_weight, + frame_size, + 0.f, + value.reset_output_value, + frame_size, + nullptr, + false, + act_param, + ctx); + } + compute_kernel(value, frame_size, batch_size, active_node, active_gate); + } +}; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/lstm.cc b/lite/backends/arm/math/lstm.cc index cd8e012a287..bf096ed04c0 100644 --- a/lite/backends/arm/math/lstm.cc +++ b/lite/backends/arm/math/lstm.cc @@ -36,6 +36,7 @@ void add_bias_rowwise(Tensor* input, i_data += width; } } + void vector_dot( float* out, const float* in, const float* v1, int size, const float* v2) { int loop = size >> 2; diff --git a/lite/backends/x86/math/elementwise.h b/lite/backends/x86/math/elementwise.h index 6a5c8cf92f6..2de0b831783 100644 --- a/lite/backends/x86/math/elementwise.h +++ b/lite/backends/x86/math/elementwise.h @@ -150,24 +150,27 @@ namespace math { } \ } -// marco func add -ElementWiseFunc(Add) ElementWiseFuncBCast(Add) - // marco func sub - ElementWiseFunc(Sub) ElementWiseFuncBCast(Sub) - // marco func mul - ElementWiseFunc(Mul) ElementWiseFuncBCast(Mul) - // marco func max - ElementWiseFunc(Max) ElementWiseFuncBCast(Max) - // marco func min - ElementWiseFunc(Min) ElementWiseFuncBCast(Min) - // marco func div - ElementWiseFunc(Div) ElementWiseFuncBCast(Div) - // marco func floordiv - ElementWiseFunc(FloorDiv) ElementWiseFuncBCast(FloorDiv) - // marco func mod - ElementWiseFunc(Mod) ElementWiseFuncBCast(Mod) - // marco func pow - ElementWiseFunc(Pow) ElementWiseFuncBCast(Pow) +// clang-format off +ElementWiseFunc(Add) +ElementWiseFuncBCast(Add) +ElementWiseFunc(Sub) +ElementWiseFuncBCast(Sub) +ElementWiseFunc(Mul) +ElementWiseFuncBCast(Mul) +ElementWiseFunc(Max) +ElementWiseFuncBCast(Max) +ElementWiseFunc(Min) +ElementWiseFuncBCast(Min) +ElementWiseFunc(Div) +ElementWiseFuncBCast(Div) +ElementWiseFunc(FloorDiv) +ElementWiseFuncBCast(FloorDiv) +ElementWiseFunc(Mod) +ElementWiseFuncBCast(Mod) +ElementWiseFunc(Pow) +ElementWiseFuncBCast(Pow) +// clang-format on + } // namespace math } // namespace x86 } // namespace lite diff --git a/lite/backends/x86/math/fill_bias_activate.cc b/lite/backends/x86/math/fill_bias_activate.cc index 31685b90778..126fe1f382c 100644 --- a/lite/backends/x86/math/fill_bias_activate.cc +++ b/lite/backends/x86/math/fill_bias_activate.cc @@ -38,7 +38,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) { __m256 vec_data = _mm256_loadu_ps(data + i); _mm256_storeu_ps(data + i, _mm256_max_ps(vec_data, vec_zero)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ __m128 vec_zero_128 = _mm_set1_ps(0.f); @@ -59,7 +58,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) { _mm256_storeu_ps( data + i, _mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ __m128 vec_zero_128 = _mm_set1_ps(0.f); @@ -112,7 +110,6 @@ static void activate_relu_inplace_bias(float *data, vec_data = _mm256_add_ps(vec_bias, vec_data); _mm256_storeu_ps(tmp_data + i, _mm256_max_ps(vec_data, vec_zero)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ vec_bias_128 = _mm_set1_ps(bias[j]); @@ -140,7 +137,6 @@ static void activate_relu_inplace_bias(float *data, tmp_data + i, _mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ vec_bias_128 = _mm_set1_ps(bias[j]); @@ -174,7 +170,6 @@ static void activate_lrelu_inplace(float *data, int len, float alpha) { __m256 vec_mask = _mm256_cmp_ps(vec_data, vec_zero, cmp_le_os); _mm256_storeu_ps(data + i, _mm256_blendv_ps(vec_data, vec_lr, vec_mask)); } -// _mm256_zeroupper(); #endif #ifdef __SSE4_1__ // blendv need 4.1 __m128 vec_zero_128 = _mm_set1_ps(0.f); @@ -226,7 +221,6 @@ static void activate_lrelu_inplace_bias(float *data, _mm256_storeu_ps(tmp_data + i, _mm256_blendv_ps(vec_data, vec_lr, vec_mask)); } -// _mm256_zeroupper(); #endif #ifdef __SSE4_1__ vec_bias_128 = _mm_set1_ps(bias[j]); @@ -273,7 +267,6 @@ static void activate_none_inplace_bias(float *data, vec_data = _mm256_add_ps(vec_bias, vec_data); _mm256_storeu_ps(tmp_data + i, vec_data); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ vec_bias_128 = _mm_set1_ps(bias[j]); diff --git a/lite/backends/x86/math/rnn.cc b/lite/backends/x86/math/rnn.cc deleted file mode 100644 index bddcf5228e4..00000000000 --- a/lite/backends/x86/math/rnn.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef __AVX__ -#include -#endif -#ifdef __SSE__ -#include -#endif - -#include "lite/backends/x86/math/activation_functions.h" -#include "lite/backends/x86/math/rnn.h" - -namespace paddle { -namespace lite { -namespace x86 { -namespace math { - -void vector_dot( - float* out, const float* in, const float* v1, int size, const float* v2) { -#if defined(__AVX__) - __m256 vec_in, vec_v1, vec_v2; -#endif -#if defined(__SSE__) - __m128 vec_in_128, vec_v1_128, vec_v2_128; -#endif - - int i = 0; - if (nullptr == v2) { - i = 0; - -// in_out * v1 -#if defined(__AVX__) - for (; i + 7 < size; i += 8) { - vec_in = _mm256_loadu_ps(in + i); - vec_v1 = _mm256_loadu_ps(v1 + i); - _mm256_storeu_ps(out + i, _mm256_mul_ps(vec_in, vec_v1)); - } -// _mm256_zeroupper(); -#endif -#if defined(__SSE__) - for (; i + 3 < size; i += 4) { - vec_in_128 = _mm_loadu_ps(in + i); - vec_v1_128 = _mm_loadu_ps(v1 + i); - _mm_storeu_ps(out + i, _mm_mul_ps(vec_in_128, vec_v1_128)); - } -#endif - for (; i < size; i++) { - out[i] = in[i] * v1[i]; - } - } else { - i = 0; - -// in_out + v1 * v2 -#if defined(__AVX__) && defined(__FMA__) - for (; i + 7 < size; i += 8) { - vec_in = _mm256_loadu_ps(in + i); - vec_v1 = _mm256_loadu_ps(v1 + i); - vec_v2 = _mm256_loadu_ps(v2 + i); - _mm256_storeu_ps(out + i, _mm256_fmadd_ps(vec_v2, vec_v1, vec_in)); - } - for (; i + 3 < size; i += 4) { - vec_in_128 = _mm_loadu_ps(in + i); - vec_v1_128 = _mm_loadu_ps(v1 + i); - vec_v2_128 = _mm_loadu_ps(v2 + i); - _mm_storeu_ps(out + i, _mm_fmadd_ps(vec_v2_128, vec_v1_128, vec_in_128)); - } -#endif - for (; i < size; i++) { - out[i] = in[i] + v1[i] * v2[i]; - } - } -} - -template <> -void act_relu(const float* din, float* dout, int size, int threads) { - int i = 0; - -#ifdef __AVX__ - for (; i + 7 < size; i += 8) { - __m256 a = _mm256_loadu_ps(din + i); - _mm256_storeu_ps(dout + i, lite::x86::math::detail::forward::avx::Relu(a)); - } -#endif - for (; i < size; i++) { - dout[i] = lite::x86::math::detail::forward::Relu(din[i]); - } -} - -template <> -void act_sigmoid(const float* din, float* dout, int size, int threads) { - int i = 0; - -#ifdef __AVX__ - for (; i + 7 < size; i += 8) { - __m256 a = _mm256_loadu_ps(din + i); - _mm256_storeu_ps(dout + i, - lite::x86::math::detail::forward::avx::Sigmoid(a)); - } -#endif - for (; i < size; i++) { - dout[i] = lite::x86::math::detail::forward::Sigmoid(din[i]); - } -} - -template <> -void act_tanh(const float* din, float* dout, int size, int threads) { - int i = 0; - -#ifdef __AVX__ - for (; i + 7 < size; i += 8) { - __m256 a = _mm256_loadu_ps(din + i); - _mm256_storeu_ps(dout + i, lite::x86::math::detail::forward::avx::Tanh(a)); - } -#endif - for (; i < size; i++) { - dout[i] = lite::x86::math::detail::forward::Tanh(din[i]); - } -} - -void fill_bias_fc(float* out, const float* bias, int num, int channel) { -#ifdef __AVX__ - __m256 vec_bias = {0.f}; - __m256 vec_data = {0.f}; -#endif -#ifdef __SSE__ - __m128 vec_bias_128 = {0.f}; - __m128 vec_data_128 = {0.f}; -#endif - int i = 0; - - for (int j = 0; j < num; j++) { - float* ptr = out + j * channel; - const float* pbias = bias; - i = 0; - -#ifdef __AVX__ - for (; i + 7 < channel; i += 8) { - vec_bias = _mm256_loadu_ps(pbias + i); - vec_data = _mm256_loadu_ps(ptr + i); - _mm256_storeu_ps(ptr + i, _mm256_add_ps(vec_data, vec_bias)); - } -// _mm256_zeroupper(); -#endif -#ifdef __SSE__ - for (; i + 3 < channel; i += 4) { - vec_bias_128 = _mm_loadu_ps(pbias + i); - vec_data_128 = _mm_loadu_ps(ptr + i); - _mm_storeu_ps(ptr + i, _mm_add_ps(vec_data_128, vec_bias_128)); - } -#endif - for (; i < channel; i++) { - *(ptr + i) = pbias[i] + ptr[i]; - } - } -} - -} // namespace math -} // namespace x86 -} // namespace lite -} // namespace paddle diff --git a/lite/backends/x86/math/rnn.h b/lite/backends/x86/math/rnn.h index 82d467fcad5..620ebe8c23d 100644 --- a/lite/backends/x86/math/rnn.h +++ b/lite/backends/x86/math/rnn.h @@ -15,14 +15,33 @@ #pragma once #include +#include "lite/backends/x86/math/activation_functions.h" +#include "lite/backends/x86/math/blas.h" #include "lite/core/tensor.h" #include "lite/utils/log/logging.h" +#ifdef __AVX__ +#include +#endif +#ifdef __SSE__ +#include +#endif + +#ifndef __FMA__ +#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps((c), _mm256_mul_ps((a), (b))) +#define _mm_fmadd_ps(a, b, c) _mm_add_ps((c), _mm_mul_ps((a), (b))) +#endif + namespace paddle { namespace lite { namespace x86 { namespace math { +namespace x86_forward = paddle::lite::x86::math::detail::forward; + +//************************************** +// Class Def +//************************************** template struct LstmMetaValue { T* gate_value; @@ -35,25 +54,183 @@ struct LstmMetaValue { T* check_og; }; +template +struct GRUMetaValue { + const T* gate_weight; + const T* state_weight; + const T* reset_bias; + T* gate_value; + T* reset_output_value; + T* output_value; + const T* prev_out_value; +}; + +//********************************* +// Inline Function +//********************************* // if v2 isn't null: out[i] = in[i] + v1[i] * v2[i]; // if v2 is null: out[i] = in[i] * v1[i]; -void vector_dot(float* out, - const float* in, - const float* v1, - int size, - const float* v2 = nullptr); +inline void vector_dot(float* out, + const float* in, + const float* v1, + int size, + const float* v2 = nullptr) { +#if defined(__AVX__) + __m256 vec_in, vec_v1, vec_v2; +#endif +#if defined(__SSE__) + __m128 vec_in_128, vec_v1_128, vec_v2_128; +#endif + + int i = 0; + if (nullptr == v2) { + i = 0; + +// in_out * v1 +#if defined(__AVX__) + for (; i + 7 < size; i += 8) { + vec_in = _mm256_loadu_ps(in + i); + vec_v1 = _mm256_loadu_ps(v1 + i); + _mm256_storeu_ps(out + i, _mm256_mul_ps(vec_in, vec_v1)); + } +#endif +#if defined(__SSE__) + for (; i + 3 < size; i += 4) { + vec_in_128 = _mm_loadu_ps(in + i); + vec_v1_128 = _mm_loadu_ps(v1 + i); + _mm_storeu_ps(out + i, _mm_mul_ps(vec_in_128, vec_v1_128)); + } +#endif + for (; i < size; i++) { + out[i] = in[i] * v1[i]; + } + } else { + i = 0; + +// in_out + v1 * v2 +#if defined(__AVX__) + for (; i + 7 < size; i += 8) { + vec_in = _mm256_loadu_ps(in + i); + vec_v1 = _mm256_loadu_ps(v1 + i); + vec_v2 = _mm256_loadu_ps(v2 + i); + _mm256_storeu_ps(out + i, _mm256_fmadd_ps(vec_v2, vec_v1, vec_in)); + } +#endif +#if defined(__SSE__) + for (; i + 3 < size; i += 4) { + vec_in_128 = _mm_loadu_ps(in + i); + vec_v1_128 = _mm_loadu_ps(v1 + i); + vec_v2_128 = _mm_loadu_ps(v2 + i); + _mm_storeu_ps(out + i, _mm_fmadd_ps(vec_v2_128, vec_v1_128, vec_in_128)); + } +#endif + for (; i < size; i++) { + out[i] = in[i] + v1[i] * v2[i]; + } + } +} + +inline void fill_bias_fc(float* out, const float* bias, int num, int channel) { +#ifdef __AVX__ + __m256 vec_bias = {0.f}; + __m256 vec_data = {0.f}; +#endif +#ifdef __SSE__ + __m128 vec_bias_128 = {0.f}; + __m128 vec_data_128 = {0.f}; +#endif + int i = 0; -// only add bias -void fill_bias_fc(float* out, const float* bias, int num, int channel); + for (int j = 0; j < num; j++) { + float* ptr = out + j * channel; + const float* pbias = bias; + i = 0; +#ifdef __AVX__ + for (; i + 7 < channel; i += 8) { + vec_bias = _mm256_loadu_ps(pbias + i); + vec_data = _mm256_loadu_ps(ptr + i); + _mm256_storeu_ps(ptr + i, _mm256_add_ps(vec_data, vec_bias)); + } +#endif +#ifdef __SSE__ + for (; i + 3 < channel; i += 4) { + vec_bias_128 = _mm_loadu_ps(pbias + i); + vec_data_128 = _mm_loadu_ps(ptr + i); + _mm_storeu_ps(ptr + i, _mm_add_ps(vec_data_128, vec_bias_128)); + } +#endif + for (; i < channel; i++) { + *(ptr + i) = pbias[i] + ptr[i]; + } + } +} + +//******************************* +// Template Func +//******************************* template -void act_relu(const T* din, T* dout, int size, int threads); +void act_relu(const T* din, T* dout, int size, int threads) { + for (int i = 0; i < size; i++) { + dout[i] = x86_forward::Relu(din[i]); + } +} template -void act_sigmoid(const T* din, T* dout, int size, int threads); +void act_sigmoid(const T* din, T* dout, int size, int threads) { + for (int i = 0; i < size; i++) { + dout[i] = x86_forward::Sigmoid(din[i]); + } +} template -void act_tanh(const T* din, T* dout, int size, int threads); +void act_tanh(const T* din, T* dout, int size, int threads) { + for (int i = 0; i < size; i++) { + dout[i] = x86_forward::Tanh(din[i]); + } +} + +template <> +void act_relu(const float* din, float* dout, int size, int threads) { + int i = 0; +#ifdef __AVX__ + for (; i + 7 < size; i += 8) { + __m256 a = _mm256_loadu_ps(din + i); + _mm256_storeu_ps(dout + i, x86_forward::avx::Relu(a)); + } +#endif + for (; i < size; i++) { + dout[i] = x86_forward::Relu(din[i]); + } +} + +template <> +void act_sigmoid(const float* din, float* dout, int size, int threads) { + int i = 0; +#ifdef __AVX__ + for (; i + 7 < size; i += 8) { + __m256 a = _mm256_loadu_ps(din + i); + _mm256_storeu_ps(dout + i, x86_forward::avx::Sigmoid(a)); + } +#endif + for (; i < size; i++) { + dout[i] = x86_forward::Sigmoid(din[i]); + } +} + +template <> +void act_tanh(const float* din, float* dout, int size, int threads) { + int i = 0; +#ifdef __AVX__ + for (; i + 7 < size; i += 8) { + __m256 a = _mm256_loadu_ps(din + i); + _mm256_storeu_ps(dout + i, x86_forward::avx::Tanh(a)); + } +#endif + for (; i < size; i++) { + dout[i] = x86_forward::Tanh(din[i]); + } +} template void activation( @@ -97,6 +274,9 @@ void activation(const T* din, } } +//*********************************** +// LSTM MODE +//*********************************** template struct RnnLstmUnitFunctor { static void compute(LstmMetaValue value, @@ -163,6 +343,237 @@ struct RnnLstmUnitFunctor { } }; +//************************************ +// GRU MODE +//************************************ +template +void GruRnnComputeKernel(GRUMetaValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + + for (int b = 0; b < batch_size; b++) { + activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + for (int i = 0; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template <> +void GruRnnComputeKernel(GRUMetaValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + int i = 0; + +#ifdef __AVX__ + __m256 vec_one_256 = _mm256_set1_ps(1.0f); +#endif +#ifdef __SSE__ + __m128 vec_one_128 = _mm_set1_ps(1.0f); +#endif + + for (int b = 0; b < batch_size; b++) { + activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + i = 0; +#ifdef __AVX__ + for (; i + 7 < frame_size; i += 8) { + __m256 vec_out = _mm256_loadu_ps(value_reset_output + i); + __m256 vec_reset = _mm256_loadu_ps(value_reset_gate + i); + __m256 vec_bias = _mm256_loadu_ps(value_reset_bias + i); + vec_out = _mm256_mul_ps(_mm256_add_ps(vec_out, vec_bias), vec_reset); + _mm256_storeu_ps(value_reset_output + i, vec_out); + _mm256_storeu_ps( + cell_state_value + i, + _mm256_add_ps(vec_out, _mm256_loadu_ps(cell_state_value + i))); + } +#endif +#ifdef __SSE__ + for (; i + 3 < frame_size; i += 4) { + __m128 vec_out = _mm_loadu_ps(value_reset_output + i); + __m128 vec_reset = _mm_loadu_ps(value_reset_gate + i); + __m128 vec_bias = _mm_loadu_ps(value_reset_bias + i); + vec_out = _mm_mul_ps(_mm_add_ps(vec_out, vec_bias), vec_reset); + _mm_storeu_ps(value_reset_output + i, vec_out); + _mm_storeu_ps(cell_state_value + i, + _mm_add_ps(vec_out, _mm_loadu_ps(cell_state_value + i))); + } +#endif + for (; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + i = 0; +#ifdef __AVX__ + for (; i + 7 < frame_size; i += 8) { + __m256 vec_vug = _mm256_loadu_ps(value_update_gate + i); + __m256 vec_vpo = _mm256_loadu_ps(value_prev_out + i); + __m256 vec_csv = _mm256_loadu_ps(cell_state_value + i); + vec_vpo = _mm256_mul_ps(vec_vug, vec_vpo); +#ifdef __FMA__ + __m256 vec_out = _mm256_fmadd_ps( + vec_csv, _mm256_sub_ps(vec_one_256, vec_vug), vec_vpo); +#else + __m256 vec_out = _mm256_add_ps( + _mm256_mul_ps(vec_csv, _mm256_sub_ps(vec_one_256, vec_vug)), + vec_vpo); +#endif + _mm256_storeu_ps(value_output + i, vec_out); + } +#endif +#ifdef __SSE__ + for (; i + 3 < frame_size; i += 4) { + __m128 vec_vug = _mm_loadu_ps(value_update_gate + i); + __m128 vec_vpo = _mm_loadu_ps(value_prev_out + i); + __m128 vec_csv = _mm_loadu_ps(cell_state_value + i); + vec_vpo = _mm_mul_ps(vec_vug, vec_vpo); + __m128 vec_out = _mm_add_ps( + _mm_mul_ps(vec_csv, _mm_sub_ps(vec_one_128, vec_vug)), vec_vpo); + _mm_storeu_ps(value_output + i, vec_out); + } +#endif + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + i = 0; +#ifdef __AVX__ + for (; i + 7 < frame_size; i += 8) { + __m256 vec_vug = _mm256_loadu_ps(value_update_gate + i); + __m256 vec_csv = _mm256_loadu_ps(cell_state_value + i); + __m256 vec_out = + _mm256_mul_ps(_mm256_sub_ps(vec_one_256, vec_vug), vec_csv); + _mm256_storeu_ps(value_output + i, vec_out); + } +#endif +#ifdef __SSE__ + for (; i + 3 < frame_size; i += 4) { + __m128 vec_vug = _mm_loadu_ps(value_update_gate + i); + __m128 vec_csv = _mm_loadu_ps(cell_state_value + i); + __m128 vec_out = _mm_mul_ps(_mm_sub_ps(vec_one_128, vec_vug), vec_csv); + _mm_storeu_ps(value_output + i, vec_out); + } +#endif + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template +struct RnnGruUnitFunctorV2 { + static void compute(X86Context* ctx, + GRUMetaValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + if (value.prev_out_value) { + lite::x86::math::Blas matmul(*ctx); + matmul.GEMM(false, + true, + batch_size, + frame_size, + frame_size, + 1.f, + value.prev_out_value, + frame_size, + value.state_weight, + frame_size, + 0.f, + value.reset_output_value, + frame_size); + } + GruRnnComputeKernel( + value, frame_size, batch_size, active_node, active_gate); + } +}; + } // namespace math } // namespace x86 } // namespace lite diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h index 18d9c4908a6..b749f25f7a8 100644 --- a/lite/core/profile/precision_profiler.h +++ b/lite/core/profile/precision_profiler.h @@ -575,8 +575,8 @@ class PrecisionProfiler { std::string out_arg_name; op->op_info()->GetOutputArgname(out_name, &out_arg_name); auto type = kernel->GetOutputDeclType(out_arg_name); - - if (type->IsTensor()) { + auto tmp = op_scope->FindVar(out_name); + if (tmp->IsType()) { const Tensor* tout = op_scope->FindVar(out_name)->GetMutable(); double mean = -999999; @@ -613,7 +613,7 @@ class PrecisionProfiler { << " " << setw(15) << left << mean_str << " " << setw(15) << left << std_dev_str << " " << setw(15) << left << ave_grow_rate_str << std::endl; - } else if (type->IsTensorList()) { + } else if (tmp->IsType>()) { auto touts = op_scope->FindVar(out_name)->GetMutable>(); for (auto t : *touts) { diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index f282d7ee3df..28a6954531c 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -40,7 +40,6 @@ add_kernel(conv_transpose_compute_arm ARM basic SRCS conv_transpose_compute.cc) add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc) add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc) add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc) -add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc) add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc) add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc) add_kernel(affine_grid_compute_arm ARM basic SRCS affine_grid_compute.cc) diff --git a/lite/kernels/arm/rnn_compute.cc b/lite/kernels/arm/rnn_compute.cc index 4e8ad8f0267..8e67c82aa9a 100644 --- a/lite/kernels/arm/rnn_compute.cc +++ b/lite/kernels/arm/rnn_compute.cc @@ -19,6 +19,7 @@ #include #include "lite/backends/arm/math/concat.h" #include "lite/backends/arm/math/funcs.h" +#include "lite/backends/arm/math/gru.h" #include "lite/backends/arm/math/lstm.h" #include "lite/backends/arm/math/sgemm.h" #include "lite/backends/host/math/split.h" @@ -28,6 +29,23 @@ namespace lite { namespace kernels { namespace arm { +// layer, output_tensor, is_bidirection, offset +#define RUN_RNN_LAYER(x, y, z, w) \ + RunRnnLayer(&ctx, \ + input_temp_holder, \ + parameter_lists[x], \ + init_h_unbind, \ + init_c_unbind, \ + sequence_length, \ + &last_h_unbind, \ + &last_c_unbind, \ + y, \ + x, \ + &gate_value, \ + z, \ + w, \ + mode) + void reset_parameter_vector(const std::vector& raw_params_vec, const int& num_layers, const int& gate_num, @@ -66,25 +84,36 @@ void SwapPoniter(Tensor** a, Tensor** b) { *b = c; } -void preprocess(ARMContext* ctx, - const Tensor* input, - const Tensor& weight, - const Tensor& bias_ih, - const Tensor& bias_hh, - Tensor* cache_input, - bool is_test) { +/****************************************************** +input: + ctx:context, + input:(3D)time_step, batch, input_size, + weight:(2D)hidden_size, input_size, + bias_ih, + bias_hh, + mode:LSTM, GRU +output: + cache_input:(3D)time_step, batch, hidden_size +*******************************************************/ +static void preprocess(ARMContext* ctx, + const Tensor* input, + const Tensor& weight, + const Tensor& bias_ih, + const Tensor& bias_hh, + std::string mode, + Tensor* cache_input) { const int& hidden_size = weight.dims()[0]; int time_step = input->dims()[0]; int batch = input->dims()[1]; + std::vector cache_input_dim = {time_step, batch, hidden_size}; DDim gate_dim; gate_dim.ConstructFrom(cache_input_dim); cache_input->Resize(gate_dim); - cache_input->mutable_data(); + auto* i_data = input->data(); auto* w_data = weight.data(); auto* o_data = cache_input->mutable_data(); - bool flag_act = false; operators::ActivationParam act_param; act_param.has_active = false; @@ -112,19 +141,173 @@ void preprocess(ARMContext* ctx, act_param, ctx); lite::arm::math::fill_bias_fc(o_data, bias_ih.data(), m, n, flag_act); - lite::arm::math::fill_bias_fc(o_data, bias_hh.data(), m, n, flag_act); + + if ("GRU" == mode) { + Tensor bias_tmp_hh; + bias_tmp_hh.Resize(bias_hh.dims()); + auto bias_ptr = bias_tmp_hh.mutable_data(); + auto bias_src = bias_hh.data(); + int bias_offt = bias_hh.numel() / 3 * 2; + std::memcpy(bias_ptr, bias_src, bias_offt * sizeof(float)); + std::memset( + bias_ptr + bias_offt, 0, (bias_hh.numel() - bias_offt) * sizeof(float)); + lite::arm::math::fill_bias_fc( + o_data, bias_tmp_hh.data(), m, n, flag_act); + } else { + lite::arm::math::fill_bias_fc( + o_data, bias_hh.data(), m, n, flag_act); + } +} + +/****************************************************** +input: + ctx:context, + init_h:(2D), + init_c:(2D), + mask_tensor:(1D)input->dims()[1], + mode:LSTM, GRU +output: + output:(2D)output->dims()[1], output->dims()[2], + last_h:(2D), + last_c:(2D) +*******************************************************/ +static void postprocess(ARMContext* ctx, + Tensor* output, + const Tensor* init_h, + const Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + const Tensor& mask_tensor, + std::string mode) { + Tensor mask_broadcast_1; + mask_broadcast_1.Resize(mask_tensor.dims()); + auto mask_ptr_1 = mask_broadcast_1.mutable_data(); + auto mask_ptr = mask_tensor.data(); + auto out_ptr = output->mutable_data(); + auto cur_h_ptr = last_h->mutable_data(); + auto pre_h_ptr = init_h->data(); + int offset = 0; + + // out = out * mask_broadcast + // curr_h = out * mask_broadcast + pre_h * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + mask_ptr_1[i] = 1 - mask_ptr[i]; + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + out_ptr[offset] *= mask_ptr[i]; + cur_h_ptr[offset] = out_ptr[offset] + pre_h_ptr[offset] * mask_ptr_1[i]; + } + } + if ("LSTM" == mode) { + auto pre_c_ptr = init_c->data(); + auto cur_c_ptr = last_c->mutable_data(); + + // curr_c = curr_c * mask_broadcast + pre_c * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + cur_c_ptr[offset] = + cur_c_ptr[offset] * mask_ptr[i] + pre_c_ptr[offset] * mask_ptr_1[i]; + } + } + } +} + +static DDim get_stride(const DDim& ddim) { + DDim strides; + strides[ddim.size() - 1] = 1; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i + 1]; + } + return strides; +} + +template +static void TransposeNormal(const Tensor& in, + Tensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = get_stride(in.dims()); + auto out_stride = get_stride(out->dims()); + const T* in_ptr = in.data(); + T* out_ptr = out->mutable_data(); + + auto transpose_helper = [&](int64_t beg, int64_t end) { + for (int64_t out_idx = beg; out_idx < end; ++out_idx) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + // calculate the input index + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride[i]; + tmp_idx -= coordinate * out_stride[i]; + in_idx += coordinate * in_stride[axis[i]]; + } + out_ptr[out_idx] = in_ptr[in_idx]; + } + }; + transpose_helper(0, out->numel()); +} + +/****************************************************** +input: + sequence_length, + is_reverse +output: + mask_matrix, + min_seq_len +******************************************************/ +static void create_mask_matrix(const Tensor* sequence_length, + Tensor* mask_matrix, + const bool& is_reverse, + int* min_seq_len) { + // Tensor to vector + std::vector seq_len_vec; + seq_len_vec.resize(sequence_length->numel()); + std::memcpy(&seq_len_vec[0], + sequence_length->data(), + sequence_length->numel() * sizeof(int)); + + const int& table_width = mask_matrix->dims()[0]; + Tensor temp; + DDimLite dims( + std::vector{mask_matrix->dims()[1], mask_matrix->dims()[0]}); + temp.Resize(dims); + float* data_temp = temp.mutable_data(); + std::fill(data_temp, data_temp + mask_matrix->numel(), 1.f); + *min_seq_len = table_width; + for (unsigned int i = 0; i < seq_len_vec.size(); i++) { + // reset the mask matrix + *min_seq_len = std::min(seq_len_vec[i], *min_seq_len); + if (seq_len_vec[i] == table_width) { + continue; + } + if (is_reverse) { + std::fill(data_temp + i * table_width, + data_temp + (i + 1) * table_width - seq_len_vec[i], + 0.f); + } else { + std::fill(data_temp + i * table_width + seq_len_vec[i], + data_temp + (i + 1) * table_width, + 0.f); + } + } + mask_matrix->mutable_data(); + std::vector trans_vec; + trans_vec.emplace_back(1); + trans_vec.emplace_back(0); + TransposeNormal(temp, mask_matrix, trans_vec); } -void cell(ARMContext* ctx, - Tensor* input, - Tensor* weight_hh, - Tensor* init_h, - Tensor* init_c, - Tensor* last_h, - Tensor* last_c, - Tensor* last_c_act, - Tensor* output, - const Tensor* bias_hh) { +static void lstm_cell(ARMContext* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh) { bool flag_act = false; operators::ActivationParam act_param; act_param.has_active = false; @@ -201,35 +384,88 @@ void cell(ARMContext* ctx, ctx->threads()); } -// layer, output_tensor, is_bidirection, offset -#define RUN_LSTM_LAYER(x, y, z, w) \ - runLSTMLayer(&ctx, \ - input_temp_holder, \ - parameter_lists[x], \ - init_h_unbind, \ - init_c_unbind, \ - sequence_length, \ - &last_h_unbind, \ - &last_c_unbind, \ - y, \ - x, \ - &gate_value, \ - z, \ - w) - -void runLSTMLayer(ARMContext* ctx, - const Tensor* input, - std::vector vec, - std::vector init_h, - std::vector init_c, - const Tensor* sequence_length, - std::vector* last_h_ptr, - std::vector* last_c_ptr, - Tensor* output, - int layer_idx, - Tensor* gate_value, - bool is_bidirect, - int offset) { +static void gru_cell(ARMContext* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh, + Tensor* weight_hh_gru) { + bool flag_act = false; + operators::ActivationParam act_param; + act_param.has_active = false; + auto h_dims = init_h->dims(); + auto weight_gru_dims = weight_hh_gru->dims(); + int m = h_dims[0]; + int k = h_dims[1]; + int n = weight_gru_dims[0]; + auto i_data = input->data(); + auto w_gru = weight_hh_gru->data(); + auto h_data = init_h->data(); + + Tensor tmp_gate; + tmp_gate.Resize(input->dims()); + auto tmp_data = tmp_gate.mutable_data(); + lite::arm::math::sgemm(false, + true, + m, + n, + k, + 1.f, + h_data, + k, + w_gru, + k, + 0.f, + tmp_data, + n, + nullptr, + false, + act_param, + ctx); + for (int i = 0; i < input->dims()[0] * input->dims()[1]; i++) { + tmp_data[i] += i_data[i]; + } + + size_t frame_size = init_h->dims()[1]; + size_t batch_size = init_h->dims()[0]; + + lite::arm::math::RNNGRUValue gru_value; + gru_value.gate_weight = weight_hh->data(); + gru_value.state_weight = + weight_hh->data() + 2 * frame_size * frame_size; + gru_value.reset_bias = bias_hh->data() + 2 * frame_size; + + gru_value.gate_value = tmp_data; + gru_value.reset_output_value = last_c->mutable_data(); + gru_value.output_value = output->mutable_data(); + gru_value.prev_out_value = init_h->data(); + + lite_api::ActivationType gate_act = lite_api::ActivationType::kSigmoid_v2; + lite_api::ActivationType cand_act = lite_api::ActivationType::kTanh_v2; + + lite::arm::math::RnnGruUnitFunctorV2::compute( + ctx, gru_value, frame_size, batch_size, cand_act, gate_act); +} + +static void RunRnnLayer(ARMContext* ctx, + const Tensor* input, + std::vector vec, + std::vector init_h, + std::vector init_c, + const Tensor* sequence_length, + std::vector* last_h_ptr, + std::vector* last_c_ptr, + Tensor* output, + int layer_idx, + Tensor* gate_value, + bool is_bidirect, + int offset, + std::string mode) { bool is_reverse = false; if (is_bidirect) { layer_idx = 2 * layer_idx + offset; @@ -243,13 +479,16 @@ void runLSTMLayer(ARMContext* ctx, vec[0 + offset * 4], vec[2 + offset * 4], vec[3 + offset * 4], - gate_value, - true); + mode, + gate_value); + std::vector input_tensors, output_tensors; std::vector input_tensors_t, output_tensors_t; - std::vector stride1, stride2; + std::vector stride1, stride2, stride3; input_tensors.resize(gate_value->dims()[0]); output_tensors.resize(output->dims()[0]); + + // unbind for (int i = 0; i < gate_value->dims()[0]; i++) { stride1.push_back(1); int dim1 = gate_value->dims()[1]; @@ -272,68 +511,140 @@ void runLSTMLayer(ARMContext* ctx, auto sd = output->mutable_data(); if (is_reverse) { + // don't need to reverse input_tensors_t becauese of unuseful std::reverse(input_tensors.begin(), input_tensors.end()); } bool has_sequence_length = false; + if (sequence_length != nullptr) { + has_sequence_length = true; + } + // unbind + Tensor mask_matrix; + std::vector mask_vec; + std::vector mask_tensor_list; + int mask_min_length = time_step; + /* - TODO has_sequence_length + to be verifying! */ - - int mask_min_length = time_step; + if (has_sequence_length) { + mask_matrix.Resize(DDimLite({time_step, input->dims()[1]})); + create_mask_matrix( + sequence_length, &mask_matrix, is_reverse, &mask_min_length); + for (int i = 0; i < time_step; i++) { + stride3.push_back(1); + DDimLite ddims(std::vector{input->dims()[1]}); + mask_vec[i].Resize(ddims); + mask_tensor_list.push_back(&mask_vec[i]); + } + lite::host::math::split( + mask_matrix.data(), mask_tensor_list, 0, stride3); + } if (is_reverse) { mask_min_length = mask_min_length - time_step + 1; } + bool has_allocate_mem_c = false; bool has_use_last_h_holder = false; const int& reverse_flag = is_reverse ? -1 : 1; + + // define the init_h holder for the swap Tensor init_h_temp; + init_h_temp.Resize(init_h[layer_idx].dims()); init_h_temp.CopyDataFrom(init_h[layer_idx]); Tensor* init_h_holder = &init_h_temp; Tensor* last_h_holder = nullptr; - if (0 < mask_min_length) { last_h_holder = &(output_tensors[0]); } else { last_h_holder = &(*last_h_ptr)[layer_idx]; has_use_last_h_holder = true; } + Tensor* init_c_holder = nullptr; Tensor* init_c_temp_holder = nullptr; Tensor init_c_temp; Tensor* last_c_holder = nullptr; Tensor last_c_temp; - last_c_holder = &(*last_c_ptr)[layer_idx]; - init_c_temp_holder = &init_c[layer_idx]; + + if ("LSTM" == mode) { + last_c_holder = &(*last_c_ptr)[layer_idx]; + init_c_temp_holder = &init_c[layer_idx]; + } else if ("GRU" == mode) { + // for reset output value + last_c_temp.Resize(init_h[layer_idx].dims()); + last_c_temp.mutable_data(); + last_c_holder = &last_c_temp; + } + + Tensor weight_hh_tmp; // for gru + std::vector weight_hh_tmp_ubind; + std::vector weight_hh_tmp_ubind_t; + std::vector stride_w; + if ("GRU" == mode) { + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + weight_hh_tmp.mutable_data(); + weight_hh_tmp.CopyDataFrom(vec[1 + offset * 4]); + int size = weight_hh_tmp.numel() / 3; + std::memset(weight_hh_tmp.mutable_data() + size * 2, + 0, + size * sizeof(float)); + } + for (int i = 0; i < time_step; i++) { bool in_mask = (reverse_flag * i) >= mask_min_length; if (i > 0) { if (!has_allocate_mem_c) { - init_c_temp.Resize(init_h[layer_idx].dims()); - init_c_temp.mutable_data(); - init_c_holder = &init_c_temp; + if (("LSTM" == mode) || ("GRU" == mode)) { + init_c_temp.Resize(init_h[layer_idx].dims()); + init_c_temp.mutable_data(); + init_c_holder = &init_c_temp; + } has_allocate_mem_c = true; } SwapPoniter(&init_c_holder, &last_c_holder); init_c_temp_holder = init_c_holder; } - // LSTMCELL - cell(ctx, - &input_tensors[i], - &vec[1 + offset * 4], - init_h_holder, - init_c_temp_holder, - last_h_holder, - last_c_holder, - nullptr, - &output_tensors[i], - &vec[3 + offset * 4]); + if ("LSTM" == mode) { + lstm_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4]); + } else if ("GRU" == mode) { + gru_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4], + &weight_hh_tmp); + } + /* + to be verifying! + */ if (in_mask) { - /* - TODO in_mask - */ + postprocess(ctx, + &output_tensors[i], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + mask_vec[i], + mode); } + // prepare next step if (i + 1 < time_step) { bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length; @@ -347,6 +658,7 @@ void runLSTMLayer(ARMContext* ctx, SwapPoniter(&init_h_holder, &last_h_holder); } } + if (is_reverse) { std::reverse(output_tensors.begin(), output_tensors.end()); } @@ -364,7 +676,7 @@ void runLSTMLayer(ARMContext* ctx, } else { (*last_h_ptr)[layer_idx].CopyDataFrom(output_tensors[time_step - 1]); } - if (time_step % 2 == 0) { + if ((0 == (time_step % 2)) && ("LSTM" == mode)) { (*last_c_ptr)[layer_idx].CopyDataFrom(*last_c_holder); } } @@ -375,25 +687,30 @@ void RnnCompute::Run() { std::string mode = param.mode; auto input = param.Input; auto weight_list = param.WeightList; - auto reserve_data = param.Reserve; auto pre_state = param.PreState; auto state = param.State; - auto dropout_state = param.DropoutState; auto output = param.Out; bool is_bidirec = param.is_bidirec; int num_layers = param.num_layers; - int input_size = param.input_size; - int hidden_size = param.hidden_size; - bool is_test = param.is_test; - float dropout_prob = param.dropout_prob; - int seed = param.seed; const Tensor* sequence_length = param.SequenceLength; + int gate_num = 0; + + if ("LSTM" == mode) { + gate_num = 4; + } else if ("GRU" == mode) { + gate_num = 3; + } else { + LOG(FATAL) << "X86 RNN ERROR: unsupport mode except gru and lstm," + " present mode is " + << mode; + return; + } state[0]->mutable_data(); - state[1]->mutable_data(); + if ("LSTM" == mode) { + state[1]->mutable_data(); + } - // lstmCell begin - int gate_num = 4; std::vector> parameter_lists; parameter_lists.reserve(num_layers); reset_parameter_vector( @@ -407,31 +724,56 @@ void RnnCompute::Run() { last_c_unbind; std::vector init_h_unbind_t, init_c_unbind_t, last_h_unbind_t, last_c_unbind_t; - init_h_unbind.resize(4); - init_c_unbind.resize(pre_state[1]->dims()[0]); + init_h_unbind.resize(pre_state[0]->dims()[0]); last_h_unbind.resize(state[0]->dims()[0]); - last_c_unbind.resize(state[1]->dims()[0]); - std::vector stride; + if ("LSTM" == mode) { + init_c_unbind.resize(pre_state[1]->dims()[0]); + last_c_unbind.resize(state[1]->dims()[0]); + } + + std::vector stride1, stride2; + // unbind for (int i = 0; i < pre_state[0]->dims()[0]; i++) { - stride.push_back(1); + stride1.push_back(1); int dim1 = pre_state[0]->dims()[1]; int dim2 = pre_state[0]->dims()[2]; DDimLite dims(std::vector{dim1, dim2}); init_h_unbind[i].Resize(dims); - init_c_unbind[i].Resize(dims); last_h_unbind[i].Resize(dims); - last_c_unbind[i].Resize(dims); init_h_unbind_t.push_back(&init_h_unbind[i]); - init_c_unbind_t.push_back(&init_c_unbind[i]); last_h_unbind_t.push_back(&last_h_unbind[i]); - last_c_unbind_t.push_back(&last_c_unbind[i]); } lite::host::math::split( - pre_state[0]->data(), init_h_unbind_t, 0, stride); - lite::host::math::split( - pre_state[1]->data(), init_c_unbind_t, 0, stride); - lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride); - lite::host::math::split(state[1]->data(), last_c_unbind_t, 0, stride); + pre_state[0]->data(), init_h_unbind_t, 0, stride1); + lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride1); + + if ("LSTM" == mode) { + for (int i = 0; i < pre_state[1]->dims()[0]; i++) { + stride2.push_back(1); + int dim1 = pre_state[1]->dims()[1]; + int dim2 = pre_state[1]->dims()[2]; + DDimLite dims(std::vector{dim1, dim2}); + init_c_unbind[i].Resize(dims); + last_c_unbind[i].Resize(dims); + init_c_unbind_t.push_back(&init_c_unbind[i]); + last_c_unbind_t.push_back(&last_c_unbind[i]); + } + lite::host::math::split( + pre_state[1]->data(), init_c_unbind_t, 0, stride2); + lite::host::math::split( + state[1]->data(), last_c_unbind_t, 0, stride2); + } + + std::vector output_vec(2); + int time_step = input->dims()[0]; + int batch_size = input->dims()[1]; + int hidden_size = output->dims()[2]; + if (is_bidirec) { + for (int i = 0; i < 2; ++i) { + output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); + output_vec[i].mutable_data(); + } + } for (int i = 0; i < num_layers; i++) { if (i > 0) { @@ -450,27 +792,21 @@ void RnnCompute::Run() { } if (is_bidirec) { - std::vector output_vec(2); - int time_step = input->dims()[0]; - int batch_size = input->dims()[1]; - int hidden_size = output->dims()[2]; - for (int i = 0; i < 2; ++i) { - output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); - output_vec[i].mutable_data(); - } - - RUN_LSTM_LAYER(i, &output_vec[0], true, 0); - RUN_LSTM_LAYER(i, &output_vec[1], true, 1); - + RUN_RNN_LAYER(i, &output_vec[0], true, 0); + RUN_RNN_LAYER(i, &output_vec[1], true, 1); std::vector output_vec_t = {&output_vec[0], &output_vec[1]}; lite::arm::math::concat_func(output_vec_t, 2, output_holder); } else { - RUN_LSTM_LAYER(i, output_holder, false, 0); + RUN_RNN_LAYER(i, output_holder, false, 0); } if (num_layers % 2 == 0) { output->CopyDataFrom(*output_holder); } } + // output_holder != output + if (num_layers % 2 == 0) { + output->CopyDataFrom(*output_holder); + } } } // namespace arm } // namespace kernels diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 439f6e5722f..8ddcc0d0819 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -28,6 +28,7 @@ add_kernel(argmax_compute_host Host basic SRCS argmax_compute.cc) add_kernel(assign_value_compute_host Host basic SRCS assign_value_compute.cc) add_kernel(yolo_box_compute_host Host basic SRCS yolo_box_compute.cc) add_kernel(write_back_compute_host Host basic SRCS write_back_compute.cc) +add_kernel(cast_compute_host Host basic SRCS cast_compute.cc) # extra kernels add_kernel(reverse_compute_host Host extra SRCS reverse_compute.cc) diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/host/cast_compute.cc similarity index 91% rename from lite/kernels/arm/cast_compute.cc rename to lite/kernels/host/cast_compute.cc index 874bbea0665..edddbc79b57 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/host/cast_compute.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/cast_compute.h" +#include "lite/kernels/host/cast_compute.h" #include #ifdef ENABLE_ARM_FP16 #include "lite/backends/arm/math/fp16/funcs_fp16.h" @@ -20,7 +20,7 @@ namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { template out_type TransOp(in_type in) { @@ -30,7 +30,6 @@ out_type TransOp(in_type in) { void CastCompute::PrepareForRun() {} void CastCompute::Run() { - auto& ctx = this->ctx_->template As(); auto& param = this->Param(); auto input_dims = param.X->dims(); if (param.X->precision() == PrecisionType::kFloat) { @@ -128,7 +127,7 @@ void CastCompute::Run() { int32_t* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); -#ifdef ENABLE_ARM_FP16 +#if defined(ENABLE_ARM_FP16) && defined(LITE_WITH_ARM) } else if (param.in_dtype == 4 && param.out_dtype == 5) { // float16 -> float32 const float16_t* in_data = param.X->data(); @@ -146,21 +145,13 @@ void CastCompute::Run() { } } -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL( - cast, kARM, kAny, kNCHW, paddle::lite::kernels::arm::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .Finalize(); - -#ifdef LITE_BUILD_EXTRA -REGISTER_LITE_KERNEL( - cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + cast, kHost, kAny, kNCHW, paddle::lite::kernels::host::CastCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) .Finalize(); -#endif // LITE_BUILD_EXTRA diff --git a/lite/kernels/arm/cast_compute.h b/lite/kernels/host/cast_compute.h similarity index 90% rename from lite/kernels/arm/cast_compute.h rename to lite/kernels/host/cast_compute.h index cd23d62198a..f344b84e8b2 100644 --- a/lite/kernels/arm/cast_compute.h +++ b/lite/kernels/host/cast_compute.h @@ -20,9 +20,9 @@ namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { -class CastCompute : public KernelLite { +class CastCompute : public KernelLite { public: using param_t = operators::CastParam; @@ -35,7 +35,7 @@ class CastCompute : public KernelLite { private: }; -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/host/compare_compute.cc b/lite/kernels/host/compare_compute.cc index 4eae6fb4156..f5bf30d4a7b 100644 --- a/lite/kernels/host/compare_compute.cc +++ b/lite/kernels/host/compare_compute.cc @@ -591,10 +591,10 @@ REGISTER_LITE_KERNEL( .Finalize(); using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute< - PRECISION(kInt64), + PRECISION(kFloat), paddle::lite::kernels::host::_GreaterEqualFunctor>; REGISTER_LITE_KERNEL( - greater_equal, kHost, kInt64, kAny, greater_equal_float, def) + greater_equal, kHost, kFloat, kAny, greater_equal_float, def_int64) .BindInput("X", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) diff --git a/lite/kernels/host/expand_v2_compute.cc b/lite/kernels/host/expand_v2_compute.cc index 928363f104e..d0e0e2ffdb3 100644 --- a/lite/kernels/host/expand_v2_compute.cc +++ b/lite/kernels/host/expand_v2_compute.cc @@ -140,8 +140,8 @@ REGISTER_LITE_KERNEL(expand_v2, kHost, kInt32, kAny, expand_v2_int32, def) .Finalize(); using expand_v2_int64 = - paddle::lite::kernels::host::ExpandV2Compute; -REGISTER_LITE_KERNEL(expand_v2, kHost, kInt64, kAny, expand_v2_int64, def) + paddle::lite::kernels::host::ExpandV2Compute; +REGISTER_LITE_KERNEL(expand_v2, kHost, kFloat, kAny, expand_v2_int64, def_int64) .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64), diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 94e8cd6b598..2ec9a87ca47 100755 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -10,7 +10,7 @@ endif() add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc) add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc) -add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc) +#add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc) if(WITH_AVX AND AVX_FOUND) add_kernel(conv_depthwise_x86 X86 basic SRCS conv_depthwise.cc) @@ -84,7 +84,7 @@ lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc) lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc) lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc) lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc) -lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc) +#lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc) lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc) lite_cc_test(test_layer_norm_compute_x86 SRCS layer_norm_compute_test.cc) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc) diff --git a/lite/kernels/x86/elementwise_compute.cc b/lite/kernels/x86/elementwise_compute.cc index 61e1ab87985..2aa948cc708 100644 --- a/lite/kernels/x86/elementwise_compute.cc +++ b/lite/kernels/x86/elementwise_compute.cc @@ -392,6 +392,7 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_add, kX86, kFloat, @@ -402,6 +403,7 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_add, kX86, kFloat, @@ -412,6 +414,7 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_add_activation, kX86, @@ -423,6 +426,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, @@ -433,6 +437,7 @@ REGISTER_LITE_KERNEL(elementwise_sub, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, @@ -443,6 +448,7 @@ REGISTER_LITE_KERNEL(elementwise_sub, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, @@ -453,6 +459,7 @@ REGISTER_LITE_KERNEL(elementwise_sub, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_sub_activation, kX86, @@ -464,6 +471,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kX86, kFloat, @@ -474,6 +482,7 @@ REGISTER_LITE_KERNEL(elementwise_mul, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kX86, kFloat, @@ -484,6 +493,7 @@ REGISTER_LITE_KERNEL(elementwise_mul, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kX86, kFloat, @@ -494,6 +504,7 @@ REGISTER_LITE_KERNEL(elementwise_mul, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_mul_activation, kX86, @@ -505,6 +516,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_div, kX86, kFloat, @@ -515,6 +527,7 @@ REGISTER_LITE_KERNEL(elementwise_div, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_div, kX86, kFloat, @@ -525,6 +538,7 @@ REGISTER_LITE_KERNEL(elementwise_div, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_div, kX86, kFloat, @@ -535,6 +549,7 @@ REGISTER_LITE_KERNEL(elementwise_div, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_div_activation, kX86, @@ -546,6 +561,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL( elementwise_floordiv, kX86, @@ -557,6 +573,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL( elementwise_floordiv, kX86, @@ -568,6 +585,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL( elementwise_floordiv, kX86, @@ -579,6 +597,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_pow, kX86, kFloat, @@ -589,6 +608,7 @@ REGISTER_LITE_KERNEL(elementwise_pow, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_pow, kX86, kFloat, @@ -599,6 +619,7 @@ REGISTER_LITE_KERNEL(elementwise_pow, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_pow, kX86, kFloat, @@ -609,6 +630,7 @@ REGISTER_LITE_KERNEL(elementwise_pow, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mod, kX86, kFloat, @@ -619,6 +641,7 @@ REGISTER_LITE_KERNEL(elementwise_mod, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mod, kX86, kFloat, @@ -629,6 +652,7 @@ REGISTER_LITE_KERNEL(elementwise_mod, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_max, kX86, kFloat, @@ -639,6 +663,7 @@ REGISTER_LITE_KERNEL(elementwise_max, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_max, kX86, kFloat, @@ -649,6 +674,7 @@ REGISTER_LITE_KERNEL(elementwise_max, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_max, kX86, kFloat, @@ -659,6 +685,7 @@ REGISTER_LITE_KERNEL(elementwise_max, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_max_activation, kX86, @@ -670,6 +697,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_min_activation, kX86, @@ -681,6 +709,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_min, kX86, kFloat, @@ -691,6 +720,7 @@ REGISTER_LITE_KERNEL(elementwise_min, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_min, kX86, kFloat, @@ -701,6 +731,7 @@ REGISTER_LITE_KERNEL(elementwise_min, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_min, kX86, kFloat, diff --git a/lite/kernels/x86/elementwise_op_function.h b/lite/kernels/x86/elementwise_op_function.h new file mode 100644 index 00000000000..5ab94ff88bb --- /dev/null +++ b/lite/kernels/x86/elementwise_op_function.h @@ -0,0 +1,802 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "lite/backends/x86/fluid/eigen.h" +#include "lite/backends/x86/fluid/for_range.h" +#include "lite/backends/x86/fluid/transform.h" +#include "lite/backends/x86/math/math_function.h" +#include "lite/utils/log/cp_logging.h" +#include "lite/utils/variant.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/* + * Out = X ⊙ Y + * If Y's shape does not match X' shape, they will be reshaped. + * For example: + * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + * pre=2, n=3*4, post=5 + * x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) + * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) + * pre=2*3, n=4*5, post=1 + * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) + * + * New parameter: *mid_flag* is added to solve m*n*k & m*1*k + * broadcast cases. + */ +inline void get_mid_dims(const lite::DDim &x_dims, + const lite::DDim &y_dims, + const int axis, + int *pre, + int *n, + int *post, + int *mid_flag = NULL) { + *pre = 1; + *n = 1; + *post = 1; + if (mid_flag != NULL) { + *mid_flag = 0; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (size_t i = 0; i < y_dims.size(); ++i) { + if (x_dims[i + axis] != y_dims[i]) { + CHECK_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, true) + << "Broadcast y or x dimension is not 1."; + *mid_flag = 1; + return; + } + (*n) *= y_dims[i]; + } + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } else { // for fused_elementwise_activation_op. keep the old version. + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + + for (size_t i = 0; i < y_dims.size(); ++i) { + CHECK_EQ(x_dims[i + axis], y_dims[i]) << "Broadcast dimension mismatch."; + (*n) *= y_dims[i]; + } + + for (size_t i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } +} + +inline lite::DDim trim_trailing_singular_dims(const lite::DDim &dims) { + // Remove trailing dimensions of size 1 for y + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + if (actual_dims_size == dims.size()) return dims; + std::vector trim_dims; + trim_dims.resize(actual_dims_size); + for (size_t i = 0; i < actual_dims_size; ++i) { + trim_dims[i] = dims[i]; + } + if (trim_dims.size() == 0) { + return lite::DDim(); + } + lite::DDim actual_dims = lite::DDim(trim_dims); + return actual_dims; +} + +template +class RowwiseTransformIterator; + +template +class MidWiseTransformIterator; + +// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17 +template +class RowwiseTransformIterator + : public std::iterator { + public: + RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + + RowwiseTransformIterator &operator++() { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + return *this; + } + + RowwiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + + return *this; + } + + bool operator==( + const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T &operator*() { return ptr_[i_]; } + + private: + const T *ptr_; + int i_; + int64_t n_; +}; + +template +class MidWiseTransformIterator + : public std::iterator { + public: + MidWiseTransformIterator(const T *ptr, int n, int post) + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} + + MidWiseTransformIterator &operator++() { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + return *this; + } + + MidWiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + } + return *this; + } + + bool operator==( + const MidWiseTransformIterator &rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const MidWiseTransformIterator &rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T &operator*() { return ptr_[i_]; } + + private: + const T *ptr_; + int64_t i_; + int64_t j_; + int64_t n_; + int64_t post_; +}; + +template +class TransformFunctor { + public: + TransformFunctor(const lite::Tensor *x, + const lite::Tensor *y, + lite::Tensor *z, + const lite::Context &ctx, + Functor func, + const bool is_xsize_larger = true) + : x_(x->template data()), + y_(y->template data()), + z_(z->mutable_data()), + nx_(x->numel()), + ctx_(ctx), + func_(func), + is_xsize_larger_(is_xsize_larger) { + if (is_xsize_larger_ == false) { + nx_ = y->numel(); + } + } + + inline void Run() const { + lite::fluid::Transform trans; + trans(ctx_, x_, x_ + nx_, y_, z_, func_); + } + + inline void RunRowWise(int n, int pre) const { + lite::fluid::Transform trans; + if (is_xsize_larger_) { + trans(ctx_, + x_, + x_ + nx_, + RowwiseTransformIterator(y_, n), + z_, + func_); + } else { + trans(ctx_, + y_, + y_ + nx_, + RowwiseTransformIterator(x_, n), + z_, + func_); + } + } + + inline void RunMidWise(int n, int pre, int post) const { + lite::fluid::Transform trans; + if (is_xsize_larger_) { + trans(ctx_, + x_, + x_ + nx_, + MidWiseTransformIterator(y_, n, post), + z_, + func_); + } else { + trans(ctx_, + y_, + y_ + nx_, + MidWiseTransformIterator(x_, n, post), + z_, + func_); + } + } + + private: + const T *x_; + const T *y_; + OutType *z_; + int64_t nx_; + const lite::Context &ctx_; + Functor func_; + bool is_xsize_larger_; +}; + +inline void GetBroadcastDimsArrays(const DDim &x_dims, + const DDim &y_dims, + int *x_dims_array, + int *y_dims_array, + int *out_dims_array, + const int max_dim, + const int axis) { + CHECK_GE(axis, 0) << "Axis should be great than or equal to 0."; + CHECK_LT(axis, max_dim) << "Axis should be less than max(x_dim, y_dim)."; + + if (x_dims.size() > y_dims.size()) { + std::fill(y_dims_array, y_dims_array + axis, 1); + if (axis + y_dims.size() < max_dim) { + std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1); + } + for (int i = 0; i < x_dims.size(); i++) x_dims_array[i] = x_dims[i]; + for (int i = 0; i < y_dims.size(); i++) + *(y_dims_array + axis + i) = y_dims[i]; + } else { + std::fill(x_dims_array, x_dims_array + axis, 1); + if (axis + x_dims.size() < max_dim) { + std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1); + } + for (int i = 0; i < x_dims.size(); i++) + *(x_dims_array + axis + i) = x_dims[i]; + for (int i = 0; i < y_dims.size(); i++) *(y_dims_array + i) = y_dims[i]; + } + + for (int i = 0; i < max_dim; i++) { + CHECK_EQ(x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || + y_dims_array[i] <= 1, + true) + << "Broadcast dimension mismatch. Operands could not be broadcast."; + + if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || + (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { + out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + } else { + out_dims_array[i] = -1; + } + } +} + +inline int GetElementwiseIndex(const int *x_dims_array, + const int max_dim, + const int *index_array) { + int index_ = 0; + for (int i = 0; i < max_dim; i++) { + if (x_dims_array[i] > 1) { + index_ = index_ * x_dims_array[i] + index_array[i]; + } + } + return index_; +} + +inline void UpdateElementwiseIndexArray(const int *out_dims_array, + const int max_dim, + int *index_array) { + for (int i = max_dim - 1; i >= 0; --i) { + ++index_array[i]; + if (index_array[i] >= out_dims_array[i]) { + index_array[i] -= out_dims_array[i]; + } else { + break; + } + } +} + +template +void CommonForwardBroadcastCPU(const Tensor *x, + const Tensor *y, + Tensor *z, + int *x_dims_array, + int *y_dims_array, + int *out_dims_array, + int max_dim, + Functor func, + const bool is_xsize_larger = true) { + std::vector index_array(max_dim, 0); + const T *x_data = x->data(); + const T *y_data = y->data(); + CHECK_EQ(x_data != nullptr, true) << "The input X should not be empty."; + CHECK_EQ(y_data != nullptr, true) << "The input Y should not be empty."; + + OutType *out_data = z->mutable_data(); + const int out_size = std::accumulate( + out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); + int x_index, y_index; + for (int out_index = 0; out_index < out_size; ++out_index) { + x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); + y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); + if (is_xsize_larger) { + out_data[out_index] = func(x_data[x_index], y_data[y_index]); + } else { + out_data[out_index] = func(y_data[y_index], x_data[x_index]); + } + + UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); + } +} + +template +void CommonElementwiseBroadcastForward(const Tensor *x, + const Tensor *y, + Tensor *z, + const DDim &x_dims, + const DDim &y_dims, + Functor func, + int axis, + const bool is_xsize_larger = true) { + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + CHECK_GE(axis, 0) << "Axis should be great than or equal to 0."; + CHECK_LT(axis, max_dim) << "Axis should be less than max(x_dim, y_dim)."; + + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + + CommonForwardBroadcastCPU(x, + y, + z, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + func, + is_xsize_larger); +} + +template +void ElementwiseComputeEx(const lite::Context &ctx, + const lite::Tensor *x, + const lite::Tensor *y, + int axis, + Functor func, + lite::Tensor *z) { + auto x_dims = x->dims(); + auto y_dims = y->dims(); + bool is_xsize_larger = true; + int max_dim = x_dims.size(); + if (x_dims.size() < y_dims.size()) { + is_xsize_larger = false; + max_dim = y_dims.size(); + } + TransformFunctor functor( + x, y, z, ctx, func, is_xsize_larger); + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int tmp = std::abs(static_cast(x_dims.size()) - + static_cast(y_dims.size())); + axis = (axis == static_cast(-1) ? tmp : axis); + + CHECK_GE(axis, 0) << "Axis should be great than or equal to 0."; + CHECK_LT(axis, max_dim) << "Axis should be less than max(x_dim, y_dim)."; + + int pre, n, post, is_run_common_broadcast, axis_trim = 0; + if (is_xsize_larger) { + auto y_dims_trimed = trim_trailing_singular_dims(y_dims); + axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; + get_mid_dims(x_dims, + y_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } else { + auto x_dims_trimed = trim_trailing_singular_dims(x_dims); + axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; + get_mid_dims(y_dims, + x_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } + // special case for common implementation. + // case 1: x=[2,3,1,5], y=[2,1,4,1] + // case 2: x=[2,3,4], y=[1,1,4] + if (is_run_common_broadcast == 1) { + CommonElementwiseBroadcastForward( + x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); + return; + } + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } +} + +// FusedElemwiseAndAct +// --- forward +template +struct FusedElemwiseAndActNoBroadcast { + HOSTDEVICE void operator()(size_t i) { + T y_val = y_[i]; + T x_val = x_[i]; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val); + intermediate_out_[i] = intermeidiate_out; + out_[i] = + compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out_[i] = compound_functor_.GetOut(x_val, y_val); + } + } + + const T *x_; + const T *y_; + CompoundFunctor compound_functor_; + T *out_; + T *intermediate_out_; +}; + +// FusedElemwiseAndActBroadcast1: +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2, +// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20) +template +static void FusedElemwiseAndActBroadcast1CPU(const T *x, + const T *y, + CompoundFunctor compound_functor, + int h, + int w, + T *out, + T *intermediate_out) { + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int offset = i * w + j; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = + compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } +} + +// FusedElemwiseAndActBroadcast2 +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1, +// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1) +// pre = 2, n = 12, post = 5 +template +static void FusedElemwiseAndActBroadcast2CPU(const T *x, + const T *y, + int pre, + int n, + int post, + CompoundFunctor compound_functor, + T *out, + T *intermediate_out) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int offset = i * n * post + j * post + k; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + + if (KeepIntermediateOut) { + T intermeidiate_out = + compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = compound_functor.GetOutUseIntermediateOut( + x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } + } +} + +template +void FusedElemwiseAndActComputeNoBroadcast(const lite::Context &ctx, + const lite::DDim &x_dim, + const lite::Tensor &x, + const lite::Tensor &y, + CompoundFunctor compound_functor, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + size_t N = static_cast(x_dim.production()); + + lite::fluid::ForRange for_range(ctx, N); + + for_range( + FusedElemwiseAndActNoBroadcast{ + x.data(), + y.data(), + compound_functor, + out->template mutable_data(), + intermediate_out == nullptr + ? nullptr + : intermediate_out->template mutable_data()}); +} + +template +void FusedElemwiseAndActComputeWithBroadcast(const lite::Context &ctx, + const lite::DDim &x_dim, + const lite::DDim &y_dim_untrimed, + const lite::Tensor &x, + const lite::Tensor &y, + CompoundFunctor compound_functor, + int axis, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); + auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + + if (post == 1) { + int h = pre; + int w = n; + FusedElemwiseAndActBroadcast1CPU( + x.data(), + y.data(), + compound_functor, + h, + w, + out->template mutable_data(), + intermediate_out == nullptr + ? nullptr + : intermediate_out->template mutable_data()); + + } else { + FusedElemwiseAndActBroadcast2CPU( + x.data(), + y.data(), + pre, + n, + post, + compound_functor, + out->template mutable_data(), + intermediate_out == nullptr + ? nullptr + : intermediate_out->template mutable_data()); + } +} + +template +void FusedElemwiseAndActComputeEx(const lite::Context &ctx, + const lite::Tensor &x, + const lite::Tensor &y, + int axis, + CompoundFunctor compound_functor, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + if (KeepIntermediateOut) { + CHECK(intermediate_out) << "The save_intermediate_out is opened, " + "intermediate_out should not be nullptr."; + } + + const lite::DDim &x_dim = x.dims(); + const lite::DDim &y_dim = y.dims(); + if (x.dims() == y.dims()) { + FusedElemwiseAndActComputeNoBroadcast( + ctx, x_dim, x, y, compound_functor, out, intermediate_out); + } else { + // Whether the shape of Y is a continuous subsequence of X, + // For more information please refer to the op's introduction. + bool bcast_y = x.dims().size() >= y.dims().size(); + if (x.dims().size() == y.dims().size()) { + for (int i = 0; i < x.dims().size(); ++i) { + if (x.dims()[i] < y.dims()[i]) { + bcast_y = false; + break; + } + } + } + + // z = f1(x, f2(y)) + // z = f1(f2(x, y)) + if (bcast_y) { // Y should be broadcast. + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape + // of Y. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of X. + FusedElemwiseAndActComputeWithBroadcast( + ctx, + x_dim /*OutShape*/, + y_dim, + x, + y, + compound_functor, + axis, + out, + intermediate_out); + } else { + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape + // of Out. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of Y. + FusedElemwiseAndActComputeWithBroadcast( + ctx, + y_dim /*OutShape*/, + x_dim, + x, + y, + compound_functor, + axis, + out, + intermediate_out); + } + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/rnn_compute.cc b/lite/kernels/x86/rnn_compute.cc index 9fb996366b9..438c08f246e 100644 --- a/lite/kernels/x86/rnn_compute.cc +++ b/lite/kernels/x86/rnn_compute.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/x86/rnn_compute.h" +#include "lite/backends/x86/math/rnn.h" #include #include #include @@ -20,7 +20,6 @@ #include "lite/backends/host/math/split.h" #include "lite/backends/x86/math/blas.h" #include "lite/backends/x86/math/concat_and_split.h" -#include "lite/backends/x86/math/rnn.h" #include "lite/kernels/x86/rnn_compute.h" namespace paddle { @@ -28,26 +27,28 @@ namespace lite { namespace kernels { namespace x86 { -#define RUN_LSTM_LAYER(x, y, z, w) \ - runLSTMLayer(&ctx, \ - input_temp_holder, \ - parameter_lists[x], \ - init_h_unbind, \ - init_c_unbind, \ - sequence_length, \ - &last_h_unbind, \ - &last_c_unbind, \ - y, \ - x, \ - &gate_value, \ - z, \ - w) - -void reset_parameter_vector(const std::vector& raw_params_vec, - const int& num_layers, - const int& gate_num, - const bool& is_bidirec, - std::vector>* params_vec) { +#define RUN_RNN_LAYER(x, y, z, w) \ + RunRnnLayer(&ctx, \ + input_temp_holder, \ + parameter_lists[x], \ + init_h_unbind, \ + init_c_unbind, \ + sequence_length, \ + &last_h_unbind, \ + &last_c_unbind, \ + y, \ + x, \ + &gate_value, \ + z, \ + w, \ + mode) + +static void reset_parameter_vector( + const std::vector& raw_params_vec, + const int& num_layers, + const int& gate_num, + const bool& is_bidirec, + std::vector>* params_vec) { // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers @@ -75,19 +76,30 @@ void reset_parameter_vector(const std::vector& raw_params_vec, } } -void SwapPoniter(Tensor** a, Tensor** b) { +static void SwapPoniter(Tensor** a, Tensor** b) { Tensor* c = *a; *a = *b; *b = c; } -void preprocess(X86Context* ctx, - const Tensor* input, - const Tensor& weight, - const Tensor& bias_ih, - const Tensor& bias_hh, - Tensor* cache_input, - bool is_test) { +/****************************************************** +input: + ctx:context, + input:(3D)time_step, batch, input_size, + weight:(2D)hidden_size, input_size, + bias_ih, + bias_hh, + mode:LSTM, GRU +output: + cache_input:(3D)time_step, batch, hidden_size +*******************************************************/ +static void preprocess(X86Context* ctx, + const Tensor* input, + const Tensor& weight, + const Tensor& bias_ih, + const Tensor& bias_hh, + std::string mode, + Tensor* cache_input) { const int& hidden_size = weight.dims()[0]; int time_step = input->dims()[0]; int batch = input->dims()[1]; @@ -106,23 +118,175 @@ void preprocess(X86Context* ctx, int k = input_dims[2]; int n = weight_input_dims[0]; - paddle::lite::x86::math::Blas matmul(*ctx); + lite::x86::math::Blas matmul(*ctx); matmul.GEMM( false, true, m, n, k, 1.f, i_data, k, w_data, k, 0.f, o_data, n); lite::x86::math::fill_bias_fc(o_data, bias_ih.data(), m, n); - lite::x86::math::fill_bias_fc(o_data, bias_hh.data(), m, n); + + if ("GRU" == mode) { + Tensor bias_tmp_hh; + bias_tmp_hh.Resize(bias_hh.dims()); + auto bias_ptr = bias_tmp_hh.mutable_data(); + auto bias_src = bias_hh.data(); + int bias_offt = bias_hh.numel() / 3 * 2; + std::memcpy(bias_ptr, bias_src, bias_offt * sizeof(float)); + std::memset( + bias_ptr + bias_offt, 0, (bias_hh.numel() - bias_offt) * sizeof(float)); + lite::x86::math::fill_bias_fc(o_data, bias_tmp_hh.data(), m, n); + } else { + lite::x86::math::fill_bias_fc(o_data, bias_hh.data(), m, n); + } +} + +/****************************************************** +input: + ctx:context, + init_h:(2D), + init_c:(2D), + mask_tensor:(1D)input->dims()[1], + mode:LSTM, GRU +output: + output:(2D)output->dims()[1], output->dims()[2], + last_h:(2D), + last_c:(2D) +*******************************************************/ +static void postprocess(X86Context* ctx, + Tensor* output, + const Tensor* init_h, + const Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + const Tensor& mask_tensor, + std::string mode) { + Tensor mask_broadcast_1; + mask_broadcast_1.Resize(mask_tensor.dims()); + auto mask_ptr_1 = mask_broadcast_1.mutable_data(); + auto mask_ptr = mask_tensor.data(); + auto out_ptr = output->mutable_data(); + auto cur_h_ptr = last_h->mutable_data(); + auto pre_h_ptr = init_h->data(); + int offset = 0; + + // out = out * mask_broadcast + // curr_h = out * mask_broadcast + pre_h * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + mask_ptr_1[i] = 1 - mask_ptr[i]; + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + out_ptr[offset] *= mask_ptr[i]; + cur_h_ptr[offset] = out_ptr[offset] + pre_h_ptr[offset] * mask_ptr_1[i]; + } + } + if ("LSTM" == mode) { + auto pre_c_ptr = init_c->data(); + auto cur_c_ptr = last_c->mutable_data(); + + // curr_c = curr_c * mask_broadcast + pre_c * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + cur_c_ptr[offset] = + cur_c_ptr[offset] * mask_ptr[i] + pre_c_ptr[offset] * mask_ptr_1[i]; + } + } + } +} + +static DDim get_stride(const DDim& ddim) { + DDim strides; + strides[ddim.size() - 1] = 1; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i + 1]; + } + return strides; +} + +template +static void TransposeNormal(const Tensor& in, + Tensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = get_stride(in.dims()); + auto out_stride = get_stride(out->dims()); + const T* in_ptr = in.data(); + T* out_ptr = out->mutable_data(); + + auto transpose_helper = [&](int64_t beg, int64_t end) { + for (int64_t out_idx = beg; out_idx < end; ++out_idx) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + // calculate the input index + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride[i]; + tmp_idx -= coordinate * out_stride[i]; + in_idx += coordinate * in_stride[axis[i]]; + } + out_ptr[out_idx] = in_ptr[in_idx]; + } + }; + transpose_helper(0, out->numel()); +} + +/****************************************************** +input: + sequence_length, + is_reverse +output: + mask_matrix, + min_seq_len +******************************************************/ +static void create_mask_matrix(const Tensor* sequence_length, + Tensor* mask_matrix, + const bool& is_reverse, + int* min_seq_len) { + // Tensor to vector + std::vector seq_len_vec; + seq_len_vec.resize(sequence_length->numel()); + std::memcpy(&seq_len_vec[0], + sequence_length->data(), + sequence_length->numel() * sizeof(int)); + + const int& table_width = mask_matrix->dims()[0]; + Tensor temp; + DDimLite dims( + std::vector{mask_matrix->dims()[1], mask_matrix->dims()[0]}); + temp.Resize(dims); + float* data_temp = temp.mutable_data(); + std::fill(data_temp, data_temp + mask_matrix->numel(), 1.f); + *min_seq_len = table_width; + for (unsigned int i = 0; i < seq_len_vec.size(); i++) { + // reset the mask matrix + *min_seq_len = std::min(seq_len_vec[i], *min_seq_len); + if (seq_len_vec[i] == table_width) { + continue; + } + if (is_reverse) { + std::fill(data_temp + i * table_width, + data_temp + (i + 1) * table_width - seq_len_vec[i], + 0.f); + } else { + std::fill(data_temp + i * table_width + seq_len_vec[i], + data_temp + (i + 1) * table_width, + 0.f); + } + } + mask_matrix->mutable_data(); + std::vector trans_vec; + trans_vec.emplace_back(1); + trans_vec.emplace_back(0); + TransposeNormal(temp, mask_matrix, trans_vec); } -void cell(X86Context* ctx, - Tensor* input, - Tensor* weight_hh, - Tensor* init_h, - Tensor* init_c, - Tensor* last_h, - Tensor* last_c, - Tensor* last_c_act, - Tensor* output, - const Tensor* bias_hh) { +static void lstm_cell(X86Context* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh) { auto h_dims = init_h->dims(); auto weight_input_dims = weight_hh->dims(); int m = h_dims[0]; @@ -136,7 +300,7 @@ void cell(X86Context* ctx, tmp_gate.Resize(input->dims()); auto tmp_data = tmp_gate.mutable_data(); - paddle::lite::x86::math::Blas matmul(*ctx); + lite::x86::math::Blas matmul(*ctx); matmul.GEMM( false, true, m, n, k, 1.f, h_data, k, w_data, k, 0.f, tmp_data, n); for (int i = 0; i < input->dims()[0] * input->dims()[1]; i++) { @@ -183,19 +347,72 @@ void cell(X86Context* ctx, 1); } -void runLSTMLayer(X86Context* ctx, - const Tensor* input, - std::vector vec, - std::vector init_h, - std::vector init_c, - const Tensor* sequence_length, - std::vector* last_h_ptr, - std::vector* last_c_ptr, - Tensor* output, - int layer_idx, - Tensor* gate_value, - bool is_bidirect, - int offset) { +static void gru_cell(X86Context* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh, + Tensor* weight_hh_gru) { + auto h_dims = init_h->dims(); + auto weight_gru_dims = weight_hh_gru->dims(); + int m = h_dims[0]; + int k = h_dims[1]; + int n = weight_gru_dims[0]; + auto i_data = input->data(); + auto w_gru = weight_hh_gru->data(); + auto h_data = init_h->data(); + + Tensor tmp_gate; + tmp_gate.Resize(input->dims()); + auto tmp_data = tmp_gate.mutable_data(); + + lite::x86::math::Blas matmul(*ctx); + matmul.GEMM( + false, true, m, n, k, 1.f, h_data, k, w_gru, k, 0.f, tmp_data, n); + for (int i = 0; i < input->dims()[0] * input->dims()[1]; i++) { + tmp_data[i] += i_data[i]; + } + + size_t frame_size = init_h->dims()[1]; + size_t batch_size = init_h->dims()[0]; + + lite::x86::math::GRUMetaValue gru_value; + gru_value.gate_weight = weight_hh->data(); + gru_value.state_weight = + weight_hh->data() + 2 * frame_size * frame_size; + gru_value.reset_bias = bias_hh->data() + 2 * frame_size; + + gru_value.gate_value = tmp_data; + gru_value.reset_output_value = last_c->mutable_data(); + gru_value.output_value = output->mutable_data(); + gru_value.prev_out_value = init_h->data(); + + auto gate_act = lite_api::ActivationType::kSigmoid_v2; + auto cand_act = lite_api::ActivationType::kTanh_v2; + + lite::x86::math::RnnGruUnitFunctorV2::compute( + ctx, gru_value, frame_size, batch_size, cand_act, gate_act); +} + +static void RunRnnLayer(X86Context* ctx, + const Tensor* input, + std::vector vec, + std::vector init_h, + std::vector init_c, + const Tensor* sequence_length, + std::vector* last_h_ptr, + std::vector* last_c_ptr, + Tensor* output, + int layer_idx, + Tensor* gate_value, + bool is_bidirect, + int offset, + std::string mode) { bool is_reverse = false; if (is_bidirect) { layer_idx = 2 * layer_idx + offset; @@ -210,25 +427,24 @@ void runLSTMLayer(X86Context* ctx, vec[0 + offset * 4], vec[2 + offset * 4], vec[3 + offset * 4], - gate_value, - true); + mode, + gate_value); + std::vector input_tensors, output_tensors; std::vector input_tensors_t, output_tensors_t; - std::vector stride1, stride2; - input_tensors.resize(gate_value->dims()[0]); // time_step + std::vector stride1, stride2, stride3; + input_tensors.resize(gate_value->dims()[0]); output_tensors.resize(output->dims()[0]); - // alloc input + // unbind for (int i = 0; i < gate_value->dims()[0]; i++) { stride1.push_back(1); - int dim1 = gate_value->dims()[1]; // batch - int dim2 = gate_value->dims()[2]; // hidden + int dim1 = gate_value->dims()[1]; + int dim2 = gate_value->dims()[2]; DDimLite dims(std::vector{dim1, dim2}); input_tensors[i].Resize(dims); input_tensors_t.push_back(&input_tensors[i]); } - - // alloc output for (int i = 0; i < output->dims()[0]; i++) { stride2.push_back(1); int dim1 = output->dims()[1]; @@ -237,32 +453,55 @@ void runLSTMLayer(X86Context* ctx, output_tensors[i].Resize(dims); output_tensors_t.push_back(&output_tensors[i]); } - lite::host::math::split( gate_value->data(), input_tensors_t, 0, stride1); lite::host::math::split(output->data(), output_tensors_t, 0, stride2); auto sd = output->mutable_data(); if (is_reverse) { + // don't need to reverse input_tensors_t becauese of unuseful std::reverse(input_tensors.begin(), input_tensors.end()); } bool has_sequence_length = false; + if (sequence_length != nullptr) { + has_sequence_length = true; + } + // unbind + Tensor mask_matrix; + std::vector mask_vec; + std::vector mask_tensor_list; + int mask_min_length = time_step; + /* - TODO has_sequence_length + to be verifying! */ - - int mask_min_length = time_step; + if (has_sequence_length) { + mask_matrix.Resize(DDimLite({time_step, input->dims()[1]})); + create_mask_matrix( + sequence_length, &mask_matrix, is_reverse, &mask_min_length); + for (int i = 0; i < time_step; i++) { + stride3.push_back(1); + DDimLite ddims(std::vector{input->dims()[1]}); + mask_vec[i].Resize(ddims); + mask_tensor_list.push_back(&mask_vec[i]); + } + lite::host::math::split( + mask_matrix.data(), mask_tensor_list, 0, stride3); + } if (is_reverse) { mask_min_length = mask_min_length - time_step + 1; } - bool has_allocate_mem_c = false; + bool has_use_last_h_holder = false; const int& reverse_flag = is_reverse ? -1 : 1; + bool has_allocate_mem_c = false; + + // define the init_h holder for the swap Tensor init_h_temp; + init_h_temp.Resize(init_h[layer_idx].dims()); init_h_temp.CopyDataFrom(init_h[layer_idx]); Tensor* init_h_holder = &init_h_temp; Tensor* last_h_holder = nullptr; - if (0 < mask_min_length) { last_h_holder = &(output_tensors[0]); } else { @@ -275,38 +514,83 @@ void runLSTMLayer(X86Context* ctx, Tensor init_c_temp; Tensor* last_c_holder = nullptr; Tensor last_c_temp; - last_c_holder = &(*last_c_ptr)[layer_idx]; - init_c_temp_holder = &init_c[layer_idx]; + + if ("LSTM" == mode) { + last_c_holder = &(*last_c_ptr)[layer_idx]; + init_c_temp_holder = &init_c[layer_idx]; + } else if ("GRU" == mode) { + // for reset output value + last_c_temp.Resize(init_h[layer_idx].dims()); + last_c_temp.mutable_data(); + last_c_holder = &last_c_temp; + } + + Tensor weight_hh_tmp; // for gru + std::vector weight_hh_tmp_ubind; + std::vector weight_hh_tmp_ubind_t; + std::vector stride_w; + if ("GRU" == mode) { + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + weight_hh_tmp.mutable_data(); + weight_hh_tmp.CopyDataFrom(vec[1 + offset * 4]); + int size = weight_hh_tmp.numel() / 3; + std::memset(weight_hh_tmp.mutable_data() + size * 2, + 0, + size * sizeof(float)); + } for (int i = 0; i < time_step; i++) { bool in_mask = (reverse_flag * i) >= mask_min_length; if (i > 0) { if (!has_allocate_mem_c) { - init_c_temp.Resize(init_h[layer_idx].dims()); - init_c_temp.mutable_data(); - init_c_holder = &init_c_temp; + if (("LSTM" == mode) || ("GRU" == mode)) { + init_c_temp.Resize(init_h[layer_idx].dims()); + init_c_temp.mutable_data(); + init_c_holder = &init_c_temp; + } has_allocate_mem_c = true; } SwapPoniter(&init_c_holder, &last_c_holder); init_c_temp_holder = init_c_holder; } - // LSTMCELL - cell(ctx, - &input_tensors[i], - &vec[1 + offset * 4], - init_h_holder, - init_c_temp_holder, - last_h_holder, - last_c_holder, - nullptr, - &output_tensors[i], - &vec[3 + offset * 4]); + if ("LSTM" == mode) { + lstm_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4]); + } else if ("GRU" == mode) { + gru_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4], + &weight_hh_tmp); + } + /* + to be verifying! + */ if (in_mask) { - /* - TODO in_mask - */ + postprocess(ctx, + &output_tensors[i], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + mask_vec[i], + mode); } // prepare next step @@ -322,6 +606,9 @@ void runLSTMLayer(X86Context* ctx, SwapPoniter(&init_h_holder, &last_h_holder); } } + + // unbind vector and source are not in the same address, need copy + // different from paddle if (is_reverse) { std::reverse(output_tensors.begin(), output_tensors.end()); } @@ -339,7 +626,7 @@ void runLSTMLayer(X86Context* ctx, } else { (*last_h_ptr)[layer_idx].CopyDataFrom(output_tensors[time_step - 1]); } - if (time_step % 2 == 0) { + if ((0 == (time_step % 2)) && ("LSTM" == mode)) { (*last_c_ptr)[layer_idx].CopyDataFrom(*last_c_holder); } } @@ -356,12 +643,24 @@ void RnnCompute::Run() { bool is_bidirec = param.is_bidirec; int num_layers = param.num_layers; const Tensor* sequence_length = param.SequenceLength; + int gate_num = 0; - state[0]->mutable_data(); - state[1]->mutable_data(); + if ("LSTM" == mode) { + gate_num = 4; + } else if ("GRU" == mode) { + gate_num = 3; + } else { + LOG(FATAL) << "X86 RNN ERROR: unsupport mode except gru and lstm," + " present mode is " + << mode; + return; + } - // lstmCell begin - int gate_num = 4; + state[0]->mutable_data(); + if ("LSTM" == mode) { + state[1]->mutable_data(); + } + // reset the parameter to sorted order and allocate the memory std::vector> parameter_lists; parameter_lists.reserve(num_layers); reset_parameter_vector( @@ -375,31 +674,57 @@ void RnnCompute::Run() { last_c_unbind; std::vector init_h_unbind_t, init_c_unbind_t, last_h_unbind_t, last_c_unbind_t; - init_h_unbind.resize(4); - init_c_unbind.resize(pre_state[1]->dims()[0]); + init_h_unbind.resize(pre_state[0]->dims()[0]); last_h_unbind.resize(state[0]->dims()[0]); - last_c_unbind.resize(state[1]->dims()[0]); - std::vector stride; + + if ("LSTM" == mode) { + init_c_unbind.resize(pre_state[1]->dims()[0]); + last_c_unbind.resize(state[1]->dims()[0]); + } + std::vector stride1, stride2; + + // unbind for (int i = 0; i < pre_state[0]->dims()[0]; i++) { - stride.push_back(1); + stride1.push_back(1); int dim1 = pre_state[0]->dims()[1]; int dim2 = pre_state[0]->dims()[2]; DDimLite dims(std::vector{dim1, dim2}); init_h_unbind[i].Resize(dims); - init_c_unbind[i].Resize(dims); last_h_unbind[i].Resize(dims); - last_c_unbind[i].Resize(dims); init_h_unbind_t.push_back(&init_h_unbind[i]); - init_c_unbind_t.push_back(&init_c_unbind[i]); last_h_unbind_t.push_back(&last_h_unbind[i]); - last_c_unbind_t.push_back(&last_c_unbind[i]); } lite::host::math::split( - pre_state[0]->data(), init_h_unbind_t, 0, stride); - lite::host::math::split( - pre_state[1]->data(), init_c_unbind_t, 0, stride); - lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride); - lite::host::math::split(state[1]->data(), last_c_unbind_t, 0, stride); + pre_state[0]->data(), init_h_unbind_t, 0, stride1); + lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride1); + + if ("LSTM" == mode) { + for (int i = 0; i < pre_state[1]->dims()[0]; i++) { + stride2.push_back(1); + int dim1 = pre_state[1]->dims()[1]; + int dim2 = pre_state[1]->dims()[2]; + DDimLite dims(std::vector{dim1, dim2}); + init_c_unbind[i].Resize(dims); + last_c_unbind[i].Resize(dims); + init_c_unbind_t.push_back(&init_c_unbind[i]); + last_c_unbind_t.push_back(&last_c_unbind[i]); + } + lite::host::math::split( + pre_state[1]->data(), init_c_unbind_t, 0, stride2); + lite::host::math::split( + state[1]->data(), last_c_unbind_t, 0, stride2); + } + + std::vector output_vec(2); + int time_step = input->dims()[0]; + int batch_size = input->dims()[1]; + int hidden_size = output->dims()[2]; + if (is_bidirec) { + for (int i = 0; i < 2; ++i) { + output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); + output_vec[i].mutable_data(); + } + } for (int i = 0; i < num_layers; i++) { if (i > 0) { @@ -418,25 +743,18 @@ void RnnCompute::Run() { } if (is_bidirec) { - std::vector output_vec(2); - int time_step = input->dims()[0]; - int batch_size = input->dims()[1]; - int hidden_size = output->dims()[2]; - for (int i = 0; i < 2; ++i) { - output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); - output_vec[i].mutable_data(); - } - - RUN_LSTM_LAYER(i, &output_vec[0], true, 0); - RUN_LSTM_LAYER(i, &output_vec[1], true, 1); - - paddle::lite::x86::math::ConcatFunctor - concat_x86; - concat_x86(ctx, output_vec, 2, output); + RUN_RNN_LAYER(i, &output_vec[0], true, 0); + RUN_RNN_LAYER(i, &output_vec[1], true, 1); + lite::x86::math::ConcatFunctor concat_x86; + concat_x86(ctx, output_vec, 2, output_holder); } else { - RUN_LSTM_LAYER(i, output_holder, false, 0); + RUN_RNN_LAYER(i, output_holder, false, 0); } } + // output_holder != output + if (num_layers % 2 == 0) { + output->CopyDataFrom(*output_holder); + } } } // namespace x86 diff --git a/lite/kernels/x86/rnn_compute.h b/lite/kernels/x86/rnn_compute.h index e4dff2a3014..70d43a3c5bc 100644 --- a/lite/kernels/x86/rnn_compute.h +++ b/lite/kernels/x86/rnn_compute.h @@ -29,7 +29,7 @@ class RnnCompute : public KernelLite { virtual ~RnnCompute() = default; }; -} // namespace arm +} // namespace x86 } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index 38bf62f42de..363eea8f144 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -145,7 +145,9 @@ TEST(Cast, precision) { #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) place = TARGET(kXPU); #elif defined(LITE_WITH_ARM) - place = TARGET(kARM); + place = TARGET(kHost); +#elif defined(LITE_WITH_X86) + place = TARGET(kHost); #else return; #endif diff --git a/lite/tests/kernels/elementwise_common_broadcast_test.cc b/lite/tests/kernels/elementwise_common_broadcast_test.cc index 6d8b93cb354..9b982df6039 100644 --- a/lite/tests/kernels/elementwise_common_broadcast_test.cc +++ b/lite/tests/kernels/elementwise_common_broadcast_test.cc @@ -506,4 +506,4 @@ TEST(elementwise_broadcast, compute_i64) { } } -#endif // LITE_WITH_ARM +#endif // LITE_WITH_X86 From 33d03b35136d12c08635c8ad68f222336745d9b7 Mon Sep 17 00:00:00 2001 From: mjp9527 <54735487+mjp9527@users.noreply.github.com> Date: Wed, 13 Oct 2021 10:50:47 +0800 Subject: [PATCH 2/4] [X86|ARM] support BIGRU (#7119) --- lite/api/cxx_api.cc | 4 +- lite/api/light_api.cc | 4 +- lite/backends/arm/math/gru.h | 2 +- lite/backends/arm/math/reduce_max.cc | 249 +--------------- lite/backends/arm/math/reduce_max.h | 280 ++++++++++++++---- lite/backends/arm/math/reduce_max_min.cc | 40 ++- lite/backends/arm/math/reduce_max_min.h | 32 +- lite/backends/arm/math/reduce_min.cc | 246 +-------------- lite/backends/arm/math/reduce_min.h | 276 +++++++++++++---- lite/core/profile/precision_profiler.h | 8 +- lite/kernels/arm/reduce_max_compute.cc | 60 ++-- lite/kernels/arm/reduce_max_compute.h | 1 + lite/kernels/arm/reduce_min_compute.cc | 64 ++-- lite/kernels/arm/reduce_min_compute.h | 1 + lite/kernels/arm/rnn_compute.cc | 2 +- lite/kernels/arm/slice_compute.cc | 69 ++++- lite/kernels/host/compare_compute.cc | 2 +- lite/kernels/host/flip_compute.h | 12 +- lite/operators/conditional_block_op.h | 3 + .../fill_constant_batch_size_like_op.h | 2 + lite/operators/op_params.h | 2 + lite/operators/slice_op.cc | 62 +++- 22 files changed, 727 insertions(+), 694 deletions(-) diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index ba3d64a1334..8225f14deae 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -525,7 +525,9 @@ void Predictor::ClearTensorArray( for (size_t var_idx = 0; var_idx < block->VarsSize(); var_idx++) { const cpp::VarDesc *var = block->GetVar(var_idx); CHECK(var); - if (var->GetType() == lite::VarDataType::LOD_TENSOR_ARRAY) { + + auto tmp = program_->exec_scope()->FindVar(var->Name()); + if (tmp->IsType>()) { std::vector *tensor_array_var = program_->exec_scope()->FindMutableTensorList(var->Name()); CHECK(tensor_array_var); diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index efa15330d57..8b99aa96f24 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -400,7 +400,9 @@ void LightPredictor::ClearTensorArray( for (size_t var_idx = 0; var_idx < block->VarsSize(); var_idx++) { const cpp::VarDesc* var = block->GetVar(var_idx); CHECK(var); - if (var->GetType() == lite::VarDataType::LOD_TENSOR_ARRAY) { + + auto tmp = program_->exec_scope()->FindVar(var->Name()); + if (tmp->IsType>()) { std::vector* tensor_array_var = program_->exec_scope()->FindMutableTensorList(var->Name()); CHECK(tensor_array_var); diff --git a/lite/backends/arm/math/gru.h b/lite/backends/arm/math/gru.h index 1492c57d6a2..14f2c8355fb 100644 --- a/lite/backends/arm/math/gru.h +++ b/lite/backends/arm/math/gru.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/lite/backends/arm/math/reduce_max.cc b/lite/backends/arm/math/reduce_max.cc index b20b93cadf3..6e5a6f853f6 100644 --- a/lite/backends/arm/math/reduce_max.cc +++ b/lite/backends/arm/math/reduce_max.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,251 +16,7 @@ limitations under the License. */ namespace paddle { namespace lite { namespace arm { -namespace math { - -template <> -void reduce_n(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int hw_size = height_in * width_in; - int chw_size = channel_in * hw_size; - int data_index, src_index, src_index0; - for (int c = 0; c < channel_in; ++c) { - for (int h = 0; h < height_in; ++h) { - for (int w = 0; w < width_in; ++w) { - data_index = c * hw_size + h * width_in + w; - dst[data_index] = src[data_index]; - for (int n = 1; n < num_in; ++n) { - src_index = n * chw_size + data_index; - dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_first_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - for (int i = 0; i < second_in; i++) { - for (int j = 0; j < third_in; j++) { - dst[i * third_in + j] = src[i * third_in + j]; - for (int k = 1; k < first_in; k++) { - dst[i * third_in + j] = - src[k * second_in * third_in + i * third_in + j] > - dst[i * third_in + j] - ? src[k * second_in * third_in + i * third_in + j] - : dst[i * third_in + j]; - } - } - } -} - -template <> -void reduce_second_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - for (int i = 0; i < first_in; i++) { - for (int j = 0; j < third_in; j++) { - dst[i * third_in + j] = src[i * second_in * third_in + j]; - for (int k = 1; k < second_in; k++) { - dst[i * third_in + j] = - src[i * second_in * third_in + third_in * k + j] > - dst[i * third_in + j] - ? src[i * second_in * third_in + third_in * k + j] - : dst[i * third_in + j]; - } - } - } -} - -template <> -void reduce_third_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - for (int i = 0; i < first_in; i++) { - for (int j = 0; j < second_in; j++) { - dst[i * second_in + j] = src[i * second_in * third_in + j * second_in]; - for (int k = 0; k < third_in; k++) { - dst[i * second_in + j] = - src[i * second_in * third_in + j * third_in + k] > - dst[i * second_in + j] - ? src[i * second_in * third_in + j * third_in + k] - : dst[i * second_in + j]; - } - } - } -} - -template <> -void reduce_all_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - float max = src[0]; - int total_element = first_in * second_in * third_in; - for (int i = 0; i < total_element; i++) { - max = src[i] > max ? src[i] : max; - } - dst[0] = max; -} - -template <> -void reduce_c(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int hw_size = height_in * width_in; - int chw_size = hw_size * channel_in; - int data_index, src_index0, src_index; - for (int n = 0; n < num_in; ++n) { - for (int h = 0; h < height_in; ++h) { - for (int w = 0; w < width_in; ++w) { - data_index = n * hw_size + h * width_in + w; - src_index0 = n * chw_size + h * width_in + w; - dst[data_index] = src[src_index0]; - for (int c = 1; c < channel_in; ++c) { - src_index = src_index0 + c * hw_size; - dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_h(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int cw_size = channel_in * width_in; - int chw_size = cw_size * height_in; - int hw_size = height_in * width_in; - int data_index, src_index, src_index0; - for (int n = 0; n < num_in; ++n) { - for (int c = 0; c < channel_in; ++c) { - for (int w = 0; w < width_in; ++w) { - data_index = n * cw_size + c * width_in + w; - src_index0 = n * chw_size + c * hw_size + w; - dst[data_index] = src[src_index0]; - for (int h = 1; h < height_in; ++h) { - src_index = src_index0 + h * width_in; - dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_w(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int ch_size = channel_in * height_in; - int hw_size = height_in * width_in; - int chw_size = ch_size * width_in; - int data_index = 0; - int src_index0 = 0; - int src_index = 0; - for (int n = 0; n < num_in; ++n) { - for (int c = 0; c < channel_in; ++c) { - for (int h = 0; h < height_in; ++h) { - data_index = n * ch_size + c * height_in + h; - src_index0 = n * chw_size + c * hw_size + h * width_in; - dst[data_index] = src[src_index0]; - for (int w = 1; w < width_in; ++w) { - src_index = src_index0 + w; - dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_all(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - float max = src[0]; - int src_index; - int n_id, c_id; - for (int n = 0; n < num_in; ++n) { - n_id = n * channel_in * height_in * width_in; - for (int c = 0; c < channel_in; ++c) { - c_id = c * height_in * width_in; - for (int h = 0; h < height_in; ++h) { - for (int w = 0; w < width_in; ++w) { - src_index = n_id + c_id + h * width_in + w; - max = src[src_index] > max ? src[src_index] : max; - } - } - } - } - dst[0] = max; -} - -template <> -void reduce_nc(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - // reduce n first. - DDimLite ddimA({1, channel_in, height_in, width_in}); - lite::Tensor tensor_tmp; - tensor_tmp.Resize(ddimA); - float* tmp_out = tensor_tmp.mutable_data(); - reduce_n(src, tmp_out, num_in, channel_in, height_in, width_in); - reduce_c(tmp_out, dst, 1, channel_in, height_in, width_in); -} - -template <> -void reduce_ch(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - // reduce c first - DDimLite ddimA({num_in, 1, height_in, width_in}); - lite::Tensor tensor_tmp; - tensor_tmp.Resize(ddimA); - float* tmp_out = tensor_tmp.mutable_data(); - reduce_c(src, tmp_out, num_in, channel_in, height_in, width_in); - reduce_h(tmp_out, dst, num_in, 1, height_in, width_in); -} - -template <> -void reduce_hw(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - // reduce h first - DDimLite ddimA({num_in, channel_in, 1, width_in}); - lite::Tensor tensor_tmp; - tensor_tmp.Resize(ddimA); - float* tmp_out = tensor_tmp.mutable_data(); - reduce_h(src, tmp_out, num_in, channel_in, height_in, width_in); - reduce_w(tmp_out, dst, num_in, channel_in, 1, width_in); -} - -} // namespace math +namespace math {} // namespace math } // namespace arm } // namespace lite } // namespace paddle diff --git a/lite/backends/arm/math/reduce_max.h b/lite/backends/arm/math/reduce_max.h index e8dafd07653..f87efb5add9 100644 --- a/lite/backends/arm/math/reduce_max.h +++ b/lite/backends/arm/math/reduce_max.h @@ -1,11 +1,8 @@ /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,90 +11,255 @@ limitations under the License. */ #pragma once +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/tensor.h" + namespace paddle { namespace lite { namespace arm { namespace math { template -void reduce_n(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = channel_in * hw_size; + int data_index, src_index, src_index0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = c * hw_size + h * width_in + w; + dst[data_index] = src[data_index]; + for (int n = 1; n < num_in; ++n) { + src_index = n * chw_size + data_index; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_c(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_first_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < second_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * third_in + j]; + for (int k = 1; k < first_in; k++) { + dst[i * third_in + j] = + src[k * second_in * third_in + i * third_in + j] > + dst[i * third_in + j] + ? src[k * second_in * third_in + i * third_in + j] + : dst[i * third_in + j]; + } + } + } +} template -void reduce_all_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_second_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * second_in * third_in + j]; + for (int k = 1; k < second_in; k++) { + dst[i * third_in + j] = + src[i * second_in * third_in + third_in * k + j] > + dst[i * third_in + j] + ? src[i * second_in * third_in + third_in * k + j] + : dst[i * third_in + j]; + } + } + } +} template -void reduce_first_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_third_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < second_in; j++) { + dst[i * second_in + j] = src[i * second_in * third_in + j * second_in]; + for (int k = 0; k < third_in; k++) { + dst[i * second_in + j] = + src[i * second_in * third_in + j * third_in + k] > + dst[i * second_in + j] + ? src[i * second_in * third_in + j * third_in + k] + : dst[i * second_in + j]; + } + } + } +} template -void reduce_second_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_all_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + T max = src[0]; + int total_element = first_in * second_in * third_in; + for (int i = 0; i < total_element; i++) { + max = src[i] > max ? src[i] : max; + } + dst[0] = max; +} template -void reduce_third_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + int data_index, src_index0, src_index; + for (int n = 0; n < num_in; ++n) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = n * hw_size + h * width_in + w; + src_index0 = n * chw_size + h * width_in + w; + dst[data_index] = src[src_index0]; + for (int c = 1; c < channel_in; ++c) { + src_index = src_index0 + c * hw_size; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_h(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int cw_size = channel_in * width_in; + int chw_size = cw_size * height_in; + int hw_size = height_in * width_in; + int data_index, src_index, src_index0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = n * cw_size + c * width_in + w; + src_index0 = n * chw_size + c * hw_size + w; + dst[data_index] = src[src_index0]; + for (int h = 1; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_w(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int hw_size = height_in * width_in; + int chw_size = ch_size * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = n * ch_size + c * height_in + h; + src_index0 = n * chw_size + c * hw_size + h * width_in; + dst[data_index] = src[src_index0]; + for (int w = 1; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] = dst[data_index] > src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_nc(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_all(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + T max = src[0]; + int src_index; + int n_id, c_id; + for (int n = 0; n < num_in; ++n) { + n_id = n * channel_in * height_in * width_in; + for (int c = 0; c < channel_in; ++c) { + c_id = c * height_in * width_in; + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + src_index = n_id + c_id + h * width_in + w; + max = src[src_index] > max ? src[src_index] : max; + } + } + } + } + dst[0] = max; +} template -void reduce_ch(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce n first. + DDimLite ddimA({1, channel_in, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + T* tmp_out = tensor_tmp.mutable_data(); + reduce_n(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_c(tmp_out, dst, 1, channel_in, height_in, width_in); +} template -void reduce_hw(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce c first + DDimLite ddimA({num_in, 1, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + T* tmp_out = tensor_tmp.mutable_data(); + reduce_c(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_h(tmp_out, dst, num_in, 1, height_in, width_in); +} template -void reduce_all(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce h first + DDimLite ddimA({num_in, channel_in, 1, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + T* tmp_out = tensor_tmp.mutable_data(); + reduce_h(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_w(tmp_out, dst, num_in, channel_in, 1, width_in); +} } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/reduce_max_min.cc b/lite/backends/arm/math/reduce_max_min.cc index 19872ab3f7f..e82205fb48b 100644 --- a/lite/backends/arm/math/reduce_max_min.cc +++ b/lite/backends/arm/math/reduce_max_min.cc @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "lite/backends/arm/math/reduce_max_min.h" -#include -#include -#include "lite/backends/arm/math/funcs.h" -#include "lite/core/tensor.h" namespace paddle { namespace lite { @@ -59,6 +55,42 @@ void reduce_first_of_two(const float* src, } } +template <> +void reduce_second_of_two(const int64_t* src, + int64_t* dst, + int first_in, + int second_in, + MaxMinType max_min_selector) { + // max_min_selector == true, do reduce max; else do reduce min + for (int j = 0; j < second_in; j++) { + dst[j * first_in] = src[j * first_in]; + for (int k = 1; k < first_in; k++) { + dst[j * first_in] = (src[j * first_in + k] <= dst[j * first_in]) ^ + static_cast(max_min_selector) + ? src[j * first_in + k] + : dst[j * first_in]; + } + } +} + +template <> +void reduce_first_of_two(const int64_t* src, + int64_t* dst, + int first_in, + int second_in, + MaxMinType max_min_selector) { + // max_min_selector == true, do reduce max; else do reduce min + for (int j = 0; j < first_in; j++) { + dst[j] = src[j]; + for (int k = 1; k < second_in; k++) { + dst[j] = (src[j + k * first_in] <= dst[j]) ^ + static_cast(max_min_selector) + ? src[j + k * first_in] + : dst[j]; + } + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/reduce_max_min.h b/lite/backends/arm/math/reduce_max_min.h index 01e24a8d282..2ddfc699e0a 100644 --- a/lite/backends/arm/math/reduce_max_min.h +++ b/lite/backends/arm/math/reduce_max_min.h @@ -14,22 +14,46 @@ limitations under the License. */ #pragma once +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/tensor.h" + namespace paddle { namespace lite { namespace arm { namespace math { enum class MaxMinType : bool { kMin = false, kMax = true }; + +template +inline void reduce_one_line_max(const DataType* src, DataType* dst, int size) { + DataType tmp = src[0]; + for (int i = 0; i < size; i++) { + if (tmp <= src[i]) tmp = src[i]; + } + *dst = tmp; +} + +template +inline void reduce_one_line_min(const DataType* src, DataType* dst, int size) { + DataType tmp = src[0]; + for (int i = 0; i < size; i++) { + if (tmp > src[i]) tmp = src[i]; + } + *dst = tmp; +} + template -void reduce_first_of_two(const float* src, - float* dst, +void reduce_first_of_two(const DataType* src, + DataType* dst, int first_in, int second_in, MaxMinType compare_functor); template -void reduce_second_of_two(const float* src, - float* dst, +void reduce_second_of_two(const DataType* src, + DataType* dst, int first_in, int second_in, MaxMinType max_min_selector); diff --git a/lite/backends/arm/math/reduce_min.cc b/lite/backends/arm/math/reduce_min.cc index eafdb32601e..43634d5b570 100644 --- a/lite/backends/arm/math/reduce_min.cc +++ b/lite/backends/arm/math/reduce_min.cc @@ -19,251 +19,7 @@ limitations under the License. */ namespace paddle { namespace lite { namespace arm { -namespace math { - -template <> -void reduce_min_n(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int hw_size = height_in * width_in; - int chw_size = channel_in * hw_size; - int data_index, src_index, src_index0; - for (int c = 0; c < channel_in; ++c) { - for (int h = 0; h < height_in; ++h) { - for (int w = 0; w < width_in; ++w) { - data_index = c * hw_size + h * width_in + w; - dst[data_index] = src[data_index]; - for (int n = 1; n < num_in; ++n) { - src_index = n * chw_size + data_index; - dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_min_first_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - for (int i = 0; i < second_in; i++) { - for (int j = 0; j < third_in; j++) { - dst[i * third_in + j] = src[i * third_in + j]; - for (int k = 1; k < first_in; k++) { - dst[i * third_in + j] = - src[k * second_in * third_in + i * third_in + j] < - dst[i * third_in + j] - ? src[k * second_in * third_in + i * third_in + j] - : dst[i * third_in + j]; - } - } - } -} - -template <> -void reduce_min_second_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - for (int i = 0; i < first_in; i++) { - for (int j = 0; j < third_in; j++) { - dst[i * third_in + j] = src[i * second_in * third_in + j]; - for (int k = 1; k < second_in; k++) { - dst[i * third_in + j] = - src[i * second_in * third_in + third_in * k + j] < - dst[i * third_in + j] - ? src[i * second_in * third_in + third_in * k + j] - : dst[i * third_in + j]; - } - } - } -} - -template <> -void reduce_min_third_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - for (int i = 0; i < first_in; i++) { - for (int j = 0; j < second_in; j++) { - dst[i * second_in + j] = src[i * second_in * third_in + j * third_in]; - for (int k = 0; k < third_in; k++) { - dst[i * second_in + j] = - src[i * second_in * third_in + j * third_in + k] < - dst[i * second_in + j] - ? src[i * second_in * third_in + j * third_in + k] - : dst[i * second_in + j]; - } - } - } -} - -template <> -void reduce_min_all_of_three( - const float* src, float* dst, int first_in, int second_in, int third_in) { - float min = src[0]; - int total_element = first_in * second_in * third_in; - for (int i = 0; i < total_element; i++) { - min = src[i] < min ? src[i] : min; - } - dst[0] = min; -} - -template <> -void reduce_min_c(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int hw_size = height_in * width_in; - int chw_size = hw_size * channel_in; - int data_index, src_index0, src_index; - for (int n = 0; n < num_in; ++n) { - for (int h = 0; h < height_in; ++h) { - for (int w = 0; w < width_in; ++w) { - data_index = n * hw_size + h * width_in + w; - src_index0 = n * chw_size + h * width_in + w; - dst[data_index] = src[src_index0]; - for (int c = 1; c < channel_in; ++c) { - src_index = src_index0 + c * hw_size; - dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_min_h(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int cw_size = channel_in * width_in; - int chw_size = cw_size * height_in; - int hw_size = height_in * width_in; - int data_index, src_index, src_index0; - for (int n = 0; n < num_in; ++n) { - for (int c = 0; c < channel_in; ++c) { - for (int w = 0; w < width_in; ++w) { - data_index = n * cw_size + c * width_in + w; - src_index0 = n * chw_size + c * hw_size + w; - dst[data_index] = src[src_index0]; - for (int h = 1; h < height_in; ++h) { - src_index = src_index0 + h * width_in; - dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_min_w(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - int ch_size = channel_in * height_in; - int hw_size = height_in * width_in; - int chw_size = ch_size * width_in; - int data_index = 0; - int src_index0 = 0; - int src_index = 0; - for (int n = 0; n < num_in; ++n) { - for (int c = 0; c < channel_in; ++c) { - for (int h = 0; h < height_in; ++h) { - data_index = n * ch_size + c * height_in + h; - src_index0 = n * chw_size + c * hw_size + h * width_in; - dst[data_index] = src[src_index0]; - for (int w = 1; w < width_in; ++w) { - src_index = src_index0 + w; - dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] - : src[src_index]; - } - } - } - } -} - -template <> -void reduce_min_all(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - float min = src[0]; - int src_index; - int n_id, c_id; - for (int n = 0; n < num_in; ++n) { - n_id = n * channel_in * height_in * width_in; - for (int c = 0; c < channel_in; ++c) { - c_id = c * height_in * width_in; - for (int h = 0; h < height_in; ++h) { - for (int w = 0; w < width_in; ++w) { - src_index = n_id + c_id + h * width_in + w; - min = src[src_index] < min ? src[src_index] : min; - } - } - } - } - dst[0] = min; -} - -template <> -void reduce_min_nc(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - // reduce n first. - DDimLite ddimA({1, channel_in, height_in, width_in}); - lite::Tensor tensor_tmp; - tensor_tmp.Resize(ddimA); - float* tmp_out = tensor_tmp.mutable_data(); - reduce_min_n(src, tmp_out, num_in, channel_in, height_in, width_in); - reduce_min_c(tmp_out, dst, 1, channel_in, height_in, width_in); -} - -template <> -void reduce_min_ch(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - // reduce c first - DDimLite ddimA({num_in, 1, height_in, width_in}); - lite::Tensor tensor_tmp; - tensor_tmp.Resize(ddimA); - float* tmp_out = tensor_tmp.mutable_data(); - reduce_min_c(src, tmp_out, num_in, channel_in, height_in, width_in); - reduce_min_h(tmp_out, dst, num_in, 1, height_in, width_in); -} - -template <> -void reduce_min_hw(const float* src, - float* dst, - int num_in, - int channel_in, - int height_in, - int width_in) { - // reduce h first - DDimLite ddimA({num_in, channel_in, 1, width_in}); - lite::Tensor tensor_tmp; - tensor_tmp.Resize(ddimA); - float* tmp_out = tensor_tmp.mutable_data(); - reduce_min_h(src, tmp_out, num_in, channel_in, height_in, width_in); - reduce_min_w(tmp_out, dst, num_in, channel_in, 1, width_in); -} - -} // namespace math +namespace math {} // namespace math } // namespace arm } // namespace lite } // namespace paddle diff --git a/lite/backends/arm/math/reduce_min.h b/lite/backends/arm/math/reduce_min.h index e41b57c8bf2..eb9123e93bf 100644 --- a/lite/backends/arm/math/reduce_min.h +++ b/lite/backends/arm/math/reduce_min.h @@ -14,90 +14,254 @@ limitations under the License. */ #pragma once +#include "lite/core/tensor.h" + namespace paddle { namespace lite { namespace arm { namespace math { template -void reduce_min_n(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_n(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = channel_in * hw_size; + int data_index, src_index, src_index0; + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = c * hw_size + h * width_in + w; + dst[data_index] = src[data_index]; + for (int n = 1; n < num_in; ++n) { + src_index = n * chw_size + data_index; + dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_min_c(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_c(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int hw_size = height_in * width_in; + int chw_size = hw_size * channel_in; + int data_index, src_index0, src_index; + for (int n = 0; n < num_in; ++n) { + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + data_index = n * hw_size + h * width_in + w; + src_index0 = n * chw_size + h * width_in + w; + dst[data_index] = src[src_index0]; + for (int c = 1; c < channel_in; ++c) { + src_index = src_index0 + c * hw_size; + dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_min_all_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_min_all_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + T min = src[0]; + int total_element = first_in * second_in * third_in; + for (int i = 0; i < total_element; i++) { + min = src[i] < min ? src[i] : min; + } + dst[0] = min; +} template -void reduce_min_first_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_min_first_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < second_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * third_in + j]; + for (int k = 1; k < first_in; k++) { + dst[i * third_in + j] = + src[k * second_in * third_in + i * third_in + j] < + dst[i * third_in + j] + ? src[k * second_in * third_in + i * third_in + j] + : dst[i * third_in + j]; + } + } + } +} template -void reduce_min_second_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_min_second_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < third_in; j++) { + dst[i * third_in + j] = src[i * second_in * third_in + j]; + for (int k = 1; k < second_in; k++) { + dst[i * third_in + j] = + src[i * second_in * third_in + third_in * k + j] < + dst[i * third_in + j] + ? src[i * second_in * third_in + third_in * k + j] + : dst[i * third_in + j]; + } + } + } +} template -void reduce_min_third_of_three( - const T* src, T* dst, int first_in, int second_in, int third_in); +inline void reduce_min_third_of_three( + const T* src, T* dst, int first_in, int second_in, int third_in) { + for (int i = 0; i < first_in; i++) { + for (int j = 0; j < second_in; j++) { + dst[i * second_in + j] = src[i * second_in * third_in + j * third_in]; + for (int k = 0; k < third_in; k++) { + dst[i * second_in + j] = + src[i * second_in * third_in + j * third_in + k] < + dst[i * second_in + j] + ? src[i * second_in * third_in + j * third_in + k] + : dst[i * second_in + j]; + } + } + } +} template -void reduce_min_h(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_h(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int cw_size = channel_in * width_in; + int chw_size = cw_size * height_in; + int hw_size = height_in * width_in; + int data_index, src_index, src_index0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int w = 0; w < width_in; ++w) { + data_index = n * cw_size + c * width_in + w; + src_index0 = n * chw_size + c * hw_size + w; + dst[data_index] = src[src_index0]; + for (int h = 1; h < height_in; ++h) { + src_index = src_index0 + h * width_in; + dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_min_w(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_w(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + int ch_size = channel_in * height_in; + int hw_size = height_in * width_in; + int chw_size = ch_size * width_in; + int data_index = 0; + int src_index0 = 0; + int src_index = 0; + for (int n = 0; n < num_in; ++n) { + for (int c = 0; c < channel_in; ++c) { + for (int h = 0; h < height_in; ++h) { + data_index = n * ch_size + c * height_in + h; + src_index0 = n * chw_size + c * hw_size + h * width_in; + dst[data_index] = src[src_index0]; + for (int w = 1; w < width_in; ++w) { + src_index = src_index0 + w; + dst[data_index] = dst[data_index] < src[src_index] ? dst[data_index] + : src[src_index]; + } + } + } + } +} template -void reduce_min_nc(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_nc(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce n first. + DDimLite ddimA({1, channel_in, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + T* tmp_out = tensor_tmp.mutable_data(); + reduce_min_n(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_min_c(tmp_out, dst, 1, channel_in, height_in, width_in); +} template -void reduce_min_ch(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_ch(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce c first + DDimLite ddimA({num_in, 1, height_in, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + T* tmp_out = tensor_tmp.mutable_data(); + reduce_min_c(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_min_h(tmp_out, dst, num_in, 1, height_in, width_in); +} template -void reduce_min_hw(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_hw(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + // reduce h first + DDimLite ddimA({num_in, channel_in, 1, width_in}); + lite::Tensor tensor_tmp; + tensor_tmp.Resize(ddimA); + T* tmp_out = tensor_tmp.mutable_data(); + reduce_min_h(src, tmp_out, num_in, channel_in, height_in, width_in); + reduce_min_w(tmp_out, dst, num_in, channel_in, 1, width_in); +} template -void reduce_min_all(const T* src, - T* dst, - int num_in, - int channel_in, - int height_in, - int width_in); +inline void reduce_min_all(const T* src, + T* dst, + int num_in, + int channel_in, + int height_in, + int width_in) { + T min = src[0]; + int src_index; + int n_id, c_id; + for (int n = 0; n < num_in; ++n) { + n_id = n * channel_in * height_in * width_in; + for (int c = 0; c < channel_in; ++c) { + c_id = c * height_in * width_in; + for (int h = 0; h < height_in; ++h) { + for (int w = 0; w < width_in; ++w) { + src_index = n_id + c_id + h * width_in + w; + min = src[src_index] < min ? src[src_index] : min; + } + } + } + } + dst[0] = min; +} } // namespace math } // namespace arm diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h index b749f25f7a8..1b1715f9580 100644 --- a/lite/core/profile/precision_profiler.h +++ b/lite/core/profile/precision_profiler.h @@ -277,9 +277,11 @@ class PrecisionProfiler { } #endif case PRECISION(kBool): { - *mean = -333333333333; - *std_dev = -33333333333; - *ave_grow_rate = -33333333333; + auto ptr = in->data(); + *mean = compute_mean(ptr, in->numel()); + *std_dev = + compute_standard_deviation(ptr, in->numel(), true, *mean); + *ave_grow_rate = compute_average_grow_rate(ptr, in->numel()); if (write_result_to_file) { write_tensorfile(in, name, log_dir_); } diff --git a/lite/kernels/arm/reduce_max_compute.cc b/lite/kernels/arm/reduce_max_compute.cc index c2e3c7f5f5b..e2f1100f136 100644 --- a/lite/kernels/arm/reduce_max_compute.cc +++ b/lite/kernels/arm/reduce_max_compute.cc @@ -13,23 +13,25 @@ // limitations under the License. #include "lite/kernels/arm/reduce_max_compute.h" - #include - #include "lite/backends/arm/math/funcs.h" +#include "lite/backends/arm/math/reduce_max.h" +#include "lite/backends/arm/math/reduce_max_min.h" +#include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -void ReduceMaxCompute::Run() { +template +void ReduceMaxCompute::Run() { auto& param = Param(); - const float* input = param.X->data(); + const T* input = param.X->template data(); auto x_dims = param.X->dims(); int x_rank = x_dims.size(); - float* output = param.Out->mutable_data(); + T* output = param.Out->template mutable_data(); bool keep_dim = param.keep_dim; auto dim = param.dim; @@ -43,21 +45,21 @@ void ReduceMaxCompute::Run() { if (x_dims.size() == 3) { if (dim.size() == 0 || dim.size() == 3) { - lite::arm::math::reduce_all_of_three( + lite::arm::math::reduce_all_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); } else if (dim.size() == 1) { switch (dim[0]) { case 0: - lite::arm::math::reduce_first_of_three( + lite::arm::math::reduce_first_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); break; case 1: - lite::arm::math::reduce_second_of_three( + lite::arm::math::reduce_second_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); break; case 2: - lite::arm::math::reduce_third_of_three( + lite::arm::math::reduce_third_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); break; default: @@ -75,31 +77,31 @@ void ReduceMaxCompute::Run() { int w_in = x_dims[3]; if (dim.size() == 0) { - lite::arm::math::reduce_all(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_all(input, output, n_in, c_in, h_in, w_in); } else if (dim.size() == 1) { switch (dim[0]) { case 0: - lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_n(input, output, n_in, c_in, h_in, w_in); break; case 1: - lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_c(input, output, n_in, c_in, h_in, w_in); break; case 2: - lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_h(input, output, n_in, c_in, h_in, w_in); break; case 3: - lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_w(input, output, n_in, c_in, h_in, w_in); break; default: LOG(FATAL) << "error!!!"; } } else if (dim.size() == 2) { if (dim[0] == 0 && dim[1] == 1) { - lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_nc(input, output, n_in, c_in, h_in, w_in); } else if (dim[0] == 1 && dim[1] == 2) { - lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_ch(input, output, n_in, c_in, h_in, w_in); } else if (dim[0] == 2 && dim[1] == 3) { - lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_hw(input, output, n_in, c_in, h_in, w_in); } else { LOG(FATAL) << "invalid dim!!"; } @@ -112,7 +114,7 @@ void ReduceMaxCompute::Run() { if (dim.size() == 1) { switch (dim[0]) { case 0: - lite::arm::math::reduce_first_of_two( + lite::arm::math::reduce_first_of_two( input, output, first_in, @@ -120,7 +122,7 @@ void ReduceMaxCompute::Run() { lite::arm::math::MaxMinType::kMax); break; case 1: - lite::arm::math::reduce_second_of_two( + lite::arm::math::reduce_second_of_two( input, output, first_in, @@ -133,9 +135,11 @@ void ReduceMaxCompute::Run() { } else { LOG(FATAL) << "dim's size over than 1, which is not supported now!!"; } // x_dims == 2 && dim.size() == 1 + } else if (x_dims.size() == 1) { + lite::arm::math::reduce_one_line_max(input, output, x_dims[0]); } else { - LOG(FATAL) << "only support input with 2&3&4 dimensions now!!"; - } // x_dims == 2 + LOG(FATAL) << "only support input with 1 to 4 dimensions now!!"; + } } } // namespace arm @@ -143,12 +147,14 @@ void ReduceMaxCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(reduce_max, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ReduceMaxCompute, - def) +using float_reduce_max = paddle::lite::kernels::arm::ReduceMaxCompute; +REGISTER_LITE_KERNEL(reduce_max, kARM, kFloat, kNCHW, float_reduce_max, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); + +using int64_reduce_max = paddle::lite::kernels::arm::ReduceMaxCompute; +REGISTER_LITE_KERNEL(reduce_max, kARM, kFloat, kNCHW, int64_reduce_max, i64) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); diff --git a/lite/kernels/arm/reduce_max_compute.h b/lite/kernels/arm/reduce_max_compute.h index f2228284aee..864387e1601 100644 --- a/lite/kernels/arm/reduce_max_compute.h +++ b/lite/kernels/arm/reduce_max_compute.h @@ -23,6 +23,7 @@ namespace lite { namespace kernels { namespace arm { +template class ReduceMaxCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/reduce_min_compute.cc b/lite/kernels/arm/reduce_min_compute.cc index f0ca0da771b..a05b20e69c7 100644 --- a/lite/kernels/arm/reduce_min_compute.cc +++ b/lite/kernels/arm/reduce_min_compute.cc @@ -15,19 +15,22 @@ #include "lite/kernels/arm/reduce_min_compute.h" #include #include "lite/backends/arm/math/funcs.h" +#include "lite/backends/arm/math/reduce_max_min.h" +#include "lite/backends/arm/math/reduce_min.h" namespace paddle { namespace lite { namespace kernels { namespace arm { -void ReduceMinCompute::Run() { +template +void ReduceMinCompute::Run() { auto& param = Param(); - const float* input = param.X->data(); + const T* input = param.X->template data(); auto x_dims = param.X->dims(); int x_rank = x_dims.size(); - float* output = param.Out->mutable_data(); + T* output = param.Out->template mutable_data(); bool keep_dim = param.keep_dim; auto dim = param.dim; @@ -41,21 +44,21 @@ void ReduceMinCompute::Run() { if (x_dims.size() == 3) { if (dim.size() == 0 || dim.size() == 3) { - lite::arm::math::reduce_min_all_of_three( + lite::arm::math::reduce_min_all_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); } else if (dim.size() == 1) { switch (dim[0]) { case 0: - lite::arm::math::reduce_min_first_of_three( + lite::arm::math::reduce_min_first_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); break; case 1: - lite::arm::math::reduce_min_second_of_three( + lite::arm::math::reduce_min_second_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); break; case 2: - lite::arm::math::reduce_min_third_of_three( + lite::arm::math::reduce_min_third_of_three( input, output, x_dims[0], x_dims[1], x_dims[2]); break; default: @@ -73,31 +76,38 @@ void ReduceMinCompute::Run() { int w_in = x_dims[3]; if (dim.size() == 0) { - lite::arm::math::reduce_min_all(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_all(input, output, n_in, c_in, h_in, w_in); } else if (dim.size() == 1) { switch (dim[0]) { case 0: - lite::arm::math::reduce_min_n(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_n( + input, output, n_in, c_in, h_in, w_in); break; case 1: - lite::arm::math::reduce_min_c(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_c( + input, output, n_in, c_in, h_in, w_in); break; case 2: - lite::arm::math::reduce_min_h(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_h( + input, output, n_in, c_in, h_in, w_in); break; case 3: - lite::arm::math::reduce_min_w(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_w( + input, output, n_in, c_in, h_in, w_in); break; default: LOG(FATAL) << "error!!!"; } } else if (dim.size() == 2) { if (dim[0] == 0 && dim[1] == 1) { - lite::arm::math::reduce_min_nc(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_nc( + input, output, n_in, c_in, h_in, w_in); } else if (dim[0] == 1 && dim[1] == 2) { - lite::arm::math::reduce_min_ch(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_ch( + input, output, n_in, c_in, h_in, w_in); } else if (dim[0] == 2 && dim[1] == 3) { - lite::arm::math::reduce_min_hw(input, output, n_in, c_in, h_in, w_in); + lite::arm::math::reduce_min_hw( + input, output, n_in, c_in, h_in, w_in); } else { LOG(FATAL) << "invalid dim!!"; } @@ -110,7 +120,7 @@ void ReduceMinCompute::Run() { if (dim.size() == 1) { switch (dim[0]) { case 0: - lite::arm::math::reduce_first_of_two( + lite::arm::math::reduce_first_of_two( input, output, first_in, @@ -118,7 +128,7 @@ void ReduceMinCompute::Run() { lite::arm::math::MaxMinType::kMin); break; case 1: - lite::arm::math::reduce_second_of_two( + lite::arm::math::reduce_second_of_two( input, output, first_in, @@ -128,9 +138,13 @@ void ReduceMinCompute::Run() { default: LOG(FATAL) << "error!!!"; } - } + } else { + LOG(FATAL) << "dim's size over than 1, which is not supported now!!"; + } // x_dims == 2 && dim.size() == 1 + } else if (x_dims.size() == 1) { + lite::arm::math::reduce_one_line_min(input, output, x_dims[0]); } else { - LOG(FATAL) << "only support input with 3&4 dimensions now!!"; + LOG(FATAL) << "only support input with 1 to 4 dimensions now!!"; } } @@ -143,8 +157,18 @@ REGISTER_LITE_KERNEL(reduce_min, kARM, kFloat, kNCHW, - paddle::lite::kernels::arm::ReduceMinCompute, + paddle::lite::kernels::arm::ReduceMinCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); + +REGISTER_LITE_KERNEL(reduce_min, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ReduceMinCompute, + def_int64) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/reduce_min_compute.h b/lite/kernels/arm/reduce_min_compute.h index 530738313f2..1d1a66de3f2 100644 --- a/lite/kernels/arm/reduce_min_compute.h +++ b/lite/kernels/arm/reduce_min_compute.h @@ -23,6 +23,7 @@ namespace lite { namespace kernels { namespace arm { +template class ReduceMinCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/rnn_compute.cc b/lite/kernels/arm/rnn_compute.cc index 8e67c82aa9a..a7ad3a61cbd 100644 --- a/lite/kernels/arm/rnn_compute.cc +++ b/lite/kernels/arm/rnn_compute.cc @@ -700,7 +700,7 @@ void RnnCompute::Run() { } else if ("GRU" == mode) { gate_num = 3; } else { - LOG(FATAL) << "X86 RNN ERROR: unsupport mode except gru and lstm," + LOG(FATAL) << "ARM RNN ERROR: unsupport mode except gru and lstm," " present mode is " << mode; return; diff --git a/lite/kernels/arm/slice_compute.cc b/lite/kernels/arm/slice_compute.cc index a484724a547..a25e2ebd0ab 100644 --- a/lite/kernels/arm/slice_compute.cc +++ b/lite/kernels/arm/slice_compute.cc @@ -23,6 +23,45 @@ namespace lite { namespace kernels { namespace arm { +void DealTensorArray(const operators::SliceParam& param, + const std::vector& starts, + const std::vector& ends, + bool out_is_array) { + auto in_array = param.XTensorList; + // If the input is LoDTensorArray, the rank of input is 1. + int64_t in_size = in_array->size(); + int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0]; + int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0]; + + start = std::max(start, static_cast(0)); + end = std::max(end, static_cast(0)); + end = std::min(end, in_size); + + CHECK_GT(end, start) << "end should greater than start"; + int64_t out_size = end - start; + + if (out_is_array) { + auto out_array = param.OutTensorList; + out_array->resize(out_size); + for (int i = 0; i < out_size; ++i) { + auto* out_tensor = &out_array->at(i); + auto in_tensor = in_array->at(i + start); + out_tensor->set_lod(in_tensor.lod()); + if (in_tensor.memory_size() > 0) { + out_tensor->CopyDataFrom(in_tensor); + } else { + VLOG(4) << "WARNING: The input tensor 'x_tensor' holds no memory, so " + "nothing has been written to output array[" + << i << "]."; + } + } + } else { + auto out_tensor = param.Out; + auto in_tensor = in_array->at(start); + out_tensor->CopyDataFrom(in_tensor); + } +} + inline std::vector get_new_data_from_tensorlist( const std::vector& list_new_data_tensor) { // get tensor @@ -36,9 +75,10 @@ inline std::vector get_new_data_from_tensorlist( } else if (tensor->precision() == PrecisionType::kInt64) { vec_new_data.push_back(static_cast(*tensor->data())); } else { - LOG(FATAL) << "slice StartsTensor or EndsTensor :The dtype of Tensor " - "must be int32 " - "or int64"; + vec_new_data.push_back(static_cast(*tensor->data())); + LOG(WARNING) << "slice StartsTensor or EndsTensor :The dtype of Tensor " + "must be int32 " + "or int64"; } } return vec_new_data; @@ -57,9 +97,13 @@ inline std::vector get_new_data_from_tensor( vec_new_data = std::vector(new_data, new_data + new_data_tensor->numel()); } else { - LOG(FATAL) << "slice StartsTensor or EndsTensor :The dtype of Tensor must " - "be int32 " - "or int64"; + auto* new_data = new_data_tensor->data(); + vec_new_data = + std::vector(new_data, new_data + new_data_tensor->numel()); + LOG(WARNING) + << "slice StartsTensor or EndsTensor :The dtype of Tensor must " + "be int32 " + "or int64"; } return vec_new_data; } @@ -92,6 +136,7 @@ void SliceCompute::Run() { if (list_new_starts_tensor.size() > 0 || list_new_ends_tensor.size() > 0) { need_infer = true; } + if (need_infer) { if (param.StartsTensor) { starts = get_new_data_from_tensor(param.StartsTensor); @@ -105,10 +150,18 @@ void SliceCompute::Run() { } else if (list_new_ends_tensor.size() > 0) { ends = get_new_data_from_tensorlist(list_new_ends_tensor); } + CHECK_EQ(ends.size(), axes.size()) << "The size of ends must be equal to the size of axes."; out_dims = in_dims; int64_t dim_value, start, end; + if (param.X == nullptr && param.XTensorList != nullptr) { + DealTensorArray(param, + starts, + ends, + (param.Out == nullptr && param.OutTensorList != nullptr)); + } + for (size_t i = 0; i < axes.size(); ++i) { dim_value = out_dims[axes[i]]; if (dim_value > 0) { @@ -130,6 +183,7 @@ void SliceCompute::Run() { out_dims[axes[i]] = end - start; } } + out->Resize(out_dims); // generate new shape if (decrease_axis.size() > 0) { @@ -155,7 +209,7 @@ void SliceCompute::Run() { // resize out dims if (decrease_axis.size() > 0) { - if (decrease_axis.size() == (size_t)in_dims.size()) { + if (decrease_axis.size() == static_cast(in_dims.size())) { std::vector vec_origin_out_shape(decrease_axis.size(), 1); out->Resize(DDim(vec_origin_out_shape)); } else { @@ -177,7 +231,6 @@ void SliceCompute::Run() { out->Resize(DDim(vec_origin_out_shape)); } } - auto new_out_dims = out->dims(); const auto* x_data = in->template data(); auto* o_data = out->template mutable_data(); diff --git a/lite/kernels/host/compare_compute.cc b/lite/kernels/host/compare_compute.cc index f5bf30d4a7b..db9d6b5625a 100644 --- a/lite/kernels/host/compare_compute.cc +++ b/lite/kernels/host/compare_compute.cc @@ -594,7 +594,7 @@ using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute< PRECISION(kFloat), paddle::lite::kernels::host::_GreaterEqualFunctor>; REGISTER_LITE_KERNEL( - greater_equal, kHost, kFloat, kAny, greater_equal_float, def_int64) + greater_equal, kHost, kFloat, kAny, greater_equal_int64, def_int64) .BindInput("X", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) diff --git a/lite/kernels/host/flip_compute.h b/lite/kernels/host/flip_compute.h index 6f4a082c506..a76f861c2e0 100644 --- a/lite/kernels/host/flip_compute.h +++ b/lite/kernels/host/flip_compute.h @@ -22,6 +22,16 @@ namespace lite { namespace kernels { namespace host { +DDimLite stride_flip(const DDimLite& ddim) { + std::vector tmp(ddim.size(), 0); + DDimLite strides(tmp); + strides[ddim.size() - 1] = 1; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i + 1]; + } + return strides; +} + template class FlipCompute : public KernelLite { public: @@ -43,7 +53,7 @@ class FlipCompute : public KernelLite { } dim_bitset[dim] = true; } - auto x_strides = x_dims.Vectorize(); + auto x_strides = stride_flip(x_dims); auto numel = x->numel(); const T* x_data = x->template data(); T* out_data = out->template mutable_data(); diff --git a/lite/operators/conditional_block_op.h b/lite/operators/conditional_block_op.h index adcd8acdff3..0274663435d 100644 --- a/lite/operators/conditional_block_op.h +++ b/lite/operators/conditional_block_op.h @@ -42,6 +42,9 @@ class ConditionalBlockOp : public OpLite { void SetProgramDesc(std::shared_ptr program_desc) { param_.program_desc = program_desc; } + + bool InferType() override { return true; } + std::shared_ptr GetProgramDesc() { return param_.program_desc; } diff --git a/lite/operators/fill_constant_batch_size_like_op.h b/lite/operators/fill_constant_batch_size_like_op.h index 3c576ab2822..a9130e7fdab 100644 --- a/lite/operators/fill_constant_batch_size_like_op.h +++ b/lite/operators/fill_constant_batch_size_like_op.h @@ -36,6 +36,8 @@ class FillConstantBatchSizeLikeOp : public OpLite { bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + bool InferType() override { return true; } + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "fill_constant_batch_size_like"; diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 637fd7ee52a..cb4797cfe6b 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1493,6 +1493,8 @@ struct CastParam : ParamBase { struct SliceParam : ParamBase { const lite::Tensor* X{nullptr}; lite::Tensor* Out{nullptr}; + const std::vector* XTensorList{nullptr}; + std::vector* OutTensorList{nullptr}; std::vector axes{}; std::vector starts{}; std::vector ends{}; diff --git a/lite/operators/slice_op.cc b/lite/operators/slice_op.cc index feb72d1764b..c42afc54ca9 100644 --- a/lite/operators/slice_op.cc +++ b/lite/operators/slice_op.cc @@ -22,8 +22,8 @@ namespace lite { namespace operators { bool SliceOp::CheckShape() const { - CHECK(param_.X); - CHECK(param_.Out); + CHECK(!(param_.X == nullptr && param_.XTensorList == nullptr)); + CHECK(!(param_.Out == nullptr && param_.OutTensorList == nullptr)); CHECK_LT(param_.X->dims().size(), 7u) << "The rank of input X should be less than 7"; return true; @@ -89,12 +89,26 @@ bool SliceOp::InferShapeImpl() const { bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { AttachParam(¶m_); - param_.X = - scope->FindVar(opdesc.Input("Input").front())->GetMutable(); - param_.Out = - scope->FindVar(opdesc.Output("Out").front())->GetMutable(); - CHECK(param_.X); - CHECK(param_.Out); + auto input_var = scope->FindVar(opdesc.Input("Input").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + bool input_is_array = input_var->IsType>(); + bool out_is_array = output_var->IsType>(); + if (input_is_array) { + param_.XTensorList = input_var->GetMutable>(); + CHECK(param_.XTensorList); + } else { + param_.X = scope->FindVar(opdesc.Input("Input").front()) + ->GetMutable(); + CHECK(param_.X); + } + if (out_is_array) { + param_.OutTensorList = output_var->GetMutable>(); + CHECK(param_.OutTensorList); + } else { + param_.Out = scope->FindVar(opdesc.Output("Out").front()) + ->GetMutable(); + CHECK(param_.Out); + } param_.axes = opdesc.GetAttr>("axes"); if (opdesc.HasAttr("infer_flags")) { @@ -125,9 +139,19 @@ bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { !opdesc.Input("StartsTensorList").empty()) { param_.StartsTensorList.clear(); auto StartsTensorList = opdesc.Input("StartsTensorList"); - for (auto var : StartsTensorList) { - param_.StartsTensorList.push_back( - scope->FindVar(var)->GetMutable()); + if (!StartsTensorList.empty() && + scope->FindVar(StartsTensorList[0]) + ->IsType>()) { + auto tmp_tensor_list = scope->FindVar(StartsTensorList[0]) + ->GetMutable>(); + for (auto tensor : *tmp_tensor_list) { + param_.StartsTensorList.push_back(&tensor); + } + } else { + for (auto var : StartsTensorList) { + param_.StartsTensorList.push_back( + scope->FindVar(var)->GetMutable()); + } } CHECK_GT(param_.StartsTensorList.size(), 0u) << "StartsTensorList size can't be zero"; @@ -138,9 +162,19 @@ bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { !opdesc.Input("EndsTensorList").empty()) { param_.EndsTensorList.clear(); auto EndsTensorList = opdesc.Input("EndsTensorList"); - for (auto var : EndsTensorList) { - param_.EndsTensorList.push_back( - scope->FindVar(var)->GetMutable()); + if (!EndsTensorList.empty() && + scope->FindVar(EndsTensorList[0]) + ->IsType>()) { + auto tmp_tensor_list = scope->FindVar(EndsTensorList[0]) + ->GetMutable>(); + for (auto tensor : *tmp_tensor_list) { + param_.EndsTensorList.push_back(&tensor); + } + } else { + for (auto var : EndsTensorList) { + param_.EndsTensorList.push_back( + scope->FindVar(var)->GetMutable()); + } } CHECK_GT(param_.EndsTensorList.size(), 0u) << "EndsTensorList size can't be zero"; From bea5aebe33fb2b29c7ea72dbb4ce5a11095c49be Mon Sep 17 00:00:00 2001 From: mjp9527 <54735487+mjp9527@users.noreply.github.com> Date: Thu, 21 Oct 2021 15:11:48 +0800 Subject: [PATCH 3/4] [ARM] fix opt report:no kernel for depthwise_conv_transpose, fix interpolate input scale shape error. (#7270) --- .../arm/math/fp16/interpolate_fp16.cc | 35 ++++++- .../backends/arm/math/fp16/interpolate_fp16.h | 3 +- lite/backends/arm/math/interpolate.cc | 35 ++++++- lite/backends/arm/math/interpolate.h | 3 +- lite/kernels/arm/CMakeLists.txt | 1 + .../arm/depthwise_conv_transpose_compute.cc | 92 +++++++++++++++++++ .../arm/depthwise_conv_transpose_compute.h | 32 +++++++ lite/kernels/arm/interpolate_compute.cc | 3 +- 8 files changed, 195 insertions(+), 9 deletions(-) create mode 100644 lite/kernels/arm/depthwise_conv_transpose_compute.cc create mode 100644 lite/kernels/arm/depthwise_conv_transpose_compute.h diff --git a/lite/backends/arm/math/fp16/interpolate_fp16.cc b/lite/backends/arm/math/fp16/interpolate_fp16.cc index 13e9bd61fca..03488e024ab 100644 --- a/lite/backends/arm/math/fp16/interpolate_fp16.cc +++ b/lite/backends/arm/math/fp16/interpolate_fp16.cc @@ -411,13 +411,39 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type) { + std::string interpolate_type, + std::vector scale_data) { int in_h = X->dims()[2]; int in_w = X->dims()[3]; + float height_scale = 0.f; + float width_scale = 0.f; + if (SizeTensor.size() > 0) { auto new_size = get_new_shape(SizeTensor); out_height = new_size[0]; out_width = new_size[1]; + } else if (scale_data.size() > 0) { + if (scale_data.size() == 1) { + if (scale_data[0] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[0]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } else if (scale_data.size() == 2) { + if (scale_data[0] > 0 && scale_data[1] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[1]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } + auto out_size = OutSize; + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = out_size_data[0]; + out_width = out_size_data[1]; + } } else { auto scale_tensor = Scale; if (scale_tensor != nullptr) { @@ -435,11 +461,14 @@ void interpolate(lite::Tensor* X, out_width = out_size_data[1]; } } - float height_scale = scale; - float width_scale = scale; + height_scale = scale; + width_scale = scale; if (out_width > 0 && out_height > 0) { height_scale = static_cast(out_height / X->dims()[2]); width_scale = static_cast(out_width / X->dims()[3]); + } else { + out_height = static_cast(X->dims()[2] * height_scale + 0.5f); + out_width = static_cast(X->dims()[3] * width_scale + 0.5f); } int num_cout = X->dims()[0]; int c_cout = X->dims()[1]; diff --git a/lite/backends/arm/math/fp16/interpolate_fp16.h b/lite/backends/arm/math/fp16/interpolate_fp16.h index 7ac96cbdf45..c78fa414da7 100644 --- a/lite/backends/arm/math/fp16/interpolate_fp16.h +++ b/lite/backends/arm/math/fp16/interpolate_fp16.h @@ -68,7 +68,8 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type); + std::string interpolate_type, + std::vector scale_data); } // namespace fp16 } // namespace math diff --git a/lite/backends/arm/math/interpolate.cc b/lite/backends/arm/math/interpolate.cc index fb935d00218..0bff70459b3 100644 --- a/lite/backends/arm/math/interpolate.cc +++ b/lite/backends/arm/math/interpolate.cc @@ -509,13 +509,39 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type) { + std::string interpolate_type, + std::vector scale_data) { int in_h = X->dims()[2]; int in_w = X->dims()[3]; + float height_scale = 0.f; + float width_scale = 0.f; + if (SizeTensor.size() > 0) { auto new_size = get_new_shape(SizeTensor); out_height = new_size[0]; out_width = new_size[1]; + } else if (scale_data.size() > 0) { + if (scale_data.size() == 1) { + if (scale_data[0] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[0]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } else if (scale_data.size() == 2) { + if (scale_data[0] > 0 && scale_data[1] > 0) { + out_height = static_cast(in_h * scale_data[0]); + out_width = static_cast(in_w * scale_data[1]); + } else { + LOG(FATAL) << "scale data <= 0"; + } + } + auto out_size = OutSize; + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_height = out_size_data[0]; + out_width = out_size_data[1]; + } } else { auto scale_tensor = Scale; if (scale_tensor != nullptr) { @@ -533,11 +559,14 @@ void interpolate(lite::Tensor* X, out_width = out_size_data[1]; } } - float height_scale = scale; - float width_scale = scale; + height_scale = scale; + width_scale = scale; if (out_width > 0 && out_height > 0) { height_scale = static_cast(out_height / X->dims()[2]); width_scale = static_cast(out_width / X->dims()[3]); + } else { + out_height = static_cast(X->dims()[2] * height_scale + 0.5f); + out_width = static_cast(X->dims()[3] * width_scale + 0.5f); } int num_cout = X->dims()[0]; int c_cout = X->dims()[1]; diff --git a/lite/backends/arm/math/interpolate.h b/lite/backends/arm/math/interpolate.h index 5c37670ec57..a4db2adf43c 100644 --- a/lite/backends/arm/math/interpolate.h +++ b/lite/backends/arm/math/interpolate.h @@ -52,7 +52,8 @@ void interpolate(lite::Tensor* X, float scale, bool with_align, int align_mode, - std::string interpolate_type); + std::string interpolate_type, + std::vector scale_data); } /* namespace math */ } /* namespace arm */ diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 6602aedaf02..1dfa44b98dc 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -37,6 +37,7 @@ add_kernel(transpose_compute_arm ARM basic SRCS transpose_compute.cc) add_kernel(shuffle_channel_compute_arm ARM basic SRCS shuffle_channel_compute.cc) add_kernel(argmax_compute_arm ARM basic SRCS argmax_compute.cc) add_kernel(conv_transpose_compute_arm ARM basic SRCS conv_transpose_compute.cc) +add_kernel(depthwise_conv_transpose_compute_arm ARM extra SRCS depthwise_conv_transpose_compute.cc) add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc) add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc) add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc) diff --git a/lite/kernels/arm/depthwise_conv_transpose_compute.cc b/lite/kernels/arm/depthwise_conv_transpose_compute.cc new file mode 100644 index 00000000000..b78ad73f598 --- /dev/null +++ b/lite/kernels/arm/depthwise_conv_transpose_compute.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/depthwise_conv_transpose_compute.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm {} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTransFp32; +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTranInt8_Fp32; +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTranInt8_Int8; + +#ifdef ENABLE_ARM_FP16 +typedef paddle::lite::kernels::arm:: + DepthwiseConv2DTransposeCompute + DepConvTranFp16; + +REGISTER_LITE_KERNEL( + depthwise_conv2d_transpose, kARM, kFP16, kNCHW, DepConvTranFp16, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); + +#endif // ENABLE_ARM_FP16 + +REGISTER_LITE_KERNEL( + depthwise_conv2d_transpose, kARM, kFloat, kNCHW, DepConvTransFp32, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d_transpose, + kARM, + kInt8, + kNCHW, + DepConvTranInt8_Fp32, + fp32_out) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d_transpose, + kARM, + kInt8, + kNCHW, + DepConvTranInt8_Int8, + int8_out) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindPaddleOpVersion("conv2d_transpose", 1) + .Finalize(); diff --git a/lite/kernels/arm/depthwise_conv_transpose_compute.h b/lite/kernels/arm/depthwise_conv_transpose_compute.h new file mode 100644 index 00000000000..b70908acc33 --- /dev/null +++ b/lite/kernels/arm/depthwise_conv_transpose_compute.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/kernel.h" +#include "lite/operators/conv_transpose_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { +template +class DepthwiseConv2DTransposeCompute : public KernelLite { +}; +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/interpolate_compute.cc b/lite/kernels/arm/interpolate_compute.cc index 0a335e06e4e..6bc6d1181c0 100644 --- a/lite/kernels/arm/interpolate_compute.cc +++ b/lite/kernels/arm/interpolate_compute.cc @@ -39,11 +39,12 @@ namespace arm { int out_h = param.out_h; \ bool align_corners = param.align_corners; \ int align_mode = param.align_mode; \ + auto scale_v = param.scale_v; \ std::string interp_method = method_name; #define INTERP_PARAM \ X, OutSize, SizeTensor, Scale, Out, out_h, out_w, scale, align_corners, \ - align_mode, interp_method + align_mode, interp_method, scale_v template <> void BilinearInterpCompute::Run() { From f66a46f58ce5eaa74bf39f512f4177c0dd36fa25 Mon Sep 17 00:00:00 2001 From: mjp9527 <54735487+mjp9527@users.noreply.github.com> Date: Tue, 19 Oct 2021 17:37:45 +0800 Subject: [PATCH 4/4] [x86] fix im2col when datasize > 2^31 (the max value that int can express) (#7300) * [x86] fix im2col when datasize > 2^31 (the max value that int can express) * fix windows ci --- lite/backends/x86/math/avx/conv_utils.cc | 72 +++++++++++++----------- lite/kernels/x86/conv_compute.cc | 16 +++--- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/lite/backends/x86/math/avx/conv_utils.cc b/lite/backends/x86/math/avx/conv_utils.cc index 814e686f62a..4692d179162 100644 --- a/lite/backends/x86/math/avx/conv_utils.cc +++ b/lite/backends/x86/math/avx/conv_utils.cc @@ -861,31 +861,36 @@ void im2col_s1(const float* data_im, (width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) + 1; const int in_channel_size = height * width; const int out_channel_size = output_h * output_w; - const int output_plane_size = output_h * output_w * kernel_h * kernel_w; - memset(data_col, 0, output_plane_size * channels * sizeof(float)); + const unsigned int output_plane_size = + output_h * output_w * kernel_h * kernel_w; + size_t tmp_size = static_cast(output_plane_size); + size_t mem_size = tmp_size * channels * sizeof(float); + memset(data_col, 0, mem_size); #pragma omp parallel for for (int c = 0; c < channels; c++) { - int data_im_z = c * in_channel_size; - int data_col_z1 = c * output_plane_size; + unsigned int data_im_z = c * in_channel_size; + unsigned int data_col_z1 = c * output_plane_size; for (int ky = 0, h_offset = 0; ky < kernel_h; ky++, h_offset += dilation_h) { - int data_col_z2 = ky * out_channel_size * kernel_w; + unsigned int data_col_z2 = ky * out_channel_size * kernel_w; for (int kx = 0, w_offset = 0; kx < kernel_w; kx++, w_offset += dilation_w) { - int data_col_z3 = kx * out_channel_size; - int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; - int oh_begin = std::max(((pad_top - h_offset)), 0); - int oh_end = std::min(((height + pad_bottom - h_offset)), output_h); + unsigned int data_col_z3 = kx * out_channel_size; + unsigned int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; + unsigned int oh_begin = std::max(((pad_top - h_offset)), 0); + unsigned int oh_end = + std::min(((height + pad_bottom - h_offset)), output_h); oh_end = std::max(oh_begin, oh_end); - int ow_begin = std::max(((pad_left - w_offset)), 0); - int ow_end = std::min(((width + pad_right - w_offset)), output_w); + unsigned int ow_begin = std::max(((pad_left - w_offset)), 0); + unsigned int ow_end = + std::min(((width + pad_right - w_offset)), output_w); ow_end = std::max(ow_begin, ow_end); - int ih = oh_begin - pad_top + h_offset; + unsigned int ih = oh_begin - pad_top + h_offset; for (int oh = oh_begin; oh < oh_end; ++oh, ++ih) { - int iw = ow_begin - pad_left + w_offset; - int ow = ow_begin; - int data_im_offset = data_im_z + ih * width; - int data_col_offset = data_col_z + oh * output_w; + unsigned int iw = ow_begin - pad_left + w_offset; + unsigned int ow = ow_begin; + unsigned int data_im_offset = data_im_z + ih * width; + unsigned int data_col_offset = data_col_z + oh * output_w; const float* data_im_ptr = data_im + data_im_offset; float* data_col_ptr = data_col + data_col_offset; #ifdef __AVX__ @@ -929,33 +934,36 @@ void im2col_s2(const float* data_im, (width + pad_left + pad_right - (dilation_w * (kernel_w - 1) + 1)) / 2 + 1; const int in_channel_size = height * width; - const int output_plane_size = output_h * output_w * kernel_h * kernel_w; - memset(data_col, 0, output_plane_size * channels * sizeof(float)); + const unsigned int output_plane_size = + output_h * output_w * kernel_h * kernel_w; + size_t tmp_size = static_cast(output_plane_size); + size_t mem_size = tmp_size * channels * sizeof(float); + memset(data_col, 0, mem_size); #pragma omp parallel for for (int c = 0; c < channels; c++) { - int data_im_z = c * in_channel_size; - int data_col_z1 = c * output_plane_size; + unsigned int data_im_z = c * in_channel_size; + unsigned int data_col_z1 = c * output_plane_size; for (int ky = 0, h_offset = 0; ky < kernel_h; ky++, h_offset += dilation_h) { - int data_col_z2 = ky * output_h * output_w * kernel_w; + unsigned int data_col_z2 = ky * output_h * output_w * kernel_w; for (int kx = 0, w_offset = 0; kx < kernel_w; kx++, w_offset += dilation_w) { - int data_col_z3 = kx * output_h * output_w; - int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; - int oh_begin = std::max(((pad_top - h_offset + 1) / 2), 0); - int oh_end = + unsigned int data_col_z3 = kx * output_h * output_w; + unsigned int data_col_z = data_col_z1 + data_col_z2 + data_col_z3; + unsigned int oh_begin = std::max(((pad_top - h_offset + 1) / 2), 0); + unsigned int oh_end = std::min(((height + pad_bottom - h_offset + 1) / 2), output_h); oh_end = std::max(oh_begin, oh_end); - int ow_begin = std::max(((pad_left - w_offset + 1) / 2), 0); - int ow_end = + unsigned int ow_begin = std::max(((pad_left - w_offset + 1) / 2), 0); + unsigned int ow_end = std::min(((width + pad_right - w_offset + 1) / 2), output_w); ow_end = std::max(ow_begin, ow_end); - int ih = oh_begin * 2 - pad_top + h_offset; + unsigned int ih = oh_begin * 2 - pad_top + h_offset; for (int oh = oh_begin; oh < oh_end; ++oh, ih += 2) { - int iw = ow_begin * 2 - pad_left + w_offset; - int ow = ow_begin; - int data_im_offset = data_im_z + ih * width; - int data_col_offset = data_col_z + oh * output_w; + unsigned int iw = ow_begin * 2 - pad_left + w_offset; + unsigned int ow = ow_begin; + unsigned int data_im_offset = data_im_z + ih * width; + unsigned int data_col_offset = data_col_z + oh * output_w; const float* data_im_ptr = data_im + data_im_offset; float* data_col_ptr = data_col + data_col_offset; for (; ow + 3 < ow_end; ow += 4, iw += 8) { diff --git a/lite/kernels/x86/conv_compute.cc b/lite/kernels/x86/conv_compute.cc index a1ea4585644..0a33848f17a 100644 --- a/lite/kernels/x86/conv_compute.cc +++ b/lite/kernels/x86/conv_compute.cc @@ -118,11 +118,11 @@ void Conv2dCompute::Run() { auto& ctx = ctx_->As(); INIT_PARAM bool flag_bias = (param.bias != nullptr); - int group_size_out = m * n; - int group_size_weights = m * k; - int group_size_coldata = n * k; - int channel_in_size = chin * hin * win; - int channel_out_size = chout * hout * wout; + unsigned int group_size_out = m * n; + unsigned int group_size_weights = m * k; + unsigned int group_size_coldata = n * k; + unsigned int channel_in_size = chin * hin * win; + unsigned int channel_out_size = chout * hout * wout; auto paddings = *param.paddings; auto dilations = *param.dilations; @@ -135,9 +135,9 @@ void Conv2dCompute::Run() { float* col_data = nullptr; if (!flag_1x1gemm_) { - int col_size = group * group_size_coldata; - col_data = static_cast( - TargetMalloc(TARGET(kX86), col_size * sizeof(float))); + size_t col_size = group_size_coldata * group; + size_t col_data_size = static_cast(col_size * sizeof(float)); + col_data = static_cast(TargetMalloc(TARGET(kX86), col_data_size)); } auto act_param = param.activation_param; paddle::lite::x86::math::Blas matmul(ctx);