Skip to content

Commit

Permalink
HNSW speedup + Distance 4 points (facebookresearch#2841)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2841

* Add virtual void DistanceComputer::distances_to_four_indices()
* Add the infrastructure
* HNSW::search() uses DistanceComputer::distances_to_four_indices()
* Add IndexFlatL2::sync_l2norms() and IndexFlatL2::clear_l2norms() that allow to precompute L2 cache for stored vectors and compute L2 distance using dot product
* Add downcasting of IndexFlatL2 and IndexFlatIP in swig
* Add general-purpose prefetch utilities

Reviewed By: mdouze

Differential Revision: D45427064

fbshipit-source-id: d23b34fe080dbff951d34cdc1323813bd3b828e0
  • Loading branch information
Alexandr Guzhva authored and facebook-github-bot committed May 5, 2023
1 parent f276c47 commit 5b17225
Show file tree
Hide file tree
Showing 12 changed files with 598 additions and 31 deletions.
1 change: 1 addition & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ set(FAISS_HEADERS
utils/hamming.h
utils/ordered_key_value.h
utils/partitioning.h
utils/prefetch.h
utils/quantize_lut.h
utils/random.h
utils/simdlib.h
Expand Down
198 changes: 195 additions & 3 deletions faiss/IndexFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/extra_distances.h>
#include <faiss/utils/prefetch.h>
#include <faiss/utils/sorting.h>
#include <faiss/utils/utils.h>
#include <cstring>
Expand Down Expand Up @@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
void set_query(const float* x) override {
q = x;
}

// compute four distances
void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) final override {
ndis += 4;

// compute first, assign next
const float* __restrict y0 =
reinterpret_cast<const float*>(codes + idx0 * code_size);
const float* __restrict y1 =
reinterpret_cast<const float*>(codes + idx1 * code_size);
const float* __restrict y2 =
reinterpret_cast<const float*>(codes + idx2 * code_size);
const float* __restrict y3 =
reinterpret_cast<const float*>(codes + idx3 * code_size);

float dp0 = 0;
float dp1 = 0;
float dp2 = 0;
float dp3 = 0;
fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
dis0 = dp0;
dis1 = dp1;
dis2 = dp2;
dis3 = dp3;
}
};

struct FlatIPDis : FlatCodesDistanceComputer {
Expand All @@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
const float* b;
size_t ndis;

float symmetric_dis(idx_t i, idx_t j) override {
float symmetric_dis(idx_t i, idx_t j) final override {
return fvec_inner_product(b + j * d, b + i * d, d);
}

float distance_to_code(const uint8_t* code) final {
float distance_to_code(const uint8_t* code) final override {
ndis++;
return fvec_inner_product(q, (float*)code, d);
return fvec_inner_product(q, (const float*)code, d);
}

explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
Expand All @@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
void set_query(const float* x) override {
q = x;
}

// compute four distances
void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) final override {
ndis += 4;

// compute first, assign next
const float* __restrict y0 =
reinterpret_cast<const float*>(codes + idx0 * code_size);
const float* __restrict y1 =
reinterpret_cast<const float*>(codes + idx1 * code_size);
const float* __restrict y2 =
reinterpret_cast<const float*>(codes + idx2 * code_size);
const float* __restrict y3 =
reinterpret_cast<const float*>(codes + idx3 * code_size);

float dp0 = 0;
float dp1 = 0;
float dp2 = 0;
float dp3 = 0;
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
dis0 = dp0;
dis1 = dp1;
dis2 = dp2;
dis3 = dp3;
}
};

} // namespace
Expand Down Expand Up @@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
}
}

/***************************************************
* IndexFlatL2
***************************************************/

namespace {
struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
size_t d;
idx_t nb;
const float* q;
const float* b;
size_t ndis;

const float* l2norms;
float query_l2norm;

float distance_to_code(const uint8_t* code) final override {
ndis++;
return fvec_L2sqr(q, (float*)code, d);
}

float operator()(const idx_t i) final override {
const float* __restrict y =
reinterpret_cast<const float*>(codes + i * code_size);

prefetch_L2(l2norms + i);
const float dp0 = fvec_inner_product(q, y, d);
return query_l2norm + l2norms[i] - 2 * dp0;
}

float symmetric_dis(idx_t i, idx_t j) final override {
const float* __restrict yi =
reinterpret_cast<const float*>(codes + i * code_size);
const float* __restrict yj =
reinterpret_cast<const float*>(codes + j * code_size);

prefetch_L2(l2norms + i);
prefetch_L2(l2norms + j);
const float dp0 = fvec_inner_product(yi, yj, d);
return l2norms[i] + l2norms[j] - 2 * dp0;
}

explicit FlatL2WithNormsDis(
const IndexFlatL2& storage,
const float* q = nullptr)
: FlatCodesDistanceComputer(
storage.codes.data(),
storage.code_size),
d(storage.d),
nb(storage.ntotal),
q(q),
b(storage.get_xb()),
ndis(0),
l2norms(storage.cached_l2norms.data()),
query_l2norm(0) {}

void set_query(const float* x) override {
q = x;
query_l2norm = fvec_norm_L2sqr(q, d);
}

// compute four distances
void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) final override {
ndis += 4;

// compute first, assign next
const float* __restrict y0 =
reinterpret_cast<const float*>(codes + idx0 * code_size);
const float* __restrict y1 =
reinterpret_cast<const float*>(codes + idx1 * code_size);
const float* __restrict y2 =
reinterpret_cast<const float*>(codes + idx2 * code_size);
const float* __restrict y3 =
reinterpret_cast<const float*>(codes + idx3 * code_size);

prefetch_L2(l2norms + idx0);
prefetch_L2(l2norms + idx1);
prefetch_L2(l2norms + idx2);
prefetch_L2(l2norms + idx3);

float dp0 = 0;
float dp1 = 0;
float dp2 = 0;
float dp3 = 0;
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
}
};

} // namespace

void IndexFlatL2::sync_l2norms() {
cached_l2norms.resize(ntotal);
fvec_norms_L2sqr(
cached_l2norms.data(),
reinterpret_cast<const float*>(codes.data()),
d,
ntotal);
}

void IndexFlatL2::clear_l2norms() {
cached_l2norms.clear();
cached_l2norms.shrink_to_fit();
}

FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
if (metric_type == METRIC_L2) {
if (!cached_l2norms.empty()) {
return new FlatL2WithNormsDis(*this);
}
}

return IndexFlat::get_FlatCodesDistanceComputer();
}

/***************************************************
* IndexFlat1D
***************************************************/
Expand Down
14 changes: 14 additions & 0 deletions faiss/IndexFlat.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,22 @@ struct IndexFlatIP : IndexFlat {
};

struct IndexFlatL2 : IndexFlat {
// Special cache for L2 norms.
// If this cache is set, then get_distance_computer() returns
// a special version that computes the distance using dot products
// and l2 norms.
std::vector<float> cached_l2norms;

explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
IndexFlatL2() {}

// override for l2 norms cache.
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;

// compute L2 norms
void sync_l2norms();
// clear L2 norms
void clear_l2norms();
};

/// optimized version for 1D "vectors".
Expand Down
5 changes: 4 additions & 1 deletion faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,10 @@ IndexHNSWFlat::IndexHNSWFlat() {
}

IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
: IndexHNSW(new IndexFlat(d, metric), M) {
: IndexHNSW(
(metric == METRIC_L2) ? new IndexFlatL2(d)
: new IndexFlat(d, metric),
M) {
own_fields = true;
is_trained = true;
}
Expand Down
25 changes: 24 additions & 1 deletion faiss/impl/DistanceComputer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@ struct DistanceComputer {
/// compute distance of vector i to current query
virtual float operator()(idx_t i) = 0;

/// compute distances of current query to 4 stored vectors.
/// certain DistanceComputer implementations may benefit
/// heavily from this.
virtual void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) {
// compute first, assign next
const float d0 = this->operator()(idx0);
const float d1 = this->operator()(idx1);
const float d2 = this->operator()(idx2);
const float d3 = this->operator()(idx3);
dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}

/// compute distance between two stored vectors
virtual float symmetric_dis(idx_t i, idx_t j) = 0;

Expand All @@ -49,7 +72,7 @@ struct FlatCodesDistanceComputer : DistanceComputer {

FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}

float operator()(idx_t i) final {
float operator()(idx_t i) override {
return distance_to_code(codes + i * code_size);
}

Expand Down
Loading

0 comments on commit 5b17225

Please sign in to comment.