diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index cbeafb89a..7d29e37a4 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -114,7 +114,12 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset int topk = cfg.k.value(); auto labels = std::make_unique(nq * topk); auto distances = std::make_unique(nq * topk); - + std::unique_ptr norms = nullptr; + if (is_cosine) { + ThreadPool::ScopedSearchOmpSetter setter(1); + norms = std::make_unique(nb); + faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); + } auto pool = ThreadPool::GetGlobalSearchThreadPool(); std::vector> futs; futs.reserve(nq); @@ -139,7 +144,8 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector); + faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf, + id_selector); } else { faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); } @@ -248,6 +254,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ auto labels = ids; auto distances = dis; + std::unique_ptr norms = nullptr; + if (is_cosine) { + ThreadPool::ScopedSearchOmpSetter setter(1); + norms = std::make_unique(nb); + faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); + } + auto pool = ThreadPool::GetGlobalSearchThreadPool(); std::vector> futs; futs.reserve(nq); @@ -272,7 +285,8 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector); + faiss::knn_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, &buf, + id_selector); } else { faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); } @@ -408,6 +422,13 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da std::vector> result_id_array(nq); std::vector> result_dist_array(nq); + std::unique_ptr norms = nullptr; + if (is_cosine) { + ThreadPool::ScopedSearchOmpSetter setter(1); + norms = std::make_unique(nb); + faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); + } + std::vector> futs; futs.reserve(nq); for (int i = 0; i < nq; ++i) { @@ -452,8 +473,8 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da auto cur_query = (const float*)xq + dim * index; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius, - &res, id_selector); + faiss::range_search_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, + radius, &res, id_selector); } else { faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, id_selector); @@ -678,6 +699,13 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da faiss::MetricType faiss_metric_type = result.value(); bool is_cosine = IsMetricType(metric_str, metric::COSINE); + std::unique_ptr norms = nullptr; + if (is_cosine) { + ThreadPool::ScopedSearchOmpSetter setter(1); + norms = std::make_unique(nb); + faiss::fvec_norms_L2(norms.get(), (float*)xb, dim, nb); + } + auto pool = ThreadPool::GetGlobalSearchThreadPool(); auto vec = std::vector(nq, nullptr); std::vector> futs; @@ -703,7 +731,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da auto cur_query = (const float*)xq + dim * index; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::all_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, distances_ids, + faiss::all_cosine(copied_query.get(), (const float*)xb, norms.get(), dim, 1, nb, distances_ids, id_selector); } else { faiss::all_inner_product(cur_query, (const float*)xb, dim, 1, nb, distances_ids, id_selector); diff --git a/src/simd/distances_avx.cc b/src/simd/distances_avx.cc index a2a8fd909..a57830a59 100644 --- a/src/simd/distances_avx.cc +++ b/src/simd/distances_avx.cc @@ -461,6 +461,43 @@ ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) { return res; } +float +fvec_norm_L2sqr_avx(const float* x, size_t d) { + __m256 msum_0 = _mm256_setzero_ps(); + __m256 msum_1 = _mm256_setzero_ps(); + while (d >= 16) { + auto mx_0 = _mm256_loadu_ps(x); + auto mx_1 = _mm256_loadu_ps(x + 8); + msum_0 = _mm256_fmadd_ps(mx_0, mx_0, msum_0); + msum_1 = _mm256_fmadd_ps(mx_1, mx_1, msum_1); + x += 16; + d -= 16; + } + msum_0 = msum_0 + msum_1; + if (d >= 8) { + auto mx = _mm256_loadu_ps(x); + msum_0 = _mm256_fmadd_ps(mx, mx, msum_0); + x += 8; + d -= 8; + } + if (d > 0) { + __m128 rest_0 = _mm_setzero_ps(); + __m128 rest_1 = _mm_setzero_ps(); + if (d >= 4) { + rest_0 = _mm_loadu_ps(x); + x += 4; + d -= 4; + } + if (d >= 0) { + rest_1 = masked_read(d, x); + } + auto mx = _mm256_set_m128(rest_0, rest_1); + msum_0 = _mm256_fmadd_ps(mx, mx, msum_0); + } + auto res = _mm256_reduce_add_ps(msum_0); + return res; +} + float fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d) { __m256 msum_0 = _mm256_setzero_ps(); diff --git a/src/simd/distances_avx.h b/src/simd/distances_avx.h index 060f33af9..441daee15 100644 --- a/src/simd/distances_avx.h +++ b/src/simd/distances_avx.h @@ -79,6 +79,9 @@ ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d); int32_t ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d); +float +fvec_norm_L2sqr_avx(const float* x, size_t d); + float fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d); diff --git a/src/simd/distances_avx512.cc b/src/simd/distances_avx512.cc index f61bb1a48..07b74195e 100644 --- a/src/simd/distances_avx512.cc +++ b/src/simd/distances_avx512.cc @@ -487,6 +487,33 @@ ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) { return res; } +float +fvec_norm_L2sqr_avx512(const float* x, size_t d) { + __m512 m512_res = _mm512_setzero_ps(); + __m512 m512_res_0 = _mm512_setzero_ps(); + while (d >= 32) { + auto mx_0 = _mm512_loadu_ps(x); + auto mx_1 = _mm512_loadu_ps(x + 16); + m512_res = _mm512_fmadd_ps(mx_0, mx_0, m512_res); + m512_res_0 = _mm512_fmadd_ps(mx_1, mx_1, m512_res_0); + x += 32; + d -= 32; + } + m512_res = m512_res + m512_res_0; + if (d >= 16) { + auto mx = _mm512_loadu_ps(x); + m512_res = _mm512_fmadd_ps(mx, mx, m512_res); + x += 16; + d -= 16; + } + if (d > 0) { + const __mmask16 mask = (1U << d) - 1U; + auto mx = _mm512_maskz_loadu_ps(mask, x); + m512_res = _mm512_fmadd_ps(mx, mx, m512_res); + } + return _mm512_reduce_add_ps(m512_res); +} + float fp16_vec_norm_L2sqr_avx512(const knowhere::fp16* x, size_t d) { __m512 m512_res = _mm512_setzero_ps(); diff --git a/src/simd/distances_avx512.h b/src/simd/distances_avx512.h index 956331d75..36096d005 100644 --- a/src/simd/distances_avx512.h +++ b/src/simd/distances_avx512.h @@ -78,6 +78,9 @@ ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d); int32_t ivec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d); +float +fvec_norm_L2sqr_avx512(const float* x, size_t d); + float fp16_vec_norm_L2sqr_avx512(const knowhere::fp16* x, size_t d); diff --git a/src/simd/hook.cc b/src/simd/hook.cc index a0167a4cb..d7cbd497c 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -185,7 +185,7 @@ fvec_hook(std::string& simd_type) { fvec_L1 = fvec_L1_avx512; fvec_Linf = fvec_Linf_avx512; - fvec_norm_L2sqr = fvec_norm_L2sqr_sse; + fvec_norm_L2sqr = fvec_norm_L2sqr_avx512; fvec_L2sqr_ny = fvec_L2sqr_ny_sse; fvec_inner_products_ny = fvec_inner_products_ny_sse; fvec_madd = fvec_madd_avx512; @@ -213,7 +213,7 @@ fvec_hook(std::string& simd_type) { fvec_L1 = fvec_L1_avx; fvec_Linf = fvec_Linf_avx; - fvec_norm_L2sqr = fvec_norm_L2sqr_sse; + fvec_norm_L2sqr = fvec_norm_L2sqr_avx; fvec_L2sqr_ny = fvec_L2sqr_ny_sse; fvec_inner_products_ny = fvec_inner_products_ny_sse; fvec_madd = fvec_madd_avx;