diff --git a/lite/backends/arm/math/gru.h b/lite/backends/arm/math/gru.h new file mode 100644 index 00000000000..1492c57d6a2 --- /dev/null +++ b/lite/backends/arm/math/gru.h @@ -0,0 +1,248 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/backends/arm/math/sgemm.h" +#ifdef LITE_WITH_ARM +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +struct RNNGRUValue { + const T* gate_weight; + const T* state_weight; + const T* reset_bias; + T* gate_value; + T* reset_output_value; + T* output_value; + const T* prev_out_value; +}; + +template +void rnn_activation(const T* din, + T* dout, + int size, + lite_api::ActivationType act_type, + int threads) { + switch (act_type) { + case lite_api::ActivationType::kSigmoid: + act_sigmoid(din, dout, size, threads); + break; + case lite_api::ActivationType::kSigmoid_v2: + act_sigmoid(din, dout, size, threads); + break; + case lite_api::ActivationType::kTanh: + act_tanh(din, dout, size, threads); + break; + case lite_api::ActivationType::kTanh_v2: + act_tanh(din, dout, size, threads); + break; + case lite_api::ActivationType::kRelu: + act_relu(din, dout, size, threads); + break; + default: + LOG(FATAL) << "unsupport activation type:" << static_cast(act_type); + break; + } +} + +template +void compute_kernel(RNNGRUValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + + for (int b = 0; b < batch_size; b++) { + rnn_activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + rnn_activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + for (int i = 0; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + rnn_activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template <> +void compute_kernel(RNNGRUValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + int i = 0; + float32x4_t vec_one = vdupq_n_f32(1.f); + + for (int b = 0; b < batch_size; b++) { + rnn_activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + rnn_activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + for (i = 0; i + 3 < frame_size; i += 4) { + float32x4_t vec_out = vld1q_f32(value_reset_output + i); + float32x4_t vec_reset = vld1q_f32(value_reset_gate + i); + float32x4_t vec_bias = vld1q_f32(value_reset_bias + i); + vec_out = vmulq_f32(vaddq_f32(vec_out, vec_bias), vec_reset); + vst1q_f32(value_reset_output + i, vec_out); + vst1q_f32(cell_state_value + i, + vaddq_f32(vec_out, vld1q_f32(cell_state_value + i))); + } + for (; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + rnn_activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + for (i = 0; i + 3 < frame_size; i += 4) { + float32x4_t vec_vug = vld1q_f32(value_update_gate + i); + float32x4_t vec_vpo = vld1q_f32(value_prev_out + i); + float32x4_t vec_csv = vld1q_f32(cell_state_value + i); + vec_vpo = vmulq_f32(vec_vug, vec_vpo); + float32x4_t vec_out = + vmlaq_f32(vec_vpo, vsubq_f32(vec_one, vec_vug), vec_csv); + vst1q_f32(value_output + i, vec_out); + } + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + for (i = 0; i + 3 < frame_size; i += 4) { + float32x4_t vec_vug = vld1q_f32(value_update_gate + i); + float32x4_t vec_csv = vld1q_f32(cell_state_value + i); + float32x4_t vec_out = vmulq_f32(vsubq_f32(vec_one, vec_vug), vec_csv); + vst1q_f32(value_output + i, vec_out); + } + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template +struct RnnGruUnitFunctorV2 { + static void compute(ARMContext* ctx, + RNNGRUValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + if (value.prev_out_value) { + operators::ActivationParam act_param; + act_param.has_active = false; + lite::arm::math::sgemm(false, + true, + batch_size, + frame_size, + frame_size, + 1.f, + value.prev_out_value, + frame_size, + value.state_weight, + frame_size, + 0.f, + value.reset_output_value, + frame_size, + nullptr, + false, + act_param, + ctx); + } + compute_kernel(value, frame_size, batch_size, active_node, active_gate); + } +}; + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/lstm.cc b/lite/backends/arm/math/lstm.cc index cd8e012a287..bf096ed04c0 100644 --- a/lite/backends/arm/math/lstm.cc +++ b/lite/backends/arm/math/lstm.cc @@ -36,6 +36,7 @@ void add_bias_rowwise(Tensor* input, i_data += width; } } + void vector_dot( float* out, const float* in, const float* v1, int size, const float* v2) { int loop = size >> 2; diff --git a/lite/backends/x86/math/elementwise.h b/lite/backends/x86/math/elementwise.h index 6a5c8cf92f6..2de0b831783 100644 --- a/lite/backends/x86/math/elementwise.h +++ b/lite/backends/x86/math/elementwise.h @@ -150,24 +150,27 @@ namespace math { } \ } -// marco func add -ElementWiseFunc(Add) ElementWiseFuncBCast(Add) - // marco func sub - ElementWiseFunc(Sub) ElementWiseFuncBCast(Sub) - // marco func mul - ElementWiseFunc(Mul) ElementWiseFuncBCast(Mul) - // marco func max - ElementWiseFunc(Max) ElementWiseFuncBCast(Max) - // marco func min - ElementWiseFunc(Min) ElementWiseFuncBCast(Min) - // marco func div - ElementWiseFunc(Div) ElementWiseFuncBCast(Div) - // marco func floordiv - ElementWiseFunc(FloorDiv) ElementWiseFuncBCast(FloorDiv) - // marco func mod - ElementWiseFunc(Mod) ElementWiseFuncBCast(Mod) - // marco func pow - ElementWiseFunc(Pow) ElementWiseFuncBCast(Pow) +// clang-format off +ElementWiseFunc(Add) +ElementWiseFuncBCast(Add) +ElementWiseFunc(Sub) +ElementWiseFuncBCast(Sub) +ElementWiseFunc(Mul) +ElementWiseFuncBCast(Mul) +ElementWiseFunc(Max) +ElementWiseFuncBCast(Max) +ElementWiseFunc(Min) +ElementWiseFuncBCast(Min) +ElementWiseFunc(Div) +ElementWiseFuncBCast(Div) +ElementWiseFunc(FloorDiv) +ElementWiseFuncBCast(FloorDiv) +ElementWiseFunc(Mod) +ElementWiseFuncBCast(Mod) +ElementWiseFunc(Pow) +ElementWiseFuncBCast(Pow) +// clang-format on + } // namespace math } // namespace x86 } // namespace lite diff --git a/lite/backends/x86/math/fill_bias_activate.cc b/lite/backends/x86/math/fill_bias_activate.cc index 31685b90778..126fe1f382c 100644 --- a/lite/backends/x86/math/fill_bias_activate.cc +++ b/lite/backends/x86/math/fill_bias_activate.cc @@ -38,7 +38,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) { __m256 vec_data = _mm256_loadu_ps(data + i); _mm256_storeu_ps(data + i, _mm256_max_ps(vec_data, vec_zero)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ __m128 vec_zero_128 = _mm_set1_ps(0.f); @@ -59,7 +58,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) { _mm256_storeu_ps( data + i, _mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ __m128 vec_zero_128 = _mm_set1_ps(0.f); @@ -112,7 +110,6 @@ static void activate_relu_inplace_bias(float *data, vec_data = _mm256_add_ps(vec_bias, vec_data); _mm256_storeu_ps(tmp_data + i, _mm256_max_ps(vec_data, vec_zero)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ vec_bias_128 = _mm_set1_ps(bias[j]); @@ -140,7 +137,6 @@ static void activate_relu_inplace_bias(float *data, tmp_data + i, _mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph)); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ vec_bias_128 = _mm_set1_ps(bias[j]); @@ -174,7 +170,6 @@ static void activate_lrelu_inplace(float *data, int len, float alpha) { __m256 vec_mask = _mm256_cmp_ps(vec_data, vec_zero, cmp_le_os); _mm256_storeu_ps(data + i, _mm256_blendv_ps(vec_data, vec_lr, vec_mask)); } -// _mm256_zeroupper(); #endif #ifdef __SSE4_1__ // blendv need 4.1 __m128 vec_zero_128 = _mm_set1_ps(0.f); @@ -226,7 +221,6 @@ static void activate_lrelu_inplace_bias(float *data, _mm256_storeu_ps(tmp_data + i, _mm256_blendv_ps(vec_data, vec_lr, vec_mask)); } -// _mm256_zeroupper(); #endif #ifdef __SSE4_1__ vec_bias_128 = _mm_set1_ps(bias[j]); @@ -273,7 +267,6 @@ static void activate_none_inplace_bias(float *data, vec_data = _mm256_add_ps(vec_bias, vec_data); _mm256_storeu_ps(tmp_data + i, vec_data); } -// _mm256_zeroupper(); #endif #ifdef __SSE__ vec_bias_128 = _mm_set1_ps(bias[j]); diff --git a/lite/backends/x86/math/rnn.cc b/lite/backends/x86/math/rnn.cc deleted file mode 100644 index bddcf5228e4..00000000000 --- a/lite/backends/x86/math/rnn.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef __AVX__ -#include -#endif -#ifdef __SSE__ -#include -#endif - -#include "lite/backends/x86/math/activation_functions.h" -#include "lite/backends/x86/math/rnn.h" - -namespace paddle { -namespace lite { -namespace x86 { -namespace math { - -void vector_dot( - float* out, const float* in, const float* v1, int size, const float* v2) { -#if defined(__AVX__) - __m256 vec_in, vec_v1, vec_v2; -#endif -#if defined(__SSE__) - __m128 vec_in_128, vec_v1_128, vec_v2_128; -#endif - - int i = 0; - if (nullptr == v2) { - i = 0; - -// in_out * v1 -#if defined(__AVX__) - for (; i + 7 < size; i += 8) { - vec_in = _mm256_loadu_ps(in + i); - vec_v1 = _mm256_loadu_ps(v1 + i); - _mm256_storeu_ps(out + i, _mm256_mul_ps(vec_in, vec_v1)); - } -// _mm256_zeroupper(); -#endif -#if defined(__SSE__) - for (; i + 3 < size; i += 4) { - vec_in_128 = _mm_loadu_ps(in + i); - vec_v1_128 = _mm_loadu_ps(v1 + i); - _mm_storeu_ps(out + i, _mm_mul_ps(vec_in_128, vec_v1_128)); - } -#endif - for (; i < size; i++) { - out[i] = in[i] * v1[i]; - } - } else { - i = 0; - -// in_out + v1 * v2 -#if defined(__AVX__) && defined(__FMA__) - for (; i + 7 < size; i += 8) { - vec_in = _mm256_loadu_ps(in + i); - vec_v1 = _mm256_loadu_ps(v1 + i); - vec_v2 = _mm256_loadu_ps(v2 + i); - _mm256_storeu_ps(out + i, _mm256_fmadd_ps(vec_v2, vec_v1, vec_in)); - } - for (; i + 3 < size; i += 4) { - vec_in_128 = _mm_loadu_ps(in + i); - vec_v1_128 = _mm_loadu_ps(v1 + i); - vec_v2_128 = _mm_loadu_ps(v2 + i); - _mm_storeu_ps(out + i, _mm_fmadd_ps(vec_v2_128, vec_v1_128, vec_in_128)); - } -#endif - for (; i < size; i++) { - out[i] = in[i] + v1[i] * v2[i]; - } - } -} - -template <> -void act_relu(const float* din, float* dout, int size, int threads) { - int i = 0; - -#ifdef __AVX__ - for (; i + 7 < size; i += 8) { - __m256 a = _mm256_loadu_ps(din + i); - _mm256_storeu_ps(dout + i, lite::x86::math::detail::forward::avx::Relu(a)); - } -#endif - for (; i < size; i++) { - dout[i] = lite::x86::math::detail::forward::Relu(din[i]); - } -} - -template <> -void act_sigmoid(const float* din, float* dout, int size, int threads) { - int i = 0; - -#ifdef __AVX__ - for (; i + 7 < size; i += 8) { - __m256 a = _mm256_loadu_ps(din + i); - _mm256_storeu_ps(dout + i, - lite::x86::math::detail::forward::avx::Sigmoid(a)); - } -#endif - for (; i < size; i++) { - dout[i] = lite::x86::math::detail::forward::Sigmoid(din[i]); - } -} - -template <> -void act_tanh(const float* din, float* dout, int size, int threads) { - int i = 0; - -#ifdef __AVX__ - for (; i + 7 < size; i += 8) { - __m256 a = _mm256_loadu_ps(din + i); - _mm256_storeu_ps(dout + i, lite::x86::math::detail::forward::avx::Tanh(a)); - } -#endif - for (; i < size; i++) { - dout[i] = lite::x86::math::detail::forward::Tanh(din[i]); - } -} - -void fill_bias_fc(float* out, const float* bias, int num, int channel) { -#ifdef __AVX__ - __m256 vec_bias = {0.f}; - __m256 vec_data = {0.f}; -#endif -#ifdef __SSE__ - __m128 vec_bias_128 = {0.f}; - __m128 vec_data_128 = {0.f}; -#endif - int i = 0; - - for (int j = 0; j < num; j++) { - float* ptr = out + j * channel; - const float* pbias = bias; - i = 0; - -#ifdef __AVX__ - for (; i + 7 < channel; i += 8) { - vec_bias = _mm256_loadu_ps(pbias + i); - vec_data = _mm256_loadu_ps(ptr + i); - _mm256_storeu_ps(ptr + i, _mm256_add_ps(vec_data, vec_bias)); - } -// _mm256_zeroupper(); -#endif -#ifdef __SSE__ - for (; i + 3 < channel; i += 4) { - vec_bias_128 = _mm_loadu_ps(pbias + i); - vec_data_128 = _mm_loadu_ps(ptr + i); - _mm_storeu_ps(ptr + i, _mm_add_ps(vec_data_128, vec_bias_128)); - } -#endif - for (; i < channel; i++) { - *(ptr + i) = pbias[i] + ptr[i]; - } - } -} - -} // namespace math -} // namespace x86 -} // namespace lite -} // namespace paddle diff --git a/lite/backends/x86/math/rnn.h b/lite/backends/x86/math/rnn.h index 82d467fcad5..620ebe8c23d 100644 --- a/lite/backends/x86/math/rnn.h +++ b/lite/backends/x86/math/rnn.h @@ -15,14 +15,33 @@ #pragma once #include +#include "lite/backends/x86/math/activation_functions.h" +#include "lite/backends/x86/math/blas.h" #include "lite/core/tensor.h" #include "lite/utils/log/logging.h" +#ifdef __AVX__ +#include +#endif +#ifdef __SSE__ +#include +#endif + +#ifndef __FMA__ +#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps((c), _mm256_mul_ps((a), (b))) +#define _mm_fmadd_ps(a, b, c) _mm_add_ps((c), _mm_mul_ps((a), (b))) +#endif + namespace paddle { namespace lite { namespace x86 { namespace math { +namespace x86_forward = paddle::lite::x86::math::detail::forward; + +//************************************** +// Class Def +//************************************** template struct LstmMetaValue { T* gate_value; @@ -35,25 +54,183 @@ struct LstmMetaValue { T* check_og; }; +template +struct GRUMetaValue { + const T* gate_weight; + const T* state_weight; + const T* reset_bias; + T* gate_value; + T* reset_output_value; + T* output_value; + const T* prev_out_value; +}; + +//********************************* +// Inline Function +//********************************* // if v2 isn't null: out[i] = in[i] + v1[i] * v2[i]; // if v2 is null: out[i] = in[i] * v1[i]; -void vector_dot(float* out, - const float* in, - const float* v1, - int size, - const float* v2 = nullptr); +inline void vector_dot(float* out, + const float* in, + const float* v1, + int size, + const float* v2 = nullptr) { +#if defined(__AVX__) + __m256 vec_in, vec_v1, vec_v2; +#endif +#if defined(__SSE__) + __m128 vec_in_128, vec_v1_128, vec_v2_128; +#endif + + int i = 0; + if (nullptr == v2) { + i = 0; + +// in_out * v1 +#if defined(__AVX__) + for (; i + 7 < size; i += 8) { + vec_in = _mm256_loadu_ps(in + i); + vec_v1 = _mm256_loadu_ps(v1 + i); + _mm256_storeu_ps(out + i, _mm256_mul_ps(vec_in, vec_v1)); + } +#endif +#if defined(__SSE__) + for (; i + 3 < size; i += 4) { + vec_in_128 = _mm_loadu_ps(in + i); + vec_v1_128 = _mm_loadu_ps(v1 + i); + _mm_storeu_ps(out + i, _mm_mul_ps(vec_in_128, vec_v1_128)); + } +#endif + for (; i < size; i++) { + out[i] = in[i] * v1[i]; + } + } else { + i = 0; + +// in_out + v1 * v2 +#if defined(__AVX__) + for (; i + 7 < size; i += 8) { + vec_in = _mm256_loadu_ps(in + i); + vec_v1 = _mm256_loadu_ps(v1 + i); + vec_v2 = _mm256_loadu_ps(v2 + i); + _mm256_storeu_ps(out + i, _mm256_fmadd_ps(vec_v2, vec_v1, vec_in)); + } +#endif +#if defined(__SSE__) + for (; i + 3 < size; i += 4) { + vec_in_128 = _mm_loadu_ps(in + i); + vec_v1_128 = _mm_loadu_ps(v1 + i); + vec_v2_128 = _mm_loadu_ps(v2 + i); + _mm_storeu_ps(out + i, _mm_fmadd_ps(vec_v2_128, vec_v1_128, vec_in_128)); + } +#endif + for (; i < size; i++) { + out[i] = in[i] + v1[i] * v2[i]; + } + } +} + +inline void fill_bias_fc(float* out, const float* bias, int num, int channel) { +#ifdef __AVX__ + __m256 vec_bias = {0.f}; + __m256 vec_data = {0.f}; +#endif +#ifdef __SSE__ + __m128 vec_bias_128 = {0.f}; + __m128 vec_data_128 = {0.f}; +#endif + int i = 0; -// only add bias -void fill_bias_fc(float* out, const float* bias, int num, int channel); + for (int j = 0; j < num; j++) { + float* ptr = out + j * channel; + const float* pbias = bias; + i = 0; +#ifdef __AVX__ + for (; i + 7 < channel; i += 8) { + vec_bias = _mm256_loadu_ps(pbias + i); + vec_data = _mm256_loadu_ps(ptr + i); + _mm256_storeu_ps(ptr + i, _mm256_add_ps(vec_data, vec_bias)); + } +#endif +#ifdef __SSE__ + for (; i + 3 < channel; i += 4) { + vec_bias_128 = _mm_loadu_ps(pbias + i); + vec_data_128 = _mm_loadu_ps(ptr + i); + _mm_storeu_ps(ptr + i, _mm_add_ps(vec_data_128, vec_bias_128)); + } +#endif + for (; i < channel; i++) { + *(ptr + i) = pbias[i] + ptr[i]; + } + } +} + +//******************************* +// Template Func +//******************************* template -void act_relu(const T* din, T* dout, int size, int threads); +void act_relu(const T* din, T* dout, int size, int threads) { + for (int i = 0; i < size; i++) { + dout[i] = x86_forward::Relu(din[i]); + } +} template -void act_sigmoid(const T* din, T* dout, int size, int threads); +void act_sigmoid(const T* din, T* dout, int size, int threads) { + for (int i = 0; i < size; i++) { + dout[i] = x86_forward::Sigmoid(din[i]); + } +} template -void act_tanh(const T* din, T* dout, int size, int threads); +void act_tanh(const T* din, T* dout, int size, int threads) { + for (int i = 0; i < size; i++) { + dout[i] = x86_forward::Tanh(din[i]); + } +} + +template <> +void act_relu(const float* din, float* dout, int size, int threads) { + int i = 0; +#ifdef __AVX__ + for (; i + 7 < size; i += 8) { + __m256 a = _mm256_loadu_ps(din + i); + _mm256_storeu_ps(dout + i, x86_forward::avx::Relu(a)); + } +#endif + for (; i < size; i++) { + dout[i] = x86_forward::Relu(din[i]); + } +} + +template <> +void act_sigmoid(const float* din, float* dout, int size, int threads) { + int i = 0; +#ifdef __AVX__ + for (; i + 7 < size; i += 8) { + __m256 a = _mm256_loadu_ps(din + i); + _mm256_storeu_ps(dout + i, x86_forward::avx::Sigmoid(a)); + } +#endif + for (; i < size; i++) { + dout[i] = x86_forward::Sigmoid(din[i]); + } +} + +template <> +void act_tanh(const float* din, float* dout, int size, int threads) { + int i = 0; +#ifdef __AVX__ + for (; i + 7 < size; i += 8) { + __m256 a = _mm256_loadu_ps(din + i); + _mm256_storeu_ps(dout + i, x86_forward::avx::Tanh(a)); + } +#endif + for (; i < size; i++) { + dout[i] = x86_forward::Tanh(din[i]); + } +} template void activation( @@ -97,6 +274,9 @@ void activation(const T* din, } } +//*********************************** +// LSTM MODE +//*********************************** template struct RnnLstmUnitFunctor { static void compute(LstmMetaValue value, @@ -163,6 +343,237 @@ struct RnnLstmUnitFunctor { } }; +//************************************ +// GRU MODE +//************************************ +template +void GruRnnComputeKernel(GRUMetaValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + + for (int b = 0; b < batch_size; b++) { + activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + + for (int i = 0; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + for (int i = 0; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template <> +void GruRnnComputeKernel(GRUMetaValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + auto value_reset_gate = value.gate_value; + auto value_update_gate = value.gate_value + frame_size; + auto value_reset_output = value.reset_output_value; + auto value_reset_bias = value.reset_bias; + auto cell_state_value = value.gate_value + 2 * frame_size; + auto value_output = value.output_value; + auto value_prev_out = value.prev_out_value; + int i = 0; + +#ifdef __AVX__ + __m256 vec_one_256 = _mm256_set1_ps(1.0f); +#endif +#ifdef __SSE__ + __m128 vec_one_128 = _mm_set1_ps(1.0f); +#endif + + for (int b = 0; b < batch_size; b++) { + activation(value_reset_gate, + value_reset_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + activation(value_update_gate, + value_update_gate, + frame_size, + lite_api::ActivationType::kSigmoid_v2, + 1); + i = 0; +#ifdef __AVX__ + for (; i + 7 < frame_size; i += 8) { + __m256 vec_out = _mm256_loadu_ps(value_reset_output + i); + __m256 vec_reset = _mm256_loadu_ps(value_reset_gate + i); + __m256 vec_bias = _mm256_loadu_ps(value_reset_bias + i); + vec_out = _mm256_mul_ps(_mm256_add_ps(vec_out, vec_bias), vec_reset); + _mm256_storeu_ps(value_reset_output + i, vec_out); + _mm256_storeu_ps( + cell_state_value + i, + _mm256_add_ps(vec_out, _mm256_loadu_ps(cell_state_value + i))); + } +#endif +#ifdef __SSE__ + for (; i + 3 < frame_size; i += 4) { + __m128 vec_out = _mm_loadu_ps(value_reset_output + i); + __m128 vec_reset = _mm_loadu_ps(value_reset_gate + i); + __m128 vec_bias = _mm_loadu_ps(value_reset_bias + i); + vec_out = _mm_mul_ps(_mm_add_ps(vec_out, vec_bias), vec_reset); + _mm_storeu_ps(value_reset_output + i, vec_out); + _mm_storeu_ps(cell_state_value + i, + _mm_add_ps(vec_out, _mm_loadu_ps(cell_state_value + i))); + } +#endif + for (; i < frame_size; i++) { + value_reset_output[i] = + (value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i]; + cell_state_value[i] += value_reset_output[i]; + } + + activation(cell_state_value, + cell_state_value, + frame_size, + lite_api::ActivationType::kTanh_v2, + 1); + + if (value.prev_out_value) { + i = 0; +#ifdef __AVX__ + for (; i + 7 < frame_size; i += 8) { + __m256 vec_vug = _mm256_loadu_ps(value_update_gate + i); + __m256 vec_vpo = _mm256_loadu_ps(value_prev_out + i); + __m256 vec_csv = _mm256_loadu_ps(cell_state_value + i); + vec_vpo = _mm256_mul_ps(vec_vug, vec_vpo); +#ifdef __FMA__ + __m256 vec_out = _mm256_fmadd_ps( + vec_csv, _mm256_sub_ps(vec_one_256, vec_vug), vec_vpo); +#else + __m256 vec_out = _mm256_add_ps( + _mm256_mul_ps(vec_csv, _mm256_sub_ps(vec_one_256, vec_vug)), + vec_vpo); +#endif + _mm256_storeu_ps(value_output + i, vec_out); + } +#endif +#ifdef __SSE__ + for (; i + 3 < frame_size; i += 4) { + __m128 vec_vug = _mm_loadu_ps(value_update_gate + i); + __m128 vec_vpo = _mm_loadu_ps(value_prev_out + i); + __m128 vec_csv = _mm_loadu_ps(cell_state_value + i); + vec_vpo = _mm_mul_ps(vec_vug, vec_vpo); + __m128 vec_out = _mm_add_ps( + _mm_mul_ps(vec_csv, _mm_sub_ps(vec_one_128, vec_vug)), vec_vpo); + _mm_storeu_ps(value_output + i, vec_out); + } +#endif + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] + + value_update_gate[i] * value_prev_out[i]; + } + } else { + i = 0; +#ifdef __AVX__ + for (; i + 7 < frame_size; i += 8) { + __m256 vec_vug = _mm256_loadu_ps(value_update_gate + i); + __m256 vec_csv = _mm256_loadu_ps(cell_state_value + i); + __m256 vec_out = + _mm256_mul_ps(_mm256_sub_ps(vec_one_256, vec_vug), vec_csv); + _mm256_storeu_ps(value_output + i, vec_out); + } +#endif +#ifdef __SSE__ + for (; i + 3 < frame_size; i += 4) { + __m128 vec_vug = _mm_loadu_ps(value_update_gate + i); + __m128 vec_csv = _mm_loadu_ps(cell_state_value + i); + __m128 vec_out = _mm_mul_ps(_mm_sub_ps(vec_one_128, vec_vug), vec_csv); + _mm_storeu_ps(value_output + i, vec_out); + } +#endif + for (; i < frame_size; i++) { + value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i]; + } + } + + value_reset_gate += frame_size * 3; + value_update_gate += frame_size * 3; + value_reset_output += frame_size; + cell_state_value += frame_size * 3; + value_output += frame_size; + if (value.prev_out_value) { + value_prev_out += frame_size; + } + } +} + +template +struct RnnGruUnitFunctorV2 { + static void compute(X86Context* ctx, + GRUMetaValue value, + int frame_size, + int batch_size, + lite_api::ActivationType active_node, + lite_api::ActivationType active_gate) { + if (value.prev_out_value) { + lite::x86::math::Blas matmul(*ctx); + matmul.GEMM(false, + true, + batch_size, + frame_size, + frame_size, + 1.f, + value.prev_out_value, + frame_size, + value.state_weight, + frame_size, + 0.f, + value.reset_output_value, + frame_size); + } + GruRnnComputeKernel( + value, frame_size, batch_size, active_node, active_gate); + } +}; + } // namespace math } // namespace x86 } // namespace lite diff --git a/lite/core/profile/precision_profiler.h b/lite/core/profile/precision_profiler.h index 20587eba9c8..f34f4a14e54 100644 --- a/lite/core/profile/precision_profiler.h +++ b/lite/core/profile/precision_profiler.h @@ -539,8 +539,8 @@ class PrecisionProfiler { std::string out_arg_name; op->op_info()->GetOutputArgname(out_name, &out_arg_name); auto type = kernel->GetOutputDeclType(out_arg_name); - - if (type->IsTensor()) { + auto tmp = op_scope->FindVar(out_name); + if (tmp->IsType()) { const Tensor* tout = op_scope->FindVar(out_name)->GetMutable(); double mean = -999999; @@ -577,7 +577,7 @@ class PrecisionProfiler { << " " << setw(15) << left << mean_str << " " << setw(15) << left << std_dev_str << " " << setw(15) << left << ave_grow_rate_str << std::endl; - } else if (type->IsTensorList()) { + } else if (tmp->IsType>()) { auto touts = op_scope->FindVar(out_name)->GetMutable>(); for (auto t : *touts) { diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index f282d7ee3df..28a6954531c 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -40,7 +40,6 @@ add_kernel(conv_transpose_compute_arm ARM basic SRCS conv_transpose_compute.cc) add_kernel(interpolate_compute_arm ARM basic SRCS interpolate_compute.cc) add_kernel(box_coder_compute_arm ARM basic SRCS box_coder_compute.cc) add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc) -add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc) add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc) add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc) add_kernel(affine_grid_compute_arm ARM basic SRCS affine_grid_compute.cc) diff --git a/lite/kernels/arm/rnn_compute.cc b/lite/kernels/arm/rnn_compute.cc index c405218eeb4..ced9f2d6087 100644 --- a/lite/kernels/arm/rnn_compute.cc +++ b/lite/kernels/arm/rnn_compute.cc @@ -19,6 +19,7 @@ #include #include "lite/backends/arm/math/concat.h" #include "lite/backends/arm/math/funcs.h" +#include "lite/backends/arm/math/gru.h" #include "lite/backends/arm/math/lstm.h" #include "lite/backends/arm/math/sgemm.h" #include "lite/backends/host/math/split.h" @@ -28,6 +29,23 @@ namespace lite { namespace kernels { namespace arm { +// layer, output_tensor, is_bidirection, offset +#define RUN_RNN_LAYER(x, y, z, w) \ + RunRnnLayer(&ctx, \ + input_temp_holder, \ + parameter_lists[x], \ + init_h_unbind, \ + init_c_unbind, \ + sequence_length, \ + &last_h_unbind, \ + &last_c_unbind, \ + y, \ + x, \ + &gate_value, \ + z, \ + w, \ + mode) + void reset_parameter_vector(const std::vector& raw_params_vec, const int& num_layers, const int& gate_num, @@ -66,25 +84,36 @@ void SwapPoniter(Tensor** a, Tensor** b) { *b = c; } -void preprocess(ARMContext* ctx, - const Tensor* input, - const Tensor& weight, - const Tensor& bias_ih, - const Tensor& bias_hh, - Tensor* cache_input, - bool is_test) { +/****************************************************** +input: + ctx:context, + input:(3D)time_step, batch, input_size, + weight:(2D)hidden_size, input_size, + bias_ih, + bias_hh, + mode:LSTM, GRU +output: + cache_input:(3D)time_step, batch, hidden_size +*******************************************************/ +static void preprocess(ARMContext* ctx, + const Tensor* input, + const Tensor& weight, + const Tensor& bias_ih, + const Tensor& bias_hh, + std::string mode, + Tensor* cache_input) { const int& hidden_size = weight.dims()[0]; int time_step = input->dims()[0]; int batch = input->dims()[1]; + std::vector cache_input_dim = {time_step, batch, hidden_size}; DDim gate_dim; gate_dim.ConstructFrom(cache_input_dim); cache_input->Resize(gate_dim); - cache_input->mutable_data(); + auto* i_data = input->data(); auto* w_data = weight.data(); auto* o_data = cache_input->mutable_data(); - bool flag_act = false; operators::ActivationParam act_param; act_param.has_active = false; @@ -112,19 +141,173 @@ void preprocess(ARMContext* ctx, act_param, ctx); lite::arm::math::fill_bias_fc(o_data, bias_ih.data(), m, n, flag_act); - lite::arm::math::fill_bias_fc(o_data, bias_hh.data(), m, n, flag_act); + + if ("GRU" == mode) { + Tensor bias_tmp_hh; + bias_tmp_hh.Resize(bias_hh.dims()); + auto bias_ptr = bias_tmp_hh.mutable_data(); + auto bias_src = bias_hh.data(); + int bias_offt = bias_hh.numel() / 3 * 2; + std::memcpy(bias_ptr, bias_src, bias_offt * sizeof(float)); + std::memset( + bias_ptr + bias_offt, 0, (bias_hh.numel() - bias_offt) * sizeof(float)); + lite::arm::math::fill_bias_fc( + o_data, bias_tmp_hh.data(), m, n, flag_act); + } else { + lite::arm::math::fill_bias_fc( + o_data, bias_hh.data(), m, n, flag_act); + } +} + +/****************************************************** +input: + ctx:context, + init_h:(2D), + init_c:(2D), + mask_tensor:(1D)input->dims()[1], + mode:LSTM, GRU +output: + output:(2D)output->dims()[1], output->dims()[2], + last_h:(2D), + last_c:(2D) +*******************************************************/ +static void postprocess(ARMContext* ctx, + Tensor* output, + const Tensor* init_h, + const Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + const Tensor& mask_tensor, + std::string mode) { + Tensor mask_broadcast_1; + mask_broadcast_1.Resize(mask_tensor.dims()); + auto mask_ptr_1 = mask_broadcast_1.mutable_data(); + auto mask_ptr = mask_tensor.data(); + auto out_ptr = output->mutable_data(); + auto cur_h_ptr = last_h->mutable_data(); + auto pre_h_ptr = init_h->data(); + int offset = 0; + + // out = out * mask_broadcast + // curr_h = out * mask_broadcast + pre_h * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + mask_ptr_1[i] = 1 - mask_ptr[i]; + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + out_ptr[offset] *= mask_ptr[i]; + cur_h_ptr[offset] = out_ptr[offset] + pre_h_ptr[offset] * mask_ptr_1[i]; + } + } + if ("LSTM" == mode) { + auto pre_c_ptr = init_c->data(); + auto cur_c_ptr = last_c->mutable_data(); + + // curr_c = curr_c * mask_broadcast + pre_c * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + cur_c_ptr[offset] = + cur_c_ptr[offset] * mask_ptr[i] + pre_c_ptr[offset] * mask_ptr_1[i]; + } + } + } +} + +static DDim get_stride(const DDim& ddim) { + DDim strides; + strides[ddim.size() - 1] = 1; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i + 1]; + } + return strides; +} + +template +static void TransposeNormal(const Tensor& in, + Tensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = get_stride(in.dims()); + auto out_stride = get_stride(out->dims()); + const T* in_ptr = in.data(); + T* out_ptr = out->mutable_data(); + + auto transpose_helper = [&](int64_t beg, int64_t end) { + for (int64_t out_idx = beg; out_idx < end; ++out_idx) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + // calculate the input index + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride[i]; + tmp_idx -= coordinate * out_stride[i]; + in_idx += coordinate * in_stride[axis[i]]; + } + out_ptr[out_idx] = in_ptr[in_idx]; + } + }; + transpose_helper(0, out->numel()); +} + +/****************************************************** +input: + sequence_length, + is_reverse +output: + mask_matrix, + min_seq_len +******************************************************/ +static void create_mask_matrix(const Tensor* sequence_length, + Tensor* mask_matrix, + const bool& is_reverse, + int* min_seq_len) { + // Tensor to vector + std::vector seq_len_vec; + seq_len_vec.resize(sequence_length->numel()); + std::memcpy(&seq_len_vec[0], + sequence_length->data(), + sequence_length->numel() * sizeof(int)); + + const int& table_width = mask_matrix->dims()[0]; + Tensor temp; + DDimLite dims( + std::vector{mask_matrix->dims()[1], mask_matrix->dims()[0]}); + temp.Resize(dims); + float* data_temp = temp.mutable_data(); + std::fill(data_temp, data_temp + mask_matrix->numel(), 1.f); + *min_seq_len = table_width; + for (unsigned int i = 0; i < seq_len_vec.size(); i++) { + // reset the mask matrix + *min_seq_len = std::min(seq_len_vec[i], *min_seq_len); + if (seq_len_vec[i] == table_width) { + continue; + } + if (is_reverse) { + std::fill(data_temp + i * table_width, + data_temp + (i + 1) * table_width - seq_len_vec[i], + 0.f); + } else { + std::fill(data_temp + i * table_width + seq_len_vec[i], + data_temp + (i + 1) * table_width, + 0.f); + } + } + mask_matrix->mutable_data(); + std::vector trans_vec; + trans_vec.emplace_back(1); + trans_vec.emplace_back(0); + TransposeNormal(temp, mask_matrix, trans_vec); } -void cell(ARMContext* ctx, - Tensor* input, - Tensor* weight_hh, - Tensor* init_h, - Tensor* init_c, - Tensor* last_h, - Tensor* last_c, - Tensor* last_c_act, - Tensor* output, - const Tensor* bias_hh) { +static void lstm_cell(ARMContext* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh) { bool flag_act = false; operators::ActivationParam act_param; act_param.has_active = false; @@ -201,35 +384,88 @@ void cell(ARMContext* ctx, ctx->threads()); } -// layer, output_tensor, is_bidirection, offset -#define RUN_LSTM_LAYER(x, y, z, w) \ - runLSTMLayer(&ctx, \ - input_temp_holder, \ - parameter_lists[x], \ - init_h_unbind, \ - init_c_unbind, \ - sequence_length, \ - &last_h_unbind, \ - &last_c_unbind, \ - y, \ - x, \ - &gate_value, \ - z, \ - w) - -void runLSTMLayer(ARMContext* ctx, - const Tensor* input, - std::vector vec, - std::vector init_h, - std::vector init_c, - const Tensor* sequence_length, - std::vector* last_h_ptr, - std::vector* last_c_ptr, - Tensor* output, - int layer_idx, - Tensor* gate_value, - bool is_bidirect, - int offset) { +static void gru_cell(ARMContext* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh, + Tensor* weight_hh_gru) { + bool flag_act = false; + operators::ActivationParam act_param; + act_param.has_active = false; + auto h_dims = init_h->dims(); + auto weight_gru_dims = weight_hh_gru->dims(); + int m = h_dims[0]; + int k = h_dims[1]; + int n = weight_gru_dims[0]; + auto i_data = input->data(); + auto w_gru = weight_hh_gru->data(); + auto h_data = init_h->data(); + + Tensor tmp_gate; + tmp_gate.Resize(input->dims()); + auto tmp_data = tmp_gate.mutable_data(); + lite::arm::math::sgemm(false, + true, + m, + n, + k, + 1.f, + h_data, + k, + w_gru, + k, + 0.f, + tmp_data, + n, + nullptr, + false, + act_param, + ctx); + for (int i = 0; i < input->dims()[0] * input->dims()[1]; i++) { + tmp_data[i] += i_data[i]; + } + + size_t frame_size = init_h->dims()[1]; + size_t batch_size = init_h->dims()[0]; + + lite::arm::math::RNNGRUValue gru_value; + gru_value.gate_weight = weight_hh->data(); + gru_value.state_weight = + weight_hh->data() + 2 * frame_size * frame_size; + gru_value.reset_bias = bias_hh->data() + 2 * frame_size; + + gru_value.gate_value = tmp_data; + gru_value.reset_output_value = last_c->mutable_data(); + gru_value.output_value = output->mutable_data(); + gru_value.prev_out_value = init_h->data(); + + lite_api::ActivationType gate_act = lite_api::ActivationType::kSigmoid_v2; + lite_api::ActivationType cand_act = lite_api::ActivationType::kTanh_v2; + + lite::arm::math::RnnGruUnitFunctorV2::compute( + ctx, gru_value, frame_size, batch_size, cand_act, gate_act); +} + +static void RunRnnLayer(ARMContext* ctx, + const Tensor* input, + std::vector vec, + std::vector init_h, + std::vector init_c, + const Tensor* sequence_length, + std::vector* last_h_ptr, + std::vector* last_c_ptr, + Tensor* output, + int layer_idx, + Tensor* gate_value, + bool is_bidirect, + int offset, + std::string mode) { bool is_reverse = false; if (is_bidirect) { layer_idx = 2 * layer_idx + offset; @@ -243,13 +479,16 @@ void runLSTMLayer(ARMContext* ctx, vec[0 + offset * 4], vec[2 + offset * 4], vec[3 + offset * 4], - gate_value, - true); + mode, + gate_value); + std::vector input_tensors, output_tensors; std::vector input_tensors_t, output_tensors_t; - std::vector stride1, stride2; + std::vector stride1, stride2, stride3; input_tensors.resize(gate_value->dims()[0]); output_tensors.resize(output->dims()[0]); + + // unbind for (int i = 0; i < gate_value->dims()[0]; i++) { stride1.push_back(1); int dim1 = gate_value->dims()[1]; @@ -272,68 +511,140 @@ void runLSTMLayer(ARMContext* ctx, auto sd = output->mutable_data(); if (is_reverse) { + // don't need to reverse input_tensors_t becauese of unuseful std::reverse(input_tensors.begin(), input_tensors.end()); } bool has_sequence_length = false; + if (sequence_length != nullptr) { + has_sequence_length = true; + } + // unbind + Tensor mask_matrix; + std::vector mask_vec; + std::vector mask_tensor_list; + int mask_min_length = time_step; + /* - TODO has_sequence_length + to be verifying! */ - - int mask_min_length = time_step; + if (has_sequence_length) { + mask_matrix.Resize(DDimLite({time_step, input->dims()[1]})); + create_mask_matrix( + sequence_length, &mask_matrix, is_reverse, &mask_min_length); + for (int i = 0; i < time_step; i++) { + stride3.push_back(1); + DDimLite ddims(std::vector{input->dims()[1]}); + mask_vec[i].Resize(ddims); + mask_tensor_list.push_back(&mask_vec[i]); + } + lite::host::math::split( + mask_matrix.data(), mask_tensor_list, 0, stride3); + } if (is_reverse) { mask_min_length = mask_min_length - time_step + 1; } + bool has_allocate_mem_c = false; bool has_use_last_h_holder = false; const int& reverse_flag = is_reverse ? -1 : 1; + + // define the init_h holder for the swap Tensor init_h_temp; + init_h_temp.Resize(init_h[layer_idx].dims()); init_h_temp.CopyDataFrom(init_h[layer_idx]); Tensor* init_h_holder = &init_h_temp; Tensor* last_h_holder = nullptr; - if (0 < mask_min_length) { last_h_holder = &(output_tensors[0]); } else { last_h_holder = &(*last_h_ptr)[layer_idx]; has_use_last_h_holder = true; } + Tensor* init_c_holder = nullptr; Tensor* init_c_temp_holder = nullptr; Tensor init_c_temp; Tensor* last_c_holder = nullptr; Tensor last_c_temp; - last_c_holder = &(*last_c_ptr)[layer_idx]; - init_c_temp_holder = &init_c[layer_idx]; + + if ("LSTM" == mode) { + last_c_holder = &(*last_c_ptr)[layer_idx]; + init_c_temp_holder = &init_c[layer_idx]; + } else if ("GRU" == mode) { + // for reset output value + last_c_temp.Resize(init_h[layer_idx].dims()); + last_c_temp.mutable_data(); + last_c_holder = &last_c_temp; + } + + Tensor weight_hh_tmp; // for gru + std::vector weight_hh_tmp_ubind; + std::vector weight_hh_tmp_ubind_t; + std::vector stride_w; + if ("GRU" == mode) { + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + weight_hh_tmp.mutable_data(); + weight_hh_tmp.CopyDataFrom(vec[1 + offset * 4]); + int size = weight_hh_tmp.numel() / 3; + std::memset(weight_hh_tmp.mutable_data() + size * 2, + 0, + size * sizeof(float)); + } + for (int i = 0; i < time_step; i++) { bool in_mask = (reverse_flag * i) >= mask_min_length; if (i > 0) { if (!has_allocate_mem_c) { - init_c_temp.Resize(init_h[layer_idx].dims()); - init_c_temp.mutable_data(); - init_c_holder = &init_c_temp; + if (("LSTM" == mode) || ("GRU" == mode)) { + init_c_temp.Resize(init_h[layer_idx].dims()); + init_c_temp.mutable_data(); + init_c_holder = &init_c_temp; + } has_allocate_mem_c = true; } SwapPoniter(&init_c_holder, &last_c_holder); init_c_temp_holder = init_c_holder; } - // LSTMCELL - cell(ctx, - &input_tensors[i], - &vec[1 + offset * 4], - init_h_holder, - init_c_temp_holder, - last_h_holder, - last_c_holder, - nullptr, - &output_tensors[i], - &vec[3 + offset * 4]); + if ("LSTM" == mode) { + lstm_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4]); + } else if ("GRU" == mode) { + gru_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4], + &weight_hh_tmp); + } + /* + to be verifying! + */ if (in_mask) { - /* - TODO in_mask - */ + postprocess(ctx, + &output_tensors[i], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + mask_vec[i], + mode); } + // prepare next step if (i + 1 < time_step) { bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length; @@ -347,6 +658,7 @@ void runLSTMLayer(ARMContext* ctx, SwapPoniter(&init_h_holder, &last_h_holder); } } + if (is_reverse) { std::reverse(output_tensors.begin(), output_tensors.end()); } @@ -364,7 +676,7 @@ void runLSTMLayer(ARMContext* ctx, } else { (*last_h_ptr)[layer_idx].CopyDataFrom(output_tensors[time_step - 1]); } - if (time_step % 2 == 0) { + if ((0 == (time_step % 2)) && ("LSTM" == mode)) { (*last_c_ptr)[layer_idx].CopyDataFrom(*last_c_holder); } } @@ -375,25 +687,30 @@ void RnnCompute::Run() { std::string mode = param.mode; auto input = param.Input; auto weight_list = param.WeightList; - auto reserve_data = param.Reserve; auto pre_state = param.PreState; auto state = param.State; - auto dropout_state = param.DropoutState; auto output = param.Out; bool is_bidirec = param.is_bidirec; int num_layers = param.num_layers; - int input_size = param.input_size; - int hidden_size = param.hidden_size; - bool is_test = param.is_test; - float dropout_prob = param.dropout_prob; - int seed = param.seed; const Tensor* sequence_length = param.SequenceLength; + int gate_num = 0; + + if ("LSTM" == mode) { + gate_num = 4; + } else if ("GRU" == mode) { + gate_num = 3; + } else { + LOG(FATAL) << "X86 RNN ERROR: unsupport mode except gru and lstm," + " present mode is " + << mode; + return; + } state[0]->mutable_data(); - state[1]->mutable_data(); + if ("LSTM" == mode) { + state[1]->mutable_data(); + } - // lstmCell begin - int gate_num = 4; std::vector> parameter_lists; parameter_lists.reserve(num_layers); reset_parameter_vector( @@ -407,31 +724,56 @@ void RnnCompute::Run() { last_c_unbind; std::vector init_h_unbind_t, init_c_unbind_t, last_h_unbind_t, last_c_unbind_t; - init_h_unbind.resize(4); - init_c_unbind.resize(pre_state[1]->dims()[0]); + init_h_unbind.resize(pre_state[0]->dims()[0]); last_h_unbind.resize(state[0]->dims()[0]); - last_c_unbind.resize(state[1]->dims()[0]); - std::vector stride; + if ("LSTM" == mode) { + init_c_unbind.resize(pre_state[1]->dims()[0]); + last_c_unbind.resize(state[1]->dims()[0]); + } + + std::vector stride1, stride2; + // unbind for (int i = 0; i < pre_state[0]->dims()[0]; i++) { - stride.push_back(1); + stride1.push_back(1); int dim1 = pre_state[0]->dims()[1]; int dim2 = pre_state[0]->dims()[2]; DDimLite dims(std::vector{dim1, dim2}); init_h_unbind[i].Resize(dims); - init_c_unbind[i].Resize(dims); last_h_unbind[i].Resize(dims); - last_c_unbind[i].Resize(dims); init_h_unbind_t.push_back(&init_h_unbind[i]); - init_c_unbind_t.push_back(&init_c_unbind[i]); last_h_unbind_t.push_back(&last_h_unbind[i]); - last_c_unbind_t.push_back(&last_c_unbind[i]); } lite::host::math::split( - pre_state[0]->data(), init_h_unbind_t, 0, stride); - lite::host::math::split( - pre_state[1]->data(), init_c_unbind_t, 0, stride); - lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride); - lite::host::math::split(state[1]->data(), last_c_unbind_t, 0, stride); + pre_state[0]->data(), init_h_unbind_t, 0, stride1); + lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride1); + + if ("LSTM" == mode) { + for (int i = 0; i < pre_state[1]->dims()[0]; i++) { + stride2.push_back(1); + int dim1 = pre_state[1]->dims()[1]; + int dim2 = pre_state[1]->dims()[2]; + DDimLite dims(std::vector{dim1, dim2}); + init_c_unbind[i].Resize(dims); + last_c_unbind[i].Resize(dims); + init_c_unbind_t.push_back(&init_c_unbind[i]); + last_c_unbind_t.push_back(&last_c_unbind[i]); + } + lite::host::math::split( + pre_state[1]->data(), init_c_unbind_t, 0, stride2); + lite::host::math::split( + state[1]->data(), last_c_unbind_t, 0, stride2); + } + + std::vector output_vec(2); + int time_step = input->dims()[0]; + int batch_size = input->dims()[1]; + int hidden_size = output->dims()[2]; + if (is_bidirec) { + for (int i = 0; i < 2; ++i) { + output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); + output_vec[i].mutable_data(); + } + } for (int i = 0; i < num_layers; i++) { if (i > 0) { @@ -450,24 +792,18 @@ void RnnCompute::Run() { } if (is_bidirec) { - std::vector output_vec(2); - int time_step = input->dims()[0]; - int batch_size = input->dims()[1]; - int hidden_size = output->dims()[2]; - for (int i = 0; i < 2; ++i) { - output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); - output_vec[i].mutable_data(); - } - - RUN_LSTM_LAYER(i, &output_vec[0], true, 0); - RUN_LSTM_LAYER(i, &output_vec[1], true, 1); - + RUN_RNN_LAYER(i, &output_vec[0], true, 0); + RUN_RNN_LAYER(i, &output_vec[1], true, 1); std::vector output_vec_t = {&output_vec[0], &output_vec[1]}; - lite::arm::math::concat_func(output_vec_t, 2, output); + lite::arm::math::concat_func(output_vec_t, 2, output_holder); } else { - RUN_LSTM_LAYER(i, output_holder, false, 0); + RUN_RNN_LAYER(i, output_holder, false, 0); } } + // output_holder != output + if (num_layers % 2 == 0) { + output->CopyDataFrom(*output_holder); + } } } // namespace arm } // namespace kernels diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 439f6e5722f..8ddcc0d0819 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -28,6 +28,7 @@ add_kernel(argmax_compute_host Host basic SRCS argmax_compute.cc) add_kernel(assign_value_compute_host Host basic SRCS assign_value_compute.cc) add_kernel(yolo_box_compute_host Host basic SRCS yolo_box_compute.cc) add_kernel(write_back_compute_host Host basic SRCS write_back_compute.cc) +add_kernel(cast_compute_host Host basic SRCS cast_compute.cc) # extra kernels add_kernel(reverse_compute_host Host extra SRCS reverse_compute.cc) diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/host/cast_compute.cc similarity index 91% rename from lite/kernels/arm/cast_compute.cc rename to lite/kernels/host/cast_compute.cc index 874bbea0665..edddbc79b57 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/host/cast_compute.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/cast_compute.h" +#include "lite/kernels/host/cast_compute.h" #include #ifdef ENABLE_ARM_FP16 #include "lite/backends/arm/math/fp16/funcs_fp16.h" @@ -20,7 +20,7 @@ namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { template out_type TransOp(in_type in) { @@ -30,7 +30,6 @@ out_type TransOp(in_type in) { void CastCompute::PrepareForRun() {} void CastCompute::Run() { - auto& ctx = this->ctx_->template As(); auto& param = this->Param(); auto input_dims = param.X->dims(); if (param.X->precision() == PrecisionType::kFloat) { @@ -128,7 +127,7 @@ void CastCompute::Run() { int32_t* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); -#ifdef ENABLE_ARM_FP16 +#if defined(ENABLE_ARM_FP16) && defined(LITE_WITH_ARM) } else if (param.in_dtype == 4 && param.out_dtype == 5) { // float16 -> float32 const float16_t* in_data = param.X->data(); @@ -146,21 +145,13 @@ void CastCompute::Run() { } } -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL( - cast, kARM, kAny, kNCHW, paddle::lite::kernels::arm::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .Finalize(); - -#ifdef LITE_BUILD_EXTRA -REGISTER_LITE_KERNEL( - cast, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::CastCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + cast, kHost, kAny, kNCHW, paddle::lite::kernels::host::CastCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) .Finalize(); -#endif // LITE_BUILD_EXTRA diff --git a/lite/kernels/arm/cast_compute.h b/lite/kernels/host/cast_compute.h similarity index 90% rename from lite/kernels/arm/cast_compute.h rename to lite/kernels/host/cast_compute.h index cd23d62198a..f344b84e8b2 100644 --- a/lite/kernels/arm/cast_compute.h +++ b/lite/kernels/host/cast_compute.h @@ -20,9 +20,9 @@ namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { -class CastCompute : public KernelLite { +class CastCompute : public KernelLite { public: using param_t = operators::CastParam; @@ -35,7 +35,7 @@ class CastCompute : public KernelLite { private: }; -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/host/compare_compute.cc b/lite/kernels/host/compare_compute.cc index 1548b901693..755e699d82c 100644 --- a/lite/kernels/host/compare_compute.cc +++ b/lite/kernels/host/compare_compute.cc @@ -607,10 +607,10 @@ REGISTER_LITE_KERNEL( .Finalize(); using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute< - PRECISION(kInt64), + PRECISION(kFloat), paddle::lite::kernels::host::_GreaterEqualFunctor>; REGISTER_LITE_KERNEL( - greater_equal, kHost, kInt64, kAny, greater_equal_float, def) + greater_equal, kHost, kFloat, kAny, greater_equal_float, def_int64) .BindInput("X", {LiteType::GetTensorTy( TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)}) diff --git a/lite/kernels/host/expand_v2_compute.cc b/lite/kernels/host/expand_v2_compute.cc index 928363f104e..d0e0e2ffdb3 100644 --- a/lite/kernels/host/expand_v2_compute.cc +++ b/lite/kernels/host/expand_v2_compute.cc @@ -140,8 +140,8 @@ REGISTER_LITE_KERNEL(expand_v2, kHost, kInt32, kAny, expand_v2_int32, def) .Finalize(); using expand_v2_int64 = - paddle::lite::kernels::host::ExpandV2Compute; -REGISTER_LITE_KERNEL(expand_v2, kHost, kInt64, kAny, expand_v2_int64, def) + paddle::lite::kernels::host::ExpandV2Compute; +REGISTER_LITE_KERNEL(expand_v2, kHost, kFloat, kAny, expand_v2_int64, def_int64) .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64), diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 94e8cd6b598..2ec9a87ca47 100755 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -10,7 +10,7 @@ endif() add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc) add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc) -add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc) +#add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc) if(WITH_AVX AND AVX_FOUND) add_kernel(conv_depthwise_x86 X86 basic SRCS conv_depthwise.cc) @@ -84,7 +84,7 @@ lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc) lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc) lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc) lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc) -lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc) +#lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc) lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc) lite_cc_test(test_layer_norm_compute_x86 SRCS layer_norm_compute_test.cc) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc) diff --git a/lite/kernels/x86/elementwise_compute.cc b/lite/kernels/x86/elementwise_compute.cc index 61e1ab87985..2aa948cc708 100644 --- a/lite/kernels/x86/elementwise_compute.cc +++ b/lite/kernels/x86/elementwise_compute.cc @@ -392,6 +392,7 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_add, kX86, kFloat, @@ -402,6 +403,7 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_add, kX86, kFloat, @@ -412,6 +414,7 @@ REGISTER_LITE_KERNEL(elementwise_add, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_add_activation, kX86, @@ -423,6 +426,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, @@ -433,6 +437,7 @@ REGISTER_LITE_KERNEL(elementwise_sub, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, @@ -443,6 +448,7 @@ REGISTER_LITE_KERNEL(elementwise_sub, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, @@ -453,6 +459,7 @@ REGISTER_LITE_KERNEL(elementwise_sub, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_sub_activation, kX86, @@ -464,6 +471,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kX86, kFloat, @@ -474,6 +482,7 @@ REGISTER_LITE_KERNEL(elementwise_mul, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kX86, kFloat, @@ -484,6 +493,7 @@ REGISTER_LITE_KERNEL(elementwise_mul, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mul, kX86, kFloat, @@ -494,6 +504,7 @@ REGISTER_LITE_KERNEL(elementwise_mul, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_mul_activation, kX86, @@ -505,6 +516,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_div, kX86, kFloat, @@ -515,6 +527,7 @@ REGISTER_LITE_KERNEL(elementwise_div, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_div, kX86, kFloat, @@ -525,6 +538,7 @@ REGISTER_LITE_KERNEL(elementwise_div, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_div, kX86, kFloat, @@ -535,6 +549,7 @@ REGISTER_LITE_KERNEL(elementwise_div, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_div_activation, kX86, @@ -546,6 +561,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL( elementwise_floordiv, kX86, @@ -557,6 +573,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL( elementwise_floordiv, kX86, @@ -568,6 +585,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL( elementwise_floordiv, kX86, @@ -579,6 +597,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_pow, kX86, kFloat, @@ -589,6 +608,7 @@ REGISTER_LITE_KERNEL(elementwise_pow, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_pow, kX86, kFloat, @@ -599,6 +619,7 @@ REGISTER_LITE_KERNEL(elementwise_pow, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_pow, kX86, kFloat, @@ -609,6 +630,7 @@ REGISTER_LITE_KERNEL(elementwise_pow, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mod, kX86, kFloat, @@ -619,6 +641,7 @@ REGISTER_LITE_KERNEL(elementwise_mod, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_mod, kX86, kFloat, @@ -629,6 +652,7 @@ REGISTER_LITE_KERNEL(elementwise_mod, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_max, kX86, kFloat, @@ -639,6 +663,7 @@ REGISTER_LITE_KERNEL(elementwise_max, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_max, kX86, kFloat, @@ -649,6 +674,7 @@ REGISTER_LITE_KERNEL(elementwise_max, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_max, kX86, kFloat, @@ -659,6 +685,7 @@ REGISTER_LITE_KERNEL(elementwise_max, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_max_activation, kX86, @@ -670,6 +697,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_min_activation, kX86, @@ -681,6 +709,7 @@ REGISTER_LITE_KERNEL( .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_min, kX86, kFloat, @@ -691,6 +720,7 @@ REGISTER_LITE_KERNEL(elementwise_min, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_min, kX86, kFloat, @@ -701,6 +731,7 @@ REGISTER_LITE_KERNEL(elementwise_min, .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .Finalize(); + REGISTER_LITE_KERNEL(elementwise_min, kX86, kFloat, diff --git a/lite/kernels/x86/elementwise_op_function.h b/lite/kernels/x86/elementwise_op_function.h new file mode 100644 index 00000000000..5ab94ff88bb --- /dev/null +++ b/lite/kernels/x86/elementwise_op_function.h @@ -0,0 +1,802 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "lite/backends/x86/fluid/eigen.h" +#include "lite/backends/x86/fluid/for_range.h" +#include "lite/backends/x86/fluid/transform.h" +#include "lite/backends/x86/math/math_function.h" +#include "lite/utils/log/cp_logging.h" +#include "lite/utils/variant.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +/* + * Out = X ⊙ Y + * If Y's shape does not match X' shape, they will be reshaped. + * For example: + * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + * pre=2, n=3*4, post=5 + * x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) + * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) + * pre=2*3, n=4*5, post=1 + * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) + * + * New parameter: *mid_flag* is added to solve m*n*k & m*1*k + * broadcast cases. + */ +inline void get_mid_dims(const lite::DDim &x_dims, + const lite::DDim &y_dims, + const int axis, + int *pre, + int *n, + int *post, + int *mid_flag = NULL) { + *pre = 1; + *n = 1; + *post = 1; + if (mid_flag != NULL) { + *mid_flag = 0; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (size_t i = 0; i < y_dims.size(); ++i) { + if (x_dims[i + axis] != y_dims[i]) { + CHECK_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, true) + << "Broadcast y or x dimension is not 1."; + *mid_flag = 1; + return; + } + (*n) *= y_dims[i]; + } + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } else { // for fused_elementwise_activation_op. keep the old version. + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + + for (size_t i = 0; i < y_dims.size(); ++i) { + CHECK_EQ(x_dims[i + axis], y_dims[i]) << "Broadcast dimension mismatch."; + (*n) *= y_dims[i]; + } + + for (size_t i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } +} + +inline lite::DDim trim_trailing_singular_dims(const lite::DDim &dims) { + // Remove trailing dimensions of size 1 for y + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + if (actual_dims_size == dims.size()) return dims; + std::vector trim_dims; + trim_dims.resize(actual_dims_size); + for (size_t i = 0; i < actual_dims_size; ++i) { + trim_dims[i] = dims[i]; + } + if (trim_dims.size() == 0) { + return lite::DDim(); + } + lite::DDim actual_dims = lite::DDim(trim_dims); + return actual_dims; +} + +template +class RowwiseTransformIterator; + +template +class MidWiseTransformIterator; + +// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17 +template +class RowwiseTransformIterator + : public std::iterator { + public: + RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + + RowwiseTransformIterator &operator++() { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + return *this; + } + + RowwiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + + return *this; + } + + bool operator==( + const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T &operator*() { return ptr_[i_]; } + + private: + const T *ptr_; + int i_; + int64_t n_; +}; + +template +class MidWiseTransformIterator + : public std::iterator { + public: + MidWiseTransformIterator(const T *ptr, int n, int post) + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} + + MidWiseTransformIterator &operator++() { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + return *this; + } + + MidWiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++j_; + if (UNLIKELY(j_ == post_)) { + ++i_; + j_ = 0; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } + } + return *this; + } + + bool operator==( + const MidWiseTransformIterator &rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const MidWiseTransformIterator &rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T &operator*() { return ptr_[i_]; } + + private: + const T *ptr_; + int64_t i_; + int64_t j_; + int64_t n_; + int64_t post_; +}; + +template +class TransformFunctor { + public: + TransformFunctor(const lite::Tensor *x, + const lite::Tensor *y, + lite::Tensor *z, + const lite::Context &ctx, + Functor func, + const bool is_xsize_larger = true) + : x_(x->template data()), + y_(y->template data()), + z_(z->mutable_data()), + nx_(x->numel()), + ctx_(ctx), + func_(func), + is_xsize_larger_(is_xsize_larger) { + if (is_xsize_larger_ == false) { + nx_ = y->numel(); + } + } + + inline void Run() const { + lite::fluid::Transform trans; + trans(ctx_, x_, x_ + nx_, y_, z_, func_); + } + + inline void RunRowWise(int n, int pre) const { + lite::fluid::Transform trans; + if (is_xsize_larger_) { + trans(ctx_, + x_, + x_ + nx_, + RowwiseTransformIterator(y_, n), + z_, + func_); + } else { + trans(ctx_, + y_, + y_ + nx_, + RowwiseTransformIterator(x_, n), + z_, + func_); + } + } + + inline void RunMidWise(int n, int pre, int post) const { + lite::fluid::Transform trans; + if (is_xsize_larger_) { + trans(ctx_, + x_, + x_ + nx_, + MidWiseTransformIterator(y_, n, post), + z_, + func_); + } else { + trans(ctx_, + y_, + y_ + nx_, + MidWiseTransformIterator(x_, n, post), + z_, + func_); + } + } + + private: + const T *x_; + const T *y_; + OutType *z_; + int64_t nx_; + const lite::Context &ctx_; + Functor func_; + bool is_xsize_larger_; +}; + +inline void GetBroadcastDimsArrays(const DDim &x_dims, + const DDim &y_dims, + int *x_dims_array, + int *y_dims_array, + int *out_dims_array, + const int max_dim, + const int axis) { + CHECK_GE(axis, 0) << "Axis should be great than or equal to 0."; + CHECK_LT(axis, max_dim) << "Axis should be less than max(x_dim, y_dim)."; + + if (x_dims.size() > y_dims.size()) { + std::fill(y_dims_array, y_dims_array + axis, 1); + if (axis + y_dims.size() < max_dim) { + std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1); + } + for (int i = 0; i < x_dims.size(); i++) x_dims_array[i] = x_dims[i]; + for (int i = 0; i < y_dims.size(); i++) + *(y_dims_array + axis + i) = y_dims[i]; + } else { + std::fill(x_dims_array, x_dims_array + axis, 1); + if (axis + x_dims.size() < max_dim) { + std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1); + } + for (int i = 0; i < x_dims.size(); i++) + *(x_dims_array + axis + i) = x_dims[i]; + for (int i = 0; i < y_dims.size(); i++) *(y_dims_array + i) = y_dims[i]; + } + + for (int i = 0; i < max_dim; i++) { + CHECK_EQ(x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || + y_dims_array[i] <= 1, + true) + << "Broadcast dimension mismatch. Operands could not be broadcast."; + + if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || + (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { + out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + } else { + out_dims_array[i] = -1; + } + } +} + +inline int GetElementwiseIndex(const int *x_dims_array, + const int max_dim, + const int *index_array) { + int index_ = 0; + for (int i = 0; i < max_dim; i++) { + if (x_dims_array[i] > 1) { + index_ = index_ * x_dims_array[i] + index_array[i]; + } + } + return index_; +} + +inline void UpdateElementwiseIndexArray(const int *out_dims_array, + const int max_dim, + int *index_array) { + for (int i = max_dim - 1; i >= 0; --i) { + ++index_array[i]; + if (index_array[i] >= out_dims_array[i]) { + index_array[i] -= out_dims_array[i]; + } else { + break; + } + } +} + +template +void CommonForwardBroadcastCPU(const Tensor *x, + const Tensor *y, + Tensor *z, + int *x_dims_array, + int *y_dims_array, + int *out_dims_array, + int max_dim, + Functor func, + const bool is_xsize_larger = true) { + std::vector index_array(max_dim, 0); + const T *x_data = x->data(); + const T *y_data = y->data(); + CHECK_EQ(x_data != nullptr, true) << "The input X should not be empty."; + CHECK_EQ(y_data != nullptr, true) << "The input Y should not be empty."; + + OutType *out_data = z->mutable_data(); + const int out_size = std::accumulate( + out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); + int x_index, y_index; + for (int out_index = 0; out_index < out_size; ++out_index) { + x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); + y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); + if (is_xsize_larger) { + out_data[out_index] = func(x_data[x_index], y_data[y_index]); + } else { + out_data[out_index] = func(y_data[y_index], x_data[x_index]); + } + + UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); + } +} + +template +void CommonElementwiseBroadcastForward(const Tensor *x, + const Tensor *y, + Tensor *z, + const DDim &x_dims, + const DDim &y_dims, + Functor func, + int axis, + const bool is_xsize_larger = true) { + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + CHECK_GE(axis, 0) << "Axis should be great than or equal to 0."; + CHECK_LT(axis, max_dim) << "Axis should be less than max(x_dim, y_dim)."; + + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + + CommonForwardBroadcastCPU(x, + y, + z, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + func, + is_xsize_larger); +} + +template +void ElementwiseComputeEx(const lite::Context &ctx, + const lite::Tensor *x, + const lite::Tensor *y, + int axis, + Functor func, + lite::Tensor *z) { + auto x_dims = x->dims(); + auto y_dims = y->dims(); + bool is_xsize_larger = true; + int max_dim = x_dims.size(); + if (x_dims.size() < y_dims.size()) { + is_xsize_larger = false; + max_dim = y_dims.size(); + } + TransformFunctor functor( + x, y, z, ctx, func, is_xsize_larger); + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int tmp = std::abs(static_cast(x_dims.size()) - + static_cast(y_dims.size())); + axis = (axis == static_cast(-1) ? tmp : axis); + + CHECK_GE(axis, 0) << "Axis should be great than or equal to 0."; + CHECK_LT(axis, max_dim) << "Axis should be less than max(x_dim, y_dim)."; + + int pre, n, post, is_run_common_broadcast, axis_trim = 0; + if (is_xsize_larger) { + auto y_dims_trimed = trim_trailing_singular_dims(y_dims); + axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; + get_mid_dims(x_dims, + y_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } else { + auto x_dims_trimed = trim_trailing_singular_dims(x_dims); + axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; + get_mid_dims(y_dims, + x_dims_trimed, + axis_trim, + &pre, + &n, + &post, + &is_run_common_broadcast); + } + // special case for common implementation. + // case 1: x=[2,3,1,5], y=[2,1,4,1] + // case 2: x=[2,3,4], y=[1,1,4] + if (is_run_common_broadcast == 1) { + CommonElementwiseBroadcastForward( + x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); + return; + } + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } +} + +// FusedElemwiseAndAct +// --- forward +template +struct FusedElemwiseAndActNoBroadcast { + HOSTDEVICE void operator()(size_t i) { + T y_val = y_[i]; + T x_val = x_[i]; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val); + intermediate_out_[i] = intermeidiate_out; + out_[i] = + compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out_[i] = compound_functor_.GetOut(x_val, y_val); + } + } + + const T *x_; + const T *y_; + CompoundFunctor compound_functor_; + T *out_; + T *intermediate_out_; +}; + +// FusedElemwiseAndActBroadcast1: +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2, +// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20) +template +static void FusedElemwiseAndActBroadcast1CPU(const T *x, + const T *y, + CompoundFunctor compound_functor, + int h, + int w, + T *out, + T *intermediate_out) { + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int offset = i * w + j; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = + compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } +} + +// FusedElemwiseAndActBroadcast2 +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1, +// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1) +// pre = 2, n = 12, post = 5 +template +static void FusedElemwiseAndActBroadcast2CPU(const T *x, + const T *y, + int pre, + int n, + int post, + CompoundFunctor compound_functor, + T *out, + T *intermediate_out) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int offset = i * n * post + j * post + k; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + + if (KeepIntermediateOut) { + T intermeidiate_out = + compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = compound_functor.GetOutUseIntermediateOut( + x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } + } +} + +template +void FusedElemwiseAndActComputeNoBroadcast(const lite::Context &ctx, + const lite::DDim &x_dim, + const lite::Tensor &x, + const lite::Tensor &y, + CompoundFunctor compound_functor, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + size_t N = static_cast(x_dim.production()); + + lite::fluid::ForRange for_range(ctx, N); + + for_range( + FusedElemwiseAndActNoBroadcast{ + x.data(), + y.data(), + compound_functor, + out->template mutable_data(), + intermediate_out == nullptr + ? nullptr + : intermediate_out->template mutable_data()}); +} + +template +void FusedElemwiseAndActComputeWithBroadcast(const lite::Context &ctx, + const lite::DDim &x_dim, + const lite::DDim &y_dim_untrimed, + const lite::Tensor &x, + const lite::Tensor &y, + CompoundFunctor compound_functor, + int axis, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); + auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + + if (post == 1) { + int h = pre; + int w = n; + FusedElemwiseAndActBroadcast1CPU( + x.data(), + y.data(), + compound_functor, + h, + w, + out->template mutable_data(), + intermediate_out == nullptr + ? nullptr + : intermediate_out->template mutable_data()); + + } else { + FusedElemwiseAndActBroadcast2CPU( + x.data(), + y.data(), + pre, + n, + post, + compound_functor, + out->template mutable_data(), + intermediate_out == nullptr + ? nullptr + : intermediate_out->template mutable_data()); + } +} + +template +void FusedElemwiseAndActComputeEx(const lite::Context &ctx, + const lite::Tensor &x, + const lite::Tensor &y, + int axis, + CompoundFunctor compound_functor, + lite::Tensor *out, + lite::Tensor *intermediate_out) { + if (KeepIntermediateOut) { + CHECK(intermediate_out) << "The save_intermediate_out is opened, " + "intermediate_out should not be nullptr."; + } + + const lite::DDim &x_dim = x.dims(); + const lite::DDim &y_dim = y.dims(); + if (x.dims() == y.dims()) { + FusedElemwiseAndActComputeNoBroadcast( + ctx, x_dim, x, y, compound_functor, out, intermediate_out); + } else { + // Whether the shape of Y is a continuous subsequence of X, + // For more information please refer to the op's introduction. + bool bcast_y = x.dims().size() >= y.dims().size(); + if (x.dims().size() == y.dims().size()) { + for (int i = 0; i < x.dims().size(); ++i) { + if (x.dims()[i] < y.dims()[i]) { + bcast_y = false; + break; + } + } + } + + // z = f1(x, f2(y)) + // z = f1(f2(x, y)) + if (bcast_y) { // Y should be broadcast. + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape + // of Y. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of X. + FusedElemwiseAndActComputeWithBroadcast( + ctx, + x_dim /*OutShape*/, + y_dim, + x, + y, + compound_functor, + axis, + out, + intermediate_out); + } else { + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape + // of Out. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of Y. + FusedElemwiseAndActComputeWithBroadcast( + ctx, + y_dim /*OutShape*/, + x_dim, + x, + y, + compound_functor, + axis, + out, + intermediate_out); + } + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/rnn_compute.cc b/lite/kernels/x86/rnn_compute.cc index 9fb996366b9..438c08f246e 100644 --- a/lite/kernels/x86/rnn_compute.cc +++ b/lite/kernels/x86/rnn_compute.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/x86/rnn_compute.h" +#include "lite/backends/x86/math/rnn.h" #include #include #include @@ -20,7 +20,6 @@ #include "lite/backends/host/math/split.h" #include "lite/backends/x86/math/blas.h" #include "lite/backends/x86/math/concat_and_split.h" -#include "lite/backends/x86/math/rnn.h" #include "lite/kernels/x86/rnn_compute.h" namespace paddle { @@ -28,26 +27,28 @@ namespace lite { namespace kernels { namespace x86 { -#define RUN_LSTM_LAYER(x, y, z, w) \ - runLSTMLayer(&ctx, \ - input_temp_holder, \ - parameter_lists[x], \ - init_h_unbind, \ - init_c_unbind, \ - sequence_length, \ - &last_h_unbind, \ - &last_c_unbind, \ - y, \ - x, \ - &gate_value, \ - z, \ - w) - -void reset_parameter_vector(const std::vector& raw_params_vec, - const int& num_layers, - const int& gate_num, - const bool& is_bidirec, - std::vector>* params_vec) { +#define RUN_RNN_LAYER(x, y, z, w) \ + RunRnnLayer(&ctx, \ + input_temp_holder, \ + parameter_lists[x], \ + init_h_unbind, \ + init_c_unbind, \ + sequence_length, \ + &last_h_unbind, \ + &last_c_unbind, \ + y, \ + x, \ + &gate_value, \ + z, \ + w, \ + mode) + +static void reset_parameter_vector( + const std::vector& raw_params_vec, + const int& num_layers, + const int& gate_num, + const bool& is_bidirec, + std::vector>* params_vec) { // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers @@ -75,19 +76,30 @@ void reset_parameter_vector(const std::vector& raw_params_vec, } } -void SwapPoniter(Tensor** a, Tensor** b) { +static void SwapPoniter(Tensor** a, Tensor** b) { Tensor* c = *a; *a = *b; *b = c; } -void preprocess(X86Context* ctx, - const Tensor* input, - const Tensor& weight, - const Tensor& bias_ih, - const Tensor& bias_hh, - Tensor* cache_input, - bool is_test) { +/****************************************************** +input: + ctx:context, + input:(3D)time_step, batch, input_size, + weight:(2D)hidden_size, input_size, + bias_ih, + bias_hh, + mode:LSTM, GRU +output: + cache_input:(3D)time_step, batch, hidden_size +*******************************************************/ +static void preprocess(X86Context* ctx, + const Tensor* input, + const Tensor& weight, + const Tensor& bias_ih, + const Tensor& bias_hh, + std::string mode, + Tensor* cache_input) { const int& hidden_size = weight.dims()[0]; int time_step = input->dims()[0]; int batch = input->dims()[1]; @@ -106,23 +118,175 @@ void preprocess(X86Context* ctx, int k = input_dims[2]; int n = weight_input_dims[0]; - paddle::lite::x86::math::Blas matmul(*ctx); + lite::x86::math::Blas matmul(*ctx); matmul.GEMM( false, true, m, n, k, 1.f, i_data, k, w_data, k, 0.f, o_data, n); lite::x86::math::fill_bias_fc(o_data, bias_ih.data(), m, n); - lite::x86::math::fill_bias_fc(o_data, bias_hh.data(), m, n); + + if ("GRU" == mode) { + Tensor bias_tmp_hh; + bias_tmp_hh.Resize(bias_hh.dims()); + auto bias_ptr = bias_tmp_hh.mutable_data(); + auto bias_src = bias_hh.data(); + int bias_offt = bias_hh.numel() / 3 * 2; + std::memcpy(bias_ptr, bias_src, bias_offt * sizeof(float)); + std::memset( + bias_ptr + bias_offt, 0, (bias_hh.numel() - bias_offt) * sizeof(float)); + lite::x86::math::fill_bias_fc(o_data, bias_tmp_hh.data(), m, n); + } else { + lite::x86::math::fill_bias_fc(o_data, bias_hh.data(), m, n); + } +} + +/****************************************************** +input: + ctx:context, + init_h:(2D), + init_c:(2D), + mask_tensor:(1D)input->dims()[1], + mode:LSTM, GRU +output: + output:(2D)output->dims()[1], output->dims()[2], + last_h:(2D), + last_c:(2D) +*******************************************************/ +static void postprocess(X86Context* ctx, + Tensor* output, + const Tensor* init_h, + const Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + const Tensor& mask_tensor, + std::string mode) { + Tensor mask_broadcast_1; + mask_broadcast_1.Resize(mask_tensor.dims()); + auto mask_ptr_1 = mask_broadcast_1.mutable_data(); + auto mask_ptr = mask_tensor.data(); + auto out_ptr = output->mutable_data(); + auto cur_h_ptr = last_h->mutable_data(); + auto pre_h_ptr = init_h->data(); + int offset = 0; + + // out = out * mask_broadcast + // curr_h = out * mask_broadcast + pre_h * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + mask_ptr_1[i] = 1 - mask_ptr[i]; + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + out_ptr[offset] *= mask_ptr[i]; + cur_h_ptr[offset] = out_ptr[offset] + pre_h_ptr[offset] * mask_ptr_1[i]; + } + } + if ("LSTM" == mode) { + auto pre_c_ptr = init_c->data(); + auto cur_c_ptr = last_c->mutable_data(); + + // curr_c = curr_c * mask_broadcast + pre_c * (1 - mask_broadcast); + for (int i = 0; i < output->dims()[0]; i++) { + for (int j = 0; j < output->dims()[1]; j++) { + offset = i * output->dims()[1] + j; + cur_c_ptr[offset] = + cur_c_ptr[offset] * mask_ptr[i] + pre_c_ptr[offset] * mask_ptr_1[i]; + } + } + } +} + +static DDim get_stride(const DDim& ddim) { + DDim strides; + strides[ddim.size() - 1] = 1; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i + 1]; + } + return strides; +} + +template +static void TransposeNormal(const Tensor& in, + Tensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = get_stride(in.dims()); + auto out_stride = get_stride(out->dims()); + const T* in_ptr = in.data(); + T* out_ptr = out->mutable_data(); + + auto transpose_helper = [&](int64_t beg, int64_t end) { + for (int64_t out_idx = beg; out_idx < end; ++out_idx) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + // calculate the input index + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride[i]; + tmp_idx -= coordinate * out_stride[i]; + in_idx += coordinate * in_stride[axis[i]]; + } + out_ptr[out_idx] = in_ptr[in_idx]; + } + }; + transpose_helper(0, out->numel()); +} + +/****************************************************** +input: + sequence_length, + is_reverse +output: + mask_matrix, + min_seq_len +******************************************************/ +static void create_mask_matrix(const Tensor* sequence_length, + Tensor* mask_matrix, + const bool& is_reverse, + int* min_seq_len) { + // Tensor to vector + std::vector seq_len_vec; + seq_len_vec.resize(sequence_length->numel()); + std::memcpy(&seq_len_vec[0], + sequence_length->data(), + sequence_length->numel() * sizeof(int)); + + const int& table_width = mask_matrix->dims()[0]; + Tensor temp; + DDimLite dims( + std::vector{mask_matrix->dims()[1], mask_matrix->dims()[0]}); + temp.Resize(dims); + float* data_temp = temp.mutable_data(); + std::fill(data_temp, data_temp + mask_matrix->numel(), 1.f); + *min_seq_len = table_width; + for (unsigned int i = 0; i < seq_len_vec.size(); i++) { + // reset the mask matrix + *min_seq_len = std::min(seq_len_vec[i], *min_seq_len); + if (seq_len_vec[i] == table_width) { + continue; + } + if (is_reverse) { + std::fill(data_temp + i * table_width, + data_temp + (i + 1) * table_width - seq_len_vec[i], + 0.f); + } else { + std::fill(data_temp + i * table_width + seq_len_vec[i], + data_temp + (i + 1) * table_width, + 0.f); + } + } + mask_matrix->mutable_data(); + std::vector trans_vec; + trans_vec.emplace_back(1); + trans_vec.emplace_back(0); + TransposeNormal(temp, mask_matrix, trans_vec); } -void cell(X86Context* ctx, - Tensor* input, - Tensor* weight_hh, - Tensor* init_h, - Tensor* init_c, - Tensor* last_h, - Tensor* last_c, - Tensor* last_c_act, - Tensor* output, - const Tensor* bias_hh) { +static void lstm_cell(X86Context* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh) { auto h_dims = init_h->dims(); auto weight_input_dims = weight_hh->dims(); int m = h_dims[0]; @@ -136,7 +300,7 @@ void cell(X86Context* ctx, tmp_gate.Resize(input->dims()); auto tmp_data = tmp_gate.mutable_data(); - paddle::lite::x86::math::Blas matmul(*ctx); + lite::x86::math::Blas matmul(*ctx); matmul.GEMM( false, true, m, n, k, 1.f, h_data, k, w_data, k, 0.f, tmp_data, n); for (int i = 0; i < input->dims()[0] * input->dims()[1]; i++) { @@ -183,19 +347,72 @@ void cell(X86Context* ctx, 1); } -void runLSTMLayer(X86Context* ctx, - const Tensor* input, - std::vector vec, - std::vector init_h, - std::vector init_c, - const Tensor* sequence_length, - std::vector* last_h_ptr, - std::vector* last_c_ptr, - Tensor* output, - int layer_idx, - Tensor* gate_value, - bool is_bidirect, - int offset) { +static void gru_cell(X86Context* ctx, + Tensor* input, + Tensor* weight_hh, + Tensor* init_h, + Tensor* init_c, + Tensor* last_h, + Tensor* last_c, + Tensor* last_c_act, + Tensor* output, + const Tensor* bias_hh, + Tensor* weight_hh_gru) { + auto h_dims = init_h->dims(); + auto weight_gru_dims = weight_hh_gru->dims(); + int m = h_dims[0]; + int k = h_dims[1]; + int n = weight_gru_dims[0]; + auto i_data = input->data(); + auto w_gru = weight_hh_gru->data(); + auto h_data = init_h->data(); + + Tensor tmp_gate; + tmp_gate.Resize(input->dims()); + auto tmp_data = tmp_gate.mutable_data(); + + lite::x86::math::Blas matmul(*ctx); + matmul.GEMM( + false, true, m, n, k, 1.f, h_data, k, w_gru, k, 0.f, tmp_data, n); + for (int i = 0; i < input->dims()[0] * input->dims()[1]; i++) { + tmp_data[i] += i_data[i]; + } + + size_t frame_size = init_h->dims()[1]; + size_t batch_size = init_h->dims()[0]; + + lite::x86::math::GRUMetaValue gru_value; + gru_value.gate_weight = weight_hh->data(); + gru_value.state_weight = + weight_hh->data() + 2 * frame_size * frame_size; + gru_value.reset_bias = bias_hh->data() + 2 * frame_size; + + gru_value.gate_value = tmp_data; + gru_value.reset_output_value = last_c->mutable_data(); + gru_value.output_value = output->mutable_data(); + gru_value.prev_out_value = init_h->data(); + + auto gate_act = lite_api::ActivationType::kSigmoid_v2; + auto cand_act = lite_api::ActivationType::kTanh_v2; + + lite::x86::math::RnnGruUnitFunctorV2::compute( + ctx, gru_value, frame_size, batch_size, cand_act, gate_act); +} + +static void RunRnnLayer(X86Context* ctx, + const Tensor* input, + std::vector vec, + std::vector init_h, + std::vector init_c, + const Tensor* sequence_length, + std::vector* last_h_ptr, + std::vector* last_c_ptr, + Tensor* output, + int layer_idx, + Tensor* gate_value, + bool is_bidirect, + int offset, + std::string mode) { bool is_reverse = false; if (is_bidirect) { layer_idx = 2 * layer_idx + offset; @@ -210,25 +427,24 @@ void runLSTMLayer(X86Context* ctx, vec[0 + offset * 4], vec[2 + offset * 4], vec[3 + offset * 4], - gate_value, - true); + mode, + gate_value); + std::vector input_tensors, output_tensors; std::vector input_tensors_t, output_tensors_t; - std::vector stride1, stride2; - input_tensors.resize(gate_value->dims()[0]); // time_step + std::vector stride1, stride2, stride3; + input_tensors.resize(gate_value->dims()[0]); output_tensors.resize(output->dims()[0]); - // alloc input + // unbind for (int i = 0; i < gate_value->dims()[0]; i++) { stride1.push_back(1); - int dim1 = gate_value->dims()[1]; // batch - int dim2 = gate_value->dims()[2]; // hidden + int dim1 = gate_value->dims()[1]; + int dim2 = gate_value->dims()[2]; DDimLite dims(std::vector{dim1, dim2}); input_tensors[i].Resize(dims); input_tensors_t.push_back(&input_tensors[i]); } - - // alloc output for (int i = 0; i < output->dims()[0]; i++) { stride2.push_back(1); int dim1 = output->dims()[1]; @@ -237,32 +453,55 @@ void runLSTMLayer(X86Context* ctx, output_tensors[i].Resize(dims); output_tensors_t.push_back(&output_tensors[i]); } - lite::host::math::split( gate_value->data(), input_tensors_t, 0, stride1); lite::host::math::split(output->data(), output_tensors_t, 0, stride2); auto sd = output->mutable_data(); if (is_reverse) { + // don't need to reverse input_tensors_t becauese of unuseful std::reverse(input_tensors.begin(), input_tensors.end()); } bool has_sequence_length = false; + if (sequence_length != nullptr) { + has_sequence_length = true; + } + // unbind + Tensor mask_matrix; + std::vector mask_vec; + std::vector mask_tensor_list; + int mask_min_length = time_step; + /* - TODO has_sequence_length + to be verifying! */ - - int mask_min_length = time_step; + if (has_sequence_length) { + mask_matrix.Resize(DDimLite({time_step, input->dims()[1]})); + create_mask_matrix( + sequence_length, &mask_matrix, is_reverse, &mask_min_length); + for (int i = 0; i < time_step; i++) { + stride3.push_back(1); + DDimLite ddims(std::vector{input->dims()[1]}); + mask_vec[i].Resize(ddims); + mask_tensor_list.push_back(&mask_vec[i]); + } + lite::host::math::split( + mask_matrix.data(), mask_tensor_list, 0, stride3); + } if (is_reverse) { mask_min_length = mask_min_length - time_step + 1; } - bool has_allocate_mem_c = false; + bool has_use_last_h_holder = false; const int& reverse_flag = is_reverse ? -1 : 1; + bool has_allocate_mem_c = false; + + // define the init_h holder for the swap Tensor init_h_temp; + init_h_temp.Resize(init_h[layer_idx].dims()); init_h_temp.CopyDataFrom(init_h[layer_idx]); Tensor* init_h_holder = &init_h_temp; Tensor* last_h_holder = nullptr; - if (0 < mask_min_length) { last_h_holder = &(output_tensors[0]); } else { @@ -275,38 +514,83 @@ void runLSTMLayer(X86Context* ctx, Tensor init_c_temp; Tensor* last_c_holder = nullptr; Tensor last_c_temp; - last_c_holder = &(*last_c_ptr)[layer_idx]; - init_c_temp_holder = &init_c[layer_idx]; + + if ("LSTM" == mode) { + last_c_holder = &(*last_c_ptr)[layer_idx]; + init_c_temp_holder = &init_c[layer_idx]; + } else if ("GRU" == mode) { + // for reset output value + last_c_temp.Resize(init_h[layer_idx].dims()); + last_c_temp.mutable_data(); + last_c_holder = &last_c_temp; + } + + Tensor weight_hh_tmp; // for gru + std::vector weight_hh_tmp_ubind; + std::vector weight_hh_tmp_ubind_t; + std::vector stride_w; + if ("GRU" == mode) { + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + weight_hh_tmp.mutable_data(); + weight_hh_tmp.CopyDataFrom(vec[1 + offset * 4]); + int size = weight_hh_tmp.numel() / 3; + std::memset(weight_hh_tmp.mutable_data() + size * 2, + 0, + size * sizeof(float)); + } for (int i = 0; i < time_step; i++) { bool in_mask = (reverse_flag * i) >= mask_min_length; if (i > 0) { if (!has_allocate_mem_c) { - init_c_temp.Resize(init_h[layer_idx].dims()); - init_c_temp.mutable_data(); - init_c_holder = &init_c_temp; + if (("LSTM" == mode) || ("GRU" == mode)) { + init_c_temp.Resize(init_h[layer_idx].dims()); + init_c_temp.mutable_data(); + init_c_holder = &init_c_temp; + } has_allocate_mem_c = true; } SwapPoniter(&init_c_holder, &last_c_holder); init_c_temp_holder = init_c_holder; } - // LSTMCELL - cell(ctx, - &input_tensors[i], - &vec[1 + offset * 4], - init_h_holder, - init_c_temp_holder, - last_h_holder, - last_c_holder, - nullptr, - &output_tensors[i], - &vec[3 + offset * 4]); + if ("LSTM" == mode) { + lstm_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4]); + } else if ("GRU" == mode) { + gru_cell(ctx, + &input_tensors[i], + &vec[1 + offset * 4], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + nullptr, + &output_tensors[i], + &vec[3 + offset * 4], + &weight_hh_tmp); + } + /* + to be verifying! + */ if (in_mask) { - /* - TODO in_mask - */ + postprocess(ctx, + &output_tensors[i], + init_h_holder, + init_c_temp_holder, + last_h_holder, + last_c_holder, + mask_vec[i], + mode); } // prepare next step @@ -322,6 +606,9 @@ void runLSTMLayer(X86Context* ctx, SwapPoniter(&init_h_holder, &last_h_holder); } } + + // unbind vector and source are not in the same address, need copy + // different from paddle if (is_reverse) { std::reverse(output_tensors.begin(), output_tensors.end()); } @@ -339,7 +626,7 @@ void runLSTMLayer(X86Context* ctx, } else { (*last_h_ptr)[layer_idx].CopyDataFrom(output_tensors[time_step - 1]); } - if (time_step % 2 == 0) { + if ((0 == (time_step % 2)) && ("LSTM" == mode)) { (*last_c_ptr)[layer_idx].CopyDataFrom(*last_c_holder); } } @@ -356,12 +643,24 @@ void RnnCompute::Run() { bool is_bidirec = param.is_bidirec; int num_layers = param.num_layers; const Tensor* sequence_length = param.SequenceLength; + int gate_num = 0; - state[0]->mutable_data(); - state[1]->mutable_data(); + if ("LSTM" == mode) { + gate_num = 4; + } else if ("GRU" == mode) { + gate_num = 3; + } else { + LOG(FATAL) << "X86 RNN ERROR: unsupport mode except gru and lstm," + " present mode is " + << mode; + return; + } - // lstmCell begin - int gate_num = 4; + state[0]->mutable_data(); + if ("LSTM" == mode) { + state[1]->mutable_data(); + } + // reset the parameter to sorted order and allocate the memory std::vector> parameter_lists; parameter_lists.reserve(num_layers); reset_parameter_vector( @@ -375,31 +674,57 @@ void RnnCompute::Run() { last_c_unbind; std::vector init_h_unbind_t, init_c_unbind_t, last_h_unbind_t, last_c_unbind_t; - init_h_unbind.resize(4); - init_c_unbind.resize(pre_state[1]->dims()[0]); + init_h_unbind.resize(pre_state[0]->dims()[0]); last_h_unbind.resize(state[0]->dims()[0]); - last_c_unbind.resize(state[1]->dims()[0]); - std::vector stride; + + if ("LSTM" == mode) { + init_c_unbind.resize(pre_state[1]->dims()[0]); + last_c_unbind.resize(state[1]->dims()[0]); + } + std::vector stride1, stride2; + + // unbind for (int i = 0; i < pre_state[0]->dims()[0]; i++) { - stride.push_back(1); + stride1.push_back(1); int dim1 = pre_state[0]->dims()[1]; int dim2 = pre_state[0]->dims()[2]; DDimLite dims(std::vector{dim1, dim2}); init_h_unbind[i].Resize(dims); - init_c_unbind[i].Resize(dims); last_h_unbind[i].Resize(dims); - last_c_unbind[i].Resize(dims); init_h_unbind_t.push_back(&init_h_unbind[i]); - init_c_unbind_t.push_back(&init_c_unbind[i]); last_h_unbind_t.push_back(&last_h_unbind[i]); - last_c_unbind_t.push_back(&last_c_unbind[i]); } lite::host::math::split( - pre_state[0]->data(), init_h_unbind_t, 0, stride); - lite::host::math::split( - pre_state[1]->data(), init_c_unbind_t, 0, stride); - lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride); - lite::host::math::split(state[1]->data(), last_c_unbind_t, 0, stride); + pre_state[0]->data(), init_h_unbind_t, 0, stride1); + lite::host::math::split(state[0]->data(), last_h_unbind_t, 0, stride1); + + if ("LSTM" == mode) { + for (int i = 0; i < pre_state[1]->dims()[0]; i++) { + stride2.push_back(1); + int dim1 = pre_state[1]->dims()[1]; + int dim2 = pre_state[1]->dims()[2]; + DDimLite dims(std::vector{dim1, dim2}); + init_c_unbind[i].Resize(dims); + last_c_unbind[i].Resize(dims); + init_c_unbind_t.push_back(&init_c_unbind[i]); + last_c_unbind_t.push_back(&last_c_unbind[i]); + } + lite::host::math::split( + pre_state[1]->data(), init_c_unbind_t, 0, stride2); + lite::host::math::split( + state[1]->data(), last_c_unbind_t, 0, stride2); + } + + std::vector output_vec(2); + int time_step = input->dims()[0]; + int batch_size = input->dims()[1]; + int hidden_size = output->dims()[2]; + if (is_bidirec) { + for (int i = 0; i < 2; ++i) { + output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); + output_vec[i].mutable_data(); + } + } for (int i = 0; i < num_layers; i++) { if (i > 0) { @@ -418,25 +743,18 @@ void RnnCompute::Run() { } if (is_bidirec) { - std::vector output_vec(2); - int time_step = input->dims()[0]; - int batch_size = input->dims()[1]; - int hidden_size = output->dims()[2]; - for (int i = 0; i < 2; ++i) { - output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); - output_vec[i].mutable_data(); - } - - RUN_LSTM_LAYER(i, &output_vec[0], true, 0); - RUN_LSTM_LAYER(i, &output_vec[1], true, 1); - - paddle::lite::x86::math::ConcatFunctor - concat_x86; - concat_x86(ctx, output_vec, 2, output); + RUN_RNN_LAYER(i, &output_vec[0], true, 0); + RUN_RNN_LAYER(i, &output_vec[1], true, 1); + lite::x86::math::ConcatFunctor concat_x86; + concat_x86(ctx, output_vec, 2, output_holder); } else { - RUN_LSTM_LAYER(i, output_holder, false, 0); + RUN_RNN_LAYER(i, output_holder, false, 0); } } + // output_holder != output + if (num_layers % 2 == 0) { + output->CopyDataFrom(*output_holder); + } } } // namespace x86 diff --git a/lite/kernels/x86/rnn_compute.h b/lite/kernels/x86/rnn_compute.h index e4dff2a3014..70d43a3c5bc 100644 --- a/lite/kernels/x86/rnn_compute.h +++ b/lite/kernels/x86/rnn_compute.h @@ -29,7 +29,7 @@ class RnnCompute : public KernelLite { virtual ~RnnCompute() = default; }; -} // namespace arm +} // namespace x86 } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index 38bf62f42de..363eea8f144 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -145,7 +145,9 @@ TEST(Cast, precision) { #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) place = TARGET(kXPU); #elif defined(LITE_WITH_ARM) - place = TARGET(kARM); + place = TARGET(kHost); +#elif defined(LITE_WITH_X86) + place = TARGET(kHost); #else return; #endif diff --git a/lite/tests/kernels/elementwise_common_broadcast_test.cc b/lite/tests/kernels/elementwise_common_broadcast_test.cc index 6d8b93cb354..9b982df6039 100644 --- a/lite/tests/kernels/elementwise_common_broadcast_test.cc +++ b/lite/tests/kernels/elementwise_common_broadcast_test.cc @@ -506,4 +506,4 @@ TEST(elementwise_broadcast, compute_i64) { } } -#endif // LITE_WITH_ARM +#endif // LITE_WITH_X86