forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
faiss on rocksdb demo (facebookresearch#3216)
Summary: Pull Request resolved: facebookresearch#3216 Reviewed By: mdouze Differential Revision: D53051090 Pulled By: algoriddle fbshipit-source-id: 13a027db36207af9be11a2f181116994b2aff2cb
- Loading branch information
1 parent
c4b91a5
commit 51b6083
Showing
5 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
cmake_minimum_required(VERSION 3.17 FATAL_ERROR) | ||
project (ROCKSDB_IVF) | ||
set(CMAKE_BUILD_TYPE Debug) | ||
find_package(faiss REQUIRED) | ||
find_package(RocksDB REQUIRED) | ||
|
||
add_executable(demo_rocksdb_ivf demo_rocksdb_ivf.cpp RocksDBInvertedLists.cpp) | ||
target_link_libraries(demo_rocksdb_ivf faiss RocksDB::rocksdb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Storing Faiss inverted lists in RocksDB | ||
|
||
Demo of storing the inverted lists of any IVF index in RocksDB or any similar key-value store which supports the prefix scan operation. | ||
|
||
# How to build | ||
|
||
We use conda to create the build environment for simplicity. Only tested on Linux x86. | ||
|
||
``` | ||
conda create -n rocksdb_ivf | ||
conda activate rocksdb_ivf | ||
conda install pytorch::faiss-cpu conda-forge::rocksdb cmake make gxx_linux-64 sysroot_linux-64 | ||
cd ~/faiss/demos/rocksdb_ivf | ||
cmake -B build . | ||
make -C build -j$(nproc) | ||
``` | ||
|
||
# Run the example | ||
|
||
``` | ||
cd ~/faiss/demos/rocksdb_ivf/build | ||
./rocksdb_ivf test_db | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
#include "RocksDBInvertedLists.h" | ||
|
||
#include <faiss/impl/FaissAssert.h> | ||
|
||
using namespace faiss; | ||
|
||
namespace faiss_rocksdb { | ||
|
||
RocksDBInvertedListsIterator::RocksDBInvertedListsIterator( | ||
rocksdb::DB* db, | ||
size_t list_no, | ||
size_t code_size) | ||
: InvertedListsIterator(), | ||
it(db->NewIterator(rocksdb::ReadOptions())), | ||
list_no(list_no), | ||
code_size(code_size), | ||
codes(code_size) { | ||
it->Seek(rocksdb::Slice( | ||
reinterpret_cast<const char*>(&list_no), sizeof(size_t))); | ||
} | ||
|
||
bool RocksDBInvertedListsIterator::is_available() const { | ||
return it->Valid() && | ||
it->key().starts_with(rocksdb::Slice( | ||
reinterpret_cast<const char*>(&list_no), sizeof(size_t))); | ||
} | ||
|
||
void RocksDBInvertedListsIterator::next() { | ||
it->Next(); | ||
} | ||
|
||
std::pair<idx_t, const uint8_t*> RocksDBInvertedListsIterator:: | ||
get_id_and_codes() { | ||
idx_t id = | ||
*reinterpret_cast<const idx_t*>(&it->key().data()[sizeof(size_t)]); | ||
assert(code_size == it->value().size()); | ||
return {id, reinterpret_cast<const uint8_t*>(it->value().data())}; | ||
} | ||
|
||
RocksDBInvertedLists::RocksDBInvertedLists( | ||
const char* db_directory, | ||
size_t nlist, | ||
size_t code_size) | ||
: InvertedLists(nlist, code_size) { | ||
use_iterator = true; | ||
|
||
rocksdb::Options options; | ||
options.create_if_missing = true; | ||
rocksdb::DB* db; | ||
rocksdb::Status status = rocksdb::DB::Open(options, db_directory, &db); | ||
db_ = std::unique_ptr<rocksdb::DB>(db); | ||
assert(status.ok()); | ||
} | ||
|
||
size_t RocksDBInvertedLists::list_size(size_t /*list_no*/) const { | ||
FAISS_THROW_MSG("list_size is not supported"); | ||
} | ||
|
||
const uint8_t* RocksDBInvertedLists::get_codes(size_t /*list_no*/) const { | ||
FAISS_THROW_MSG("get_codes is not supported"); | ||
} | ||
|
||
const idx_t* RocksDBInvertedLists::get_ids(size_t /*list_no*/) const { | ||
FAISS_THROW_MSG("get_ids is not supported"); | ||
} | ||
|
||
size_t RocksDBInvertedLists::add_entries( | ||
size_t list_no, | ||
size_t n_entry, | ||
const idx_t* ids, | ||
const uint8_t* code) { | ||
rocksdb::WriteOptions wo; | ||
std::vector<char> key(sizeof(size_t) + sizeof(idx_t)); | ||
memcpy(key.data(), &list_no, sizeof(size_t)); | ||
for (size_t i = 0; i < n_entry; i++) { | ||
memcpy(key.data() + sizeof(size_t), ids + i, sizeof(idx_t)); | ||
rocksdb::Status status = db_->Put( | ||
wo, | ||
rocksdb::Slice(key.data(), key.size()), | ||
rocksdb::Slice( | ||
reinterpret_cast<const char*>(code + i * code_size), | ||
code_size)); | ||
assert(status.ok()); | ||
} | ||
return 0; // ignored | ||
} | ||
|
||
void RocksDBInvertedLists::update_entries( | ||
size_t /*list_no*/, | ||
size_t /*offset*/, | ||
size_t /*n_entry*/, | ||
const idx_t* /*ids*/, | ||
const uint8_t* /*code*/) { | ||
FAISS_THROW_MSG("update_entries is not supported"); | ||
} | ||
|
||
void RocksDBInvertedLists::resize(size_t /*list_no*/, size_t /*new_size*/) { | ||
FAISS_THROW_MSG("resize is not supported"); | ||
} | ||
|
||
InvertedListsIterator* RocksDBInvertedLists::get_iterator( | ||
size_t list_no) const { | ||
return new RocksDBInvertedListsIterator(db_.get(), list_no, code_size); | ||
} | ||
|
||
} // namespace faiss_rocksdb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
#pragma once | ||
|
||
#include <faiss/invlists/InvertedLists.h> | ||
|
||
#include <rocksdb/db.h> | ||
|
||
namespace faiss_rocksdb { | ||
|
||
struct RocksDBInvertedListsIterator : faiss::InvertedListsIterator { | ||
RocksDBInvertedListsIterator( | ||
rocksdb::DB* db, | ||
size_t list_no, | ||
size_t code_size); | ||
virtual bool is_available() const override; | ||
virtual void next() override; | ||
virtual std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes() override; | ||
|
||
private: | ||
std::unique_ptr<rocksdb::Iterator> it; | ||
size_t list_no; | ||
size_t code_size; | ||
std::vector<uint8_t> codes; // buffer for returning codes in next() | ||
}; | ||
|
||
struct RocksDBInvertedLists : faiss::InvertedLists { | ||
RocksDBInvertedLists( | ||
const char* db_directory, | ||
size_t nlist, | ||
size_t code_size); | ||
|
||
size_t list_size(size_t list_no) const override; | ||
const uint8_t* get_codes(size_t list_no) const override; | ||
const faiss::idx_t* get_ids(size_t list_no) const override; | ||
|
||
size_t add_entries( | ||
size_t list_no, | ||
size_t n_entry, | ||
const faiss::idx_t* ids, | ||
const uint8_t* code) override; | ||
|
||
void update_entries( | ||
size_t list_no, | ||
size_t offset, | ||
size_t n_entry, | ||
const faiss::idx_t* ids, | ||
const uint8_t* code) override; | ||
|
||
void resize(size_t list_no, size_t new_size) override; | ||
|
||
faiss::InvertedListsIterator* get_iterator(size_t list_no) const override; | ||
|
||
private: | ||
std::unique_ptr<rocksdb::DB> db_; | ||
}; | ||
|
||
} // namespace faiss_rocksdb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
#include <exception> | ||
#include <iostream> | ||
#include <memory> | ||
|
||
#include "RocksDBInvertedLists.h" | ||
|
||
#include <faiss/IndexFlat.h> | ||
#include <faiss/IndexIVFFlat.h> | ||
#include <faiss/impl/AuxIndexStructures.h> | ||
#include <faiss/impl/FaissException.h> | ||
#include <faiss/utils/random.h> | ||
|
||
using namespace faiss; | ||
|
||
int main(int argc, char* argv[]) { | ||
try { | ||
if (argc != 2) { | ||
std::cerr << "missing db directory argument" << std::endl; | ||
return -1; | ||
} | ||
size_t d = 128; | ||
size_t nlist = 100; | ||
IndexFlatL2 quantizer(d); | ||
IndexIVFFlat index(&quantizer, d, nlist); | ||
faiss_rocksdb::RocksDBInvertedLists ril( | ||
argv[1], nlist, index.code_size); | ||
index.replace_invlists(&ril, false); | ||
|
||
idx_t nb = 10000; | ||
std::vector<float> xb(d * nb); | ||
float_rand(xb.data(), d * nb, 12345); | ||
std::vector<idx_t> xids(nb); | ||
std::iota(xids.begin(), xids.end(), 0); | ||
|
||
index.train(nb, xb.data()); | ||
index.add_with_ids(nb, xb.data(), xids.data()); | ||
|
||
idx_t nq = 20; // nb; | ||
index.nprobe = 2; | ||
|
||
std::cout << "search" << std::endl; | ||
idx_t k = 5; | ||
std::vector<float> distances(nq * k); | ||
std::vector<idx_t> labels(nq * k, -1); | ||
index.search( | ||
nq, xb.data(), k, distances.data(), labels.data(), nullptr); | ||
|
||
for (idx_t iq = 0; iq < nq; iq++) { | ||
std::cout << iq << ": "; | ||
for (auto j = 0; j < k; j++) { | ||
std::cout << labels[iq * k + j] << " " << distances[iq * k + j] | ||
<< " | "; | ||
} | ||
std::cout << std::endl; | ||
} | ||
|
||
std::cout << std::endl << "range search" << std::endl; | ||
float range = 15.0f; | ||
RangeSearchResult result(nq); | ||
index.range_search(nq, xb.data(), range, &result); | ||
|
||
for (idx_t iq = 0; iq < nq; iq++) { | ||
std::cout << iq << ": "; | ||
for (auto j = result.lims[iq]; j < result.lims[iq + 1]; j++) { | ||
std::cout << result.labels[j] << " " << result.distances[j] | ||
<< " | "; | ||
} | ||
std::cout << std::endl; | ||
} | ||
|
||
} catch (FaissException& e) { | ||
std::cerr << e.what() << '\n'; | ||
} catch (std::exception& e) { | ||
std::cerr << e.what() << '\n'; | ||
} catch (...) { | ||
std::cerr << "Unrecognized exception!\n"; | ||
} | ||
return 0; | ||
} |