Skip to content

Commit

Permalink
[ARM NEON] Get rid of redundant instructions in ScalarQuantizer
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
  • Loading branch information
alexanderguzhva committed May 14, 2024
1 parent 4d06d70 commit 81157d7
Showing 1 changed file with 19 additions and 32 deletions.
51 changes: 19 additions & 32 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ struct Codec8bit {
}
float32x4_t res1 = vld1q_f32(result);
float32x4_t res2 = vld1q_f32(result + 4);
float32x4x2_t res = vzipq_f32(res1, res2);
return vuzpq_f32(res.val[0], res.val[1]);
return {res1, res2};
}
#endif
};
Expand Down Expand Up @@ -153,8 +152,7 @@ struct Codec4bit {
}
float32x4_t res1 = vld1q_f32(result);
float32x4_t res2 = vld1q_f32(result + 4);
float32x4x2_t res = vzipq_f32(res1, res2);
return vuzpq_f32(res.val[0], res.val[1]);
return {res1, res2};
}
#endif
};
Expand Down Expand Up @@ -266,8 +264,7 @@ struct Codec6bit {
}
float32x4_t res1 = vld1q_f32(result);
float32x4_t res2 = vld1q_f32(result + 4);
float32x4x2_t res = vzipq_f32(res1, res2);
return vuzpq_f32(res.val[0], res.val[1]);
return {res1, res2};
}
#endif
};
Expand Down Expand Up @@ -345,16 +342,14 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
float32x4x2_t xi = Codec::decode_8_components(code, i);
float32x4x2_t res = vzipq_f32(
vfmaq_f32(
return {vfmaq_f32(
vdupq_n_f32(this->vmin),
xi.val[0],
vdupq_n_f32(this->vdiff)),
vfmaq_f32(
vdupq_n_f32(this->vmin),
xi.val[1],
vdupq_n_f32(this->vdiff)));
return vuzpq_f32(res.val[0], res.val[1]);
vdupq_n_f32(this->vdiff))};
}
};

Expand Down Expand Up @@ -431,10 +426,8 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);

float32x4x2_t res = vzipq_f32(
vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
return vuzpq_f32(res.val[0], res.val[1]);
return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])};
}
};

Expand Down Expand Up @@ -568,8 +561,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
}
float32x4_t res1 = vld1q_f32(result);
float32x4_t res2 = vld1q_f32(result + 4);
float32x4x2_t res = vzipq_f32(res1, res2);
return vuzpq_f32(res.val[0], res.val[1]);
return {res1, res2};
}
};

Expand Down Expand Up @@ -868,7 +860,7 @@ struct SimilarityL2<8> {
float32x4x2_t accu8;

FAISS_ALWAYS_INLINE void begin_8() {
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
yi = y;
}

Expand All @@ -882,8 +874,7 @@ struct SimilarityL2<8> {
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);

float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
accu8 = {accu8_0, accu8_1};
}

FAISS_ALWAYS_INLINE void add_8_components_2(
Expand All @@ -895,8 +886,7 @@ struct SimilarityL2<8> {
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);

float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
accu8 = {accu8_0, accu8_1};
}

FAISS_ALWAYS_INLINE float result_8() {
Expand Down Expand Up @@ -996,7 +986,7 @@ struct SimilarityIP<8> {
float32x4x2_t accu8;

FAISS_ALWAYS_INLINE void begin_8() {
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
yi = y;
}

Expand All @@ -1006,28 +996,25 @@ struct SimilarityIP<8> {

float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
accu8 = {accu8_0, accu8_1};
}

FAISS_ALWAYS_INLINE void add_8_components_2(
float32x4x2_t x1,
float32x4x2_t x2) {
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
accu8 = {accu8_0, accu8_1};
}

FAISS_ALWAYS_INLINE float result_8() {
float32x4x2_t sum_tmp = vzipq_f32(
float32x4x2_t sum = {
vpaddq_f32(accu8.val[0], accu8.val[0]),
vpaddq_f32(accu8.val[1], accu8.val[1]));
float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
float32x4x2_t sum2_tmp = vzipq_f32(
vpaddq_f32(accu8.val[1], accu8.val[1])};

float32x4x2_t sum2 = {
vpaddq_f32(sum.val[0], sum.val[0]),
vpaddq_f32(sum.val[1], sum.val[1]));
float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
vpaddq_f32(sum.val[1], sum.val[1])};
return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
}
};
Expand Down

0 comments on commit 81157d7

Please sign in to comment.