Skip to content

Commit

Permalink
[X86/ARM] add gru mode for rnn (#7026)
Browse files Browse the repository at this point in the history
* [X86] Add GRU for RNN, complete elementwise op, move cast from arm to host, fix precision_profile bug

* pre-commit

* [ARM] add RNN-GRU OP; Optimize RNN-GRU OP

* fix complie bug

* fix elementwise left problem

* merge develop

* fix windows ci

* change arm cast test to host cast test
  • Loading branch information
mjp9527 authored Sep 27, 2021
1 parent ff9a54e commit c33793e
Show file tree
Hide file tree
Showing 21 changed files with 2,442 additions and 478 deletions.
248 changes: 248 additions & 0 deletions lite/backends/arm/math/gru.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "lite/backends/arm/math/sgemm.h"
#ifdef LITE_WITH_ARM
#include <arm_neon.h>
#endif

namespace paddle {
namespace lite {
namespace arm {
namespace math {

template <typename T>
struct RNNGRUValue {
const T* gate_weight;
const T* state_weight;
const T* reset_bias;
T* gate_value;
T* reset_output_value;
T* output_value;
const T* prev_out_value;
};

template <typename T>
void rnn_activation(const T* din,
T* dout,
int size,
lite_api::ActivationType act_type,
int threads) {
switch (act_type) {
case lite_api::ActivationType::kSigmoid:
act_sigmoid(din, dout, size, threads);
break;
case lite_api::ActivationType::kSigmoid_v2:
act_sigmoid(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh:
act_tanh(din, dout, size, threads);
break;
case lite_api::ActivationType::kTanh_v2:
act_tanh(din, dout, size, threads);
break;
case lite_api::ActivationType::kRelu:
act_relu(din, dout, size, threads);
break;
default:
LOG(FATAL) << "unsupport activation type:" << static_cast<int>(act_type);
break;
}
}

template <typename T>
void compute_kernel(RNNGRUValue<T> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
auto value_reset_gate = value.gate_value;
auto value_update_gate = value.gate_value + frame_size;
auto value_reset_output = value.reset_output_value;
auto value_reset_bias = value.reset_bias;
auto cell_state_value = value.gate_value + 2 * frame_size;
auto value_output = value.output_value;
auto value_prev_out = value.prev_out_value;

for (int b = 0; b < batch_size; b++) {
rnn_activation(value_reset_gate,
value_reset_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);
rnn_activation(value_update_gate,
value_update_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);

for (int i = 0; i < frame_size; i++) {
value_reset_output[i] =
(value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i];
cell_state_value[i] += value_reset_output[i];
}

rnn_activation(cell_state_value,
cell_state_value,
frame_size,
lite_api::ActivationType::kTanh_v2,
1);

if (value.prev_out_value) {
for (int i = 0; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] +
value_update_gate[i] * value_prev_out[i];
}
} else {
for (int i = 0; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i];
}
}

value_reset_gate += frame_size * 3;
value_update_gate += frame_size * 3;
value_reset_output += frame_size;
cell_state_value += frame_size * 3;
value_output += frame_size;
if (value.prev_out_value) {
value_prev_out += frame_size;
}
}
}

template <>
void compute_kernel<float>(RNNGRUValue<float> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
auto value_reset_gate = value.gate_value;
auto value_update_gate = value.gate_value + frame_size;
auto value_reset_output = value.reset_output_value;
auto value_reset_bias = value.reset_bias;
auto cell_state_value = value.gate_value + 2 * frame_size;
auto value_output = value.output_value;
auto value_prev_out = value.prev_out_value;
int i = 0;
float32x4_t vec_one = vdupq_n_f32(1.f);

for (int b = 0; b < batch_size; b++) {
rnn_activation(value_reset_gate,
value_reset_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);
rnn_activation(value_update_gate,
value_update_gate,
frame_size,
lite_api::ActivationType::kSigmoid_v2,
1);

for (i = 0; i + 3 < frame_size; i += 4) {
float32x4_t vec_out = vld1q_f32(value_reset_output + i);
float32x4_t vec_reset = vld1q_f32(value_reset_gate + i);
float32x4_t vec_bias = vld1q_f32(value_reset_bias + i);
vec_out = vmulq_f32(vaddq_f32(vec_out, vec_bias), vec_reset);
vst1q_f32(value_reset_output + i, vec_out);
vst1q_f32(cell_state_value + i,
vaddq_f32(vec_out, vld1q_f32(cell_state_value + i)));
}
for (; i < frame_size; i++) {
value_reset_output[i] =
(value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i];
cell_state_value[i] += value_reset_output[i];
}

rnn_activation(cell_state_value,
cell_state_value,
frame_size,
lite_api::ActivationType::kTanh_v2,
1);

if (value.prev_out_value) {
for (i = 0; i + 3 < frame_size; i += 4) {
float32x4_t vec_vug = vld1q_f32(value_update_gate + i);
float32x4_t vec_vpo = vld1q_f32(value_prev_out + i);
float32x4_t vec_csv = vld1q_f32(cell_state_value + i);
vec_vpo = vmulq_f32(vec_vug, vec_vpo);
float32x4_t vec_out =
vmlaq_f32(vec_vpo, vsubq_f32(vec_one, vec_vug), vec_csv);
vst1q_f32(value_output + i, vec_out);
}
for (; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] +
value_update_gate[i] * value_prev_out[i];
}
} else {
for (i = 0; i + 3 < frame_size; i += 4) {
float32x4_t vec_vug = vld1q_f32(value_update_gate + i);
float32x4_t vec_csv = vld1q_f32(cell_state_value + i);
float32x4_t vec_out = vmulq_f32(vsubq_f32(vec_one, vec_vug), vec_csv);
vst1q_f32(value_output + i, vec_out);
}
for (; i < frame_size; i++) {
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i];
}
}

value_reset_gate += frame_size * 3;
value_update_gate += frame_size * 3;
value_reset_output += frame_size;
cell_state_value += frame_size * 3;
value_output += frame_size;
if (value.prev_out_value) {
value_prev_out += frame_size;
}
}
}

template <typename T>
struct RnnGruUnitFunctorV2 {
static void compute(ARMContext* ctx,
RNNGRUValue<T> value,
int frame_size,
int batch_size,
lite_api::ActivationType active_node,
lite_api::ActivationType active_gate) {
if (value.prev_out_value) {
operators::ActivationParam act_param;
act_param.has_active = false;
lite::arm::math::sgemm(false,
true,
batch_size,
frame_size,
frame_size,
1.f,
value.prev_out_value,
frame_size,
value.state_weight,
frame_size,
0.f,
value.reset_output_value,
frame_size,
nullptr,
false,
act_param,
ctx);
}
compute_kernel(value, frame_size, batch_size, active_node, active_gate);
}
};

} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/backends/arm/math/lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void add_bias_rowwise(Tensor* input,
i_data += width;
}
}

void vector_dot(
float* out, const float* in, const float* v1, int size, const float* v2) {
int loop = size >> 2;
Expand Down
39 changes: 21 additions & 18 deletions lite/backends/x86/math/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,27 @@ namespace math {
} \
}

// marco func add
ElementWiseFunc(Add) ElementWiseFuncBCast(Add)
// marco func sub
ElementWiseFunc(Sub) ElementWiseFuncBCast(Sub)
// marco func mul
ElementWiseFunc(Mul) ElementWiseFuncBCast(Mul)
// marco func max
ElementWiseFunc(Max) ElementWiseFuncBCast(Max)
// marco func min
ElementWiseFunc(Min) ElementWiseFuncBCast(Min)
// marco func div
ElementWiseFunc(Div) ElementWiseFuncBCast(Div)
// marco func floordiv
ElementWiseFunc(FloorDiv) ElementWiseFuncBCast(FloorDiv)
// marco func mod
ElementWiseFunc(Mod) ElementWiseFuncBCast(Mod)
// marco func pow
ElementWiseFunc(Pow) ElementWiseFuncBCast(Pow)
// clang-format off
ElementWiseFunc(Add)
ElementWiseFuncBCast(Add)
ElementWiseFunc(Sub)
ElementWiseFuncBCast(Sub)
ElementWiseFunc(Mul)
ElementWiseFuncBCast(Mul)
ElementWiseFunc(Max)
ElementWiseFuncBCast(Max)
ElementWiseFunc(Min)
ElementWiseFuncBCast(Min)
ElementWiseFunc(Div)
ElementWiseFuncBCast(Div)
ElementWiseFunc(FloorDiv)
ElementWiseFuncBCast(FloorDiv)
ElementWiseFunc(Mod)
ElementWiseFuncBCast(Mod)
ElementWiseFunc(Pow)
ElementWiseFuncBCast(Pow)
// clang-format on

} // namespace math
} // namespace x86
} // namespace lite
Expand Down
7 changes: 0 additions & 7 deletions lite/backends/x86/math/fill_bias_activate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) {
__m256 vec_data = _mm256_loadu_ps(data + i);
_mm256_storeu_ps(data + i, _mm256_max_ps(vec_data, vec_zero));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
__m128 vec_zero_128 = _mm_set1_ps(0.f);
Expand All @@ -59,7 +58,6 @@ static void activate_relu_inplace(float *data, int len, float alpha, int mode) {
_mm256_storeu_ps(
data + i, _mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
__m128 vec_zero_128 = _mm_set1_ps(0.f);
Expand Down Expand Up @@ -112,7 +110,6 @@ static void activate_relu_inplace_bias(float *data,
vec_data = _mm256_add_ps(vec_bias, vec_data);
_mm256_storeu_ps(tmp_data + i, _mm256_max_ps(vec_data, vec_zero));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down Expand Up @@ -140,7 +137,6 @@ static void activate_relu_inplace_bias(float *data,
tmp_data + i,
_mm256_min_ps(_mm256_max_ps(vec_data, vec_zero), vec_alph));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down Expand Up @@ -174,7 +170,6 @@ static void activate_lrelu_inplace(float *data, int len, float alpha) {
__m256 vec_mask = _mm256_cmp_ps(vec_data, vec_zero, cmp_le_os);
_mm256_storeu_ps(data + i, _mm256_blendv_ps(vec_data, vec_lr, vec_mask));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE4_1__ // blendv need 4.1
__m128 vec_zero_128 = _mm_set1_ps(0.f);
Expand Down Expand Up @@ -226,7 +221,6 @@ static void activate_lrelu_inplace_bias(float *data,
_mm256_storeu_ps(tmp_data + i,
_mm256_blendv_ps(vec_data, vec_lr, vec_mask));
}
// _mm256_zeroupper();
#endif
#ifdef __SSE4_1__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down Expand Up @@ -273,7 +267,6 @@ static void activate_none_inplace_bias(float *data,
vec_data = _mm256_add_ps(vec_bias, vec_data);
_mm256_storeu_ps(tmp_data + i, vec_data);
}
// _mm256_zeroupper();
#endif
#ifdef __SSE__
vec_bias_128 = _mm_set1_ps(bias[j]);
Expand Down
Loading

0 comments on commit c33793e

Please sign in to comment.