Skip to content

Commit

Permalink
FAISS_HNSW_SQ
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
  • Loading branch information
alexanderguzhva committed Aug 1, 2024
1 parent 35f2e3a commit ac1cbbe
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 9 deletions.
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE";
constexpr const char* INDEX_DISKANN = "DISKANN";

constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT";
constexpr const char* INDEX_FAISS_HNSW_SQ = "FAISS_HNSW_SQ";

constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND";
Expand Down
2 changes: 2 additions & 0 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
{IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16},
// faiss hnsw
{IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT},

{IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT},
// diskann
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16},
Expand Down
159 changes: 158 additions & 1 deletion src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
// yes, wrap both base and refine index
knowhere::IndexWrapperCosine cosine_wrapper(
index_refine->refine_index,
dynamic_cast<faiss::HasInverseL2Norms*>(index_hnsw)->get_inverse_l2_norms());
dynamic_cast<faiss::HasInverseL2Norms*>(index_hnsw->storage)->get_inverse_l2_norms());

// create a temporary refine index which does not own
faiss::IndexRefine tmp_refine(base_wrapper, &cosine_wrapper);
Expand Down Expand Up @@ -625,6 +625,163 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
}
};

//
template<typename DataType>
class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
public:
BaseFaissRegularIndexHNSWSQNode(const int32_t& version, const Object& object) :
BaseFaissRegularIndexHNSWNode(version, object) {}

std::unique_ptr<BaseConfig>
CreateConfig() const override {
return std::make_unique<FaissHnswSqConfig>();
}

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ;
}

protected:
Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override {
// number of rows
auto rows = dataset->GetRows();
// dimensionality of the data
auto dim = dataset->GetDim();
// data
auto data = dataset->GetTensor();

// config
auto hnsw_cfg = static_cast<const FaissHnswSqConfig&>(cfg);

auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value());
if (!metric.has_value()) {
LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
}

// parse a ScalarQuantizer type
auto sq_type = get_sq_quantizer_type(hnsw_cfg.sq_type.value());
if (!sq_type.has_value()) {
LOG_KNOWHERE_ERROR_ << "Invalid scalar quantizer type: " << hnsw_cfg.sq_type.value();
return Status::invalid_args;
}

// create an index
const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE);

std::unique_ptr<faiss::IndexHNSW> hnsw_index;
if (is_cosine) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(
dim, sq_type.value(), hnsw_cfg.M.value());
} else {
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(
dim, sq_type.value(), hnsw_cfg.M.value(), metric.value());
}

hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value();

// should refine be used?
std::unique_ptr<faiss::Index> final_index;
if (hnsw_cfg.refine.value_or(false)) {
// yes

// grab a type of a refine index
bool is_fp32_flat = false;
if (!hnsw_cfg.refine_type.has_value()) {
is_fp32_flat = true;
} else {
// todo: tolower
if (hnsw_cfg.refine_type.value() == "FP32" || hnsw_cfg.refine_type.value() == "FLAT") {
is_fp32_flat = true;
} else {
// parse
auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value());
if (!refine_sq_type.has_value()) {
LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value();
return Status::invalid_args;
}

// recognized one, it is not fp32 flat
is_fp32_flat = false;
}
}

// either build flat or sq
if (is_fp32_flat) {
// build IndexFlat as a refine
auto refine_index = std::make_unique<faiss::IndexRefineFlat>(hnsw_index.get());

// let refine_index to own everything
refine_index->own_fields = true;
hnsw_index.release();

// reassign
final_index = std::move(refine_index);
} else {
// being IndexScalarQuantizer as a refine
auto refine_sq_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value());

// a redundant check
if (!refine_sq_type.has_value()) {
LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << hnsw_cfg.refine_type.value();
return Status::invalid_args;
}

// create an sq
auto sq_refine = std::make_unique<faiss::IndexScalarQuantizer>(
dim, refine_sq_type.value(), metric.value()
);

auto refine_index = std::make_unique<faiss::IndexRefine>(hnsw_index.get(), sq_refine.get());

// let refine_index to own everything
refine_index->own_refine_index = true;
refine_index->own_fields = true;
hnsw_index.release();
sq_refine.release();

// reassign
final_index = std::move(refine_index);
}
} else {
// no refine

// reassign
final_index = std::move(hnsw_index);
}

// train
final_index->train(rows, (const float*)data);

// done
index = std::move(final_index);
return Status::success;
}

private:
expected<faiss::ScalarQuantizer::QuantizerType>
static get_sq_quantizer_type(const std::string& sq_type) {
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {
{"SQ6", faiss::ScalarQuantizer::QT_6bit},
{"SQ8", faiss::ScalarQuantizer::QT_8bit},
{"FP16", faiss::ScalarQuantizer::QT_fp16},
{"BF16", faiss::ScalarQuantizer::QT_bf16}
};

// todo: tolower
auto itr = sq_types.find(sq_type);
if (itr == sq_types.cend()) {
return expected<faiss::ScalarQuantizer::QuantizerType>::Err(
Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type));
}

return itr->second;
}
};


KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_FLAT, BaseFaissRegularIndexHNSWFlatNode, fp32);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNode, fp32);

} // namespace knowhere
91 changes: 90 additions & 1 deletion src/index/hnsw/faiss_hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class FaissHnswConfig : public BaseConfig {
CFG_BOOL refine;
// undefined value leads to a search without a refine
CFG_FLOAT refine_k;
// type of refine
CFG_STRING refine_type;

KNOHWERE_DECLARE_CONFIG(FaissHnswConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(M).description("hnsw M").set_default(30).set_range(2, 2048).for_train();
Expand Down Expand Up @@ -78,6 +80,10 @@ class FaissHnswConfig : public BaseConfig {
.allow_empty_without_default()
.set_range(1, std::numeric_limits<CFG_FLOAT::value_type>::max())
.for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(refine_type)
.description("the type of a refine index")
.allow_empty_without_default()
.for_train();
}

Status
Expand Down Expand Up @@ -107,6 +113,23 @@ class FaissHnswConfig : public BaseConfig {
}
return Status::success;
}

protected:
bool WhetherAcceptableRefineType(const std::string& refine_type) {
// 'flat' is identical to 'fp32'
std::vector<std::string> allowed_list = {
"SQ8", "FP16", "BF16", "FP32", "FLAT"};

// todo: tolower()

for (const auto& allowed : allowed_list) {
if (refine_type == allowed) {
return true;
}
}

return false;
}
};

class FaissHnswFlatConfig : public FaissHnswConfig {
Expand All @@ -121,14 +144,80 @@ class FaissHnswFlatConfig : public FaissHnswConfig {

// check our parameters
if (param_type == PARAM_TYPE::TRAIN) {
// prohibit refine
if (refine.value_or(false)) {
*err_msg = "refine is not currently supported for this index";
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_value_in_json;
}

if (refine_type.has_value()) {
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_value_in_json;
}
}
return Status::success;
}
};

class FaissHnswSqConfig : public FaissHnswConfig {
public:
// user can use quant_type to control quantizer type.
// we have fp16, bf16, etc, so '8', '4' and '6' is insufficient
CFG_STRING sq_type;
KNOHWERE_DECLARE_CONFIG(FaissHnswSqConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(sq_type)
.set_default("SQ8")
.description("scalar quantizer type")
.for_train();
};

Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
// check the base class
const auto base_status = FaissHnswConfig::CheckAndAdjust(param_type, err_msg);
if (base_status != Status::success) {
return base_status;
}

// check our parameters
if (param_type == PARAM_TYPE::TRAIN) {
auto sq_type_v = sq_type.value();
if (!WhetherAcceptableQuantType(sq_type_v)) {
*err_msg = "invalid scalar quantizer type";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_value_in_json;
}

// check refine
if (refine_type.has_value()) {
if (!WhetherAcceptableRefineType(refine_type.value())) {
*err_msg = "invalid refine type type";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_value_in_json;
}
}
}
return Status::success;
}

private:
bool WhetherAcceptableQuantType(const std::string& sq_type) {
// todo: add more
std::vector<std::string> allowed_list = {
"SQ6", "SQ8", "FP16", "BF16"};

// todo: tolower()

for (const auto& allowed : allowed_list) {
if (sq_type == allowed) {
return true;
}
}

return false;
}
};

} // namespace knowhere
Expand Down
61 changes: 61 additions & 0 deletions thirdparty/faiss/faiss/IndexCosine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,52 @@ FlatCodesDistanceComputer* IndexFlatCosine::get_FlatCodesDistanceComputer() cons
}


//////////////////////////////////////////////////////////////////////////////////

IndexScalarQuantizerCosine::IndexScalarQuantizerCosine(
int d,
ScalarQuantizer::QuantizerType qtype)
: IndexScalarQuantizer(d, qtype, MetricType::METRIC_INNER_PRODUCT) {
is_cosine = true;
}

IndexScalarQuantizerCosine::IndexScalarQuantizerCosine() : IndexScalarQuantizer() {
metric_type = MetricType::METRIC_INNER_PRODUCT;
is_cosine = true;
}

void IndexScalarQuantizerCosine::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
if (n == 0) {
return;
}

// todo aguzhva:
// it is a tricky situation at this moment, because IndexScalarQuantizerCosine
// contains duplicate norms (one in IndexFlatCodes and one in HasInverseL2Norms).
// Norms in IndexFlatCodes are going to be removed in the future.
IndexScalarQuantizer::add(n, x);
inverse_norms_storage.add(x, n, d);
}

void IndexScalarQuantizerCosine::reset() {
IndexScalarQuantizer::reset();
inverse_norms_storage.reset();
}

const float* IndexScalarQuantizerCosine::get_inverse_l2_norms() const {
return inverse_norms_storage.inverse_l2_norms.data();
}

DistanceComputer* IndexScalarQuantizerCosine::get_distance_computer() const {
return new WithCosineNormDistanceComputer(
this->get_inverse_l2_norms(),
this->d,
std::unique_ptr<faiss::DistanceComputer>(IndexScalarQuantizer::get_FlatCodesDistanceComputer())
);
}


//////////////////////////////////////////////////////////////////////////////////

//
Expand All @@ -314,6 +360,21 @@ IndexHNSWFlatCosine::IndexHNSWFlatCosine(int d, int M) :
is_trained = true;
}

//////////////////////////////////////////////////////////////////////////////////

//
IndexHNSWSQCosine::IndexHNSWSQCosine() = default;

IndexHNSWSQCosine::IndexHNSWSQCosine(
int d,
ScalarQuantizer::QuantizerType qtype,
int M) :
IndexHNSW(new IndexScalarQuantizerCosine(d, qtype), M)
{
is_trained = this->storage->is_trained;
own_fields = true;
}


}

Loading

0 comments on commit ac1cbbe

Please sign in to comment.