Skip to content

Commit

Permalink
bugfixing the AVX2 Extract8+16 codes, where there's lines like `__m25…
Browse files Browse the repository at this point in the history
…6d scale01234567 = _mm256_loadu_ps(scales)`, i.e. loading float vectors into double vector types. Extract from tesseract-ocr#3490.
  • Loading branch information
GerHobbelt authored and stweil committed Jul 21, 2021
1 parent 9a6e937 commit 50d7e75
Showing 1 changed file with 20 additions and 29 deletions.
49 changes: 20 additions & 29 deletions src/arch/intsimdmatrixavx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,54 +132,45 @@ static inline __m128i load64_to_128(const int8_t *wi_) {
}

#if defined(FAST_FLOAT)
static inline void ExtractResults8(__m256i result, const int8_t *wi, const float *scales,
float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg

static inline void ExtractResults8(__m256i result, const int8_t *wi,
const float *scales, float *v) {
__m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg
__m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg
__m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256d scale01234567 = _mm256_loadu_ps(scales);
//~ __m256d scale4567 = _mm256_loadu_ps(scales + 8);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result = _mm256_add_epi32(result, w256); // result += bias * 127
__m256 res01234567 = _mm256_cvtepi32_ps(_mm256_castsi256_si128(result));
__m256 res01234567 = _mm256_cvtepi32_ps(result);
result = _mm256_permute4x64_epi64(result, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result));
res01234567 = _mm256_mul_pd(res01234567, scale01234567);
//~ res4567 = _mm256_mul_pd(res4567, scale4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
//~ _mm256_storeu_pd(v + 4, res4567);
}

static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi,
const float *&scales, float *&v) {
static inline void ExtractResults16(__m256i result0, __m256i result1,
const int8_t *&wi, const float *&scales,
float *&v) {
__m128i w8 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(wi));
// 8x8bit vals in bottom of 128bit reg
const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
const __m256i bias_scale =
_mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127);
__m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
__m256d scale0123 = _mm256_loadu_ps(scales);
__m256d scale4567 = _mm256_loadu_ps(scales + 8);
__m256 scale01234567 = _mm256_loadu_ps(scales);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result0 = _mm256_add_epi32(result0, w256); // result += bias * 127
__m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
__m256 res01234567 = _mm256_cvtepi32_ps(result0);
result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2));
__m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v, res0123);
_mm256_storeu_ps(v + 8, res4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v, res01234567);
w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2));
w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg
scale0123 = _mm256_loadu_ps(scales + 16);
scale4567 = _mm256_loadu_ps(scales + 24);
scale01234567 = _mm256_loadu_ps(scales + 8);
w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 <bias * 127>
result1 = _mm256_add_epi32(result1, w256); // result += bias * 127
res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res01234567 = _mm256_cvtepi32_ps(result1);
result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2));
res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1));
res0123 = _mm256_mul_pd(res0123, scale0123);
res4567 = _mm256_mul_pd(res4567, scale4567);
_mm256_storeu_ps(v + 16, res0123);
_mm256_storeu_ps(v + 24, res4567);
res01234567 = _mm256_mul_ps(res01234567, scale01234567);
_mm256_storeu_ps(v + 8, res01234567);
wi += 16;
scales += 16;
v += 16;
Expand Down

0 comments on commit 50d7e75

Please sign in to comment.