Skip to content

Commit

Permalink
fix softmax arm fp16s sum error, fix #5340 (#5393)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Mar 30, 2024
1 parent 6595743 commit 167501f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/layer/arm/softmax_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
float16x8_t _ss01 = vpaddq_f16(_p0, _p1);
float16x8_t _ss23 = vpaddq_f16(_p2, _p3);
float16x8_t _ss2 = vpaddq_f16(_ss01, _ss23);
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
vst1_f16(sumptr, _sum);
ptr += 32;
maxptr += 4;
Expand Down Expand Up @@ -292,7 +292,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
vst1q_f16(ptr, _p0);
vst1q_f16(ptr + 8, _p1);
float16x8_t _ss2 = vpaddq_f16(_p0, _p1);
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
vst1_f16(sumptr, _sum);
ptr += 16;
maxptr += 4;
Expand Down Expand Up @@ -743,7 +743,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
float16x8_t _ss01 = vpaddq_f16(_p0, _p1);
float16x8_t _ss23 = vpaddq_f16(_p2, _p3);
float16x8_t _ss2 = vpaddq_f16(_ss01, _ss23);
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
vst1_f16(sumptr, _sum);
ptr += 32;
sumptr += 4;
Expand All @@ -768,7 +768,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
float16x8_t _p1 = vld1q_f16(ptr + 8);
float16x4_t _sum = vld1_f16(sumptr);
float16x8_t _ss2 = vpaddq_f16(_p0, _p1);
_sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
_sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2)));
vst1_f16(sumptr, _sum);
ptr += 16;
sumptr += 4;
Expand Down

0 comments on commit 167501f

Please sign in to comment.