Skip to content

Commit

Permalink
opt
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 30, 2024
1 parent abbbb7b commit 5fd9ab3
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 54 deletions.
30 changes: 30 additions & 0 deletions src/layer/arm/rnn_arm_vfpv4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "cpu.h"
#include "mat.h"
#include "layer.h"
#include "arm_activation.h"
#include "arm_usability.h"

namespace ncnn {

#include "rnn_int8.h"

void rnn_int8_gate_output_vfpv4(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt)
{
rnn_int8_gate_output(gates, hidden_state, top_blob, ti, elemtype, opt);
}

} // namespace ncnn
155 changes: 101 additions & 54 deletions src/layer/arm/rnn_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ void rnn_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_x
void rnn_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt);
#endif

#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
void rnn_int8_gate_output_vfpv4(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt);
#endif

static void rnn_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt)
{
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Expand Down Expand Up @@ -210,6 +214,102 @@ static void rnn_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc
}
}

static void rnn_int8_gate_output(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt)
{
#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
if (ncnn::cpu_support_arm_vfpv4())
{
rnn_int8_gate_output_vfpv4(gates, hidden_state, top_blob, ti, elemtype, opt);
return;
}
#endif

const int num_output = top_blob.w;

float* output_data = top_blob.row(ti);

float* hidden_ptr = hidden_state;

int remain_num_output_start = 0;
#if __ARM_NEON
int nn_num_output = num_output >> 2;
#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
{
int q = qq * 4;

float32x4_t _rnn_H = vld1q_f32((const float*)gates + q);

vst1q_f32(hidden_ptr + q, _rnn_H);

if (elemtype == 1)
{
// fp32
vst1q_f32(output_data + q, _rnn_H);
}
if (elemtype == 2)
{
// fp16
unsigned short* outptr = (unsigned short*)output_data + q;
#if (__ARM_FP & 2)
#if NCNN_GNU_INLINE_ASM
#if __aarch64__
asm volatile(
"fcvtn v0.4h, %2.4s \n"
"st1 {v0.4h}, [%0] \n"
: "=r"(_rnn_H) // %0
: "0"(outptr),
"w"(_rnn_H)
: "memory", "v0");
#else // __aarch64__
asm volatile(
"vcvt.f16.f32 d0, %q2 \n"
"vst1.u16 {d0}, [%0] \n"
: "=r"(outptr) // %0
: "0"(outptr),
"w"(_rnn_H)
: "memory", "q0");
#endif // __aarch64__
#else // NCNN_GNU_INLINE_ASM
vst1_u16(outptr, (uint16x4_t)vcvt_f16_f32(_rnn_H));
#endif // NCNN_GNU_INLINE_ASM
#else
outptr[q] = float32_to_float16(hidden_ptr[q]);
outptr[q + 1] = float32_to_float16(hidden_ptr[q + 1]);
outptr[q + 2] = float32_to_float16(hidden_ptr[q + 2]);
outptr[q + 3] = float32_to_float16(hidden_ptr[q + 3]);
#endif // (__ARM_FP & 2)
}
if (elemtype == 4)
{
// bf16
vst1_u16((unsigned short*)output_data + q, float2bfloat(_rnn_H));
}
}
remain_num_output_start += nn_num_output << 2;
#endif // __ARM_NEON
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
float H = gates[q];

hidden_ptr[q] = H;

if (elemtype == 1)
{
output_data[q] = H;
}
if (elemtype == 2)
{
((unsigned short*)output_data)[q] = float32_to_float16(H);
}
if (elemtype == 4)
{
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
}
}
}

static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt)
{
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Expand Down Expand Up @@ -490,59 +590,6 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
gates[q] = H;
}

float* output_data = top_blob.row(ti);

float* hidden_ptr = hidden_state;

#if __ARM_NEON
nn_num_output = num_output >> 2;
remain_num_output_start = nn_num_output << 2;

#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
{
int q = qq * 4;

float32x4_t _rnn_H = vld1q_f32((float*)gates + q);

vst1q_f32(hidden_ptr + q, _rnn_H);

if (elemtype == 1)
{
// fp32
vst1q_f32(output_data + q, _rnn_H);
}
if (elemtype == 2)
{
// fp16
vst1_u16((unsigned short*)output_data + q, (uint16x4_t)vcvt_f16_f32(_rnn_H));
}
if (elemtype == 4)
{
// bf16
vst1_u16((unsigned short*)output_data + q, float2bfloat(_rnn_H));
}
}
#endif // __ARM_NEON
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
float H = gates[q];

hidden_ptr[q] = H;

if (elemtype == 1)
{
output_data[q] = H;
}
if (elemtype == 2)
{
((unsigned short*)output_data)[q] = float32_to_float16(H);
}
if (elemtype == 4)
{
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
}
}
rnn_int8_gate_output(gates, hidden_state, top_blob, ti, elemtype, opt);
}
}

0 comments on commit 5fd9ab3

Please sign in to comment.