Skip to content

Commit

Permalink
faiss on rocksdb demo (facebookresearch#3216)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#3216

Reviewed By: mdouze

Differential Revision: D53051090

Pulled By: algoriddle

fbshipit-source-id: 13a027db36207af9be11a2f181116994b2aff2cb
  • Loading branch information
algoriddle authored and facebook-github-bot committed Jan 25, 2024
1 parent c4b91a5 commit 51b6083
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 0 deletions.
8 changes: 8 additions & 0 deletions demos/rocksdb_ivf/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
project (ROCKSDB_IVF)
set(CMAKE_BUILD_TYPE Debug)
find_package(faiss REQUIRED)
find_package(RocksDB REQUIRED)

add_executable(demo_rocksdb_ivf demo_rocksdb_ivf.cpp RocksDBInvertedLists.cpp)
target_link_libraries(demo_rocksdb_ivf faiss RocksDB::rocksdb)
23 changes: 23 additions & 0 deletions demos/rocksdb_ivf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Storing Faiss inverted lists in RocksDB

Demo of storing the inverted lists of any IVF index in RocksDB or any similar key-value store which supports the prefix scan operation.

# How to build

We use conda to create the build environment for simplicity. Only tested on Linux x86.

```
conda create -n rocksdb_ivf
conda activate rocksdb_ivf
conda install pytorch::faiss-cpu conda-forge::rocksdb cmake make gxx_linux-64 sysroot_linux-64
cd ~/faiss/demos/rocksdb_ivf
cmake -B build .
make -C build -j$(nproc)
```

# Run the example

```
cd ~/faiss/demos/rocksdb_ivf/build
./rocksdb_ivf test_db
```
108 changes: 108 additions & 0 deletions demos/rocksdb_ivf/RocksDBInvertedLists.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include "RocksDBInvertedLists.h"

#include <faiss/impl/FaissAssert.h>

using namespace faiss;

namespace faiss_rocksdb {

RocksDBInvertedListsIterator::RocksDBInvertedListsIterator(
rocksdb::DB* db,
size_t list_no,
size_t code_size)
: InvertedListsIterator(),
it(db->NewIterator(rocksdb::ReadOptions())),
list_no(list_no),
code_size(code_size),
codes(code_size) {
it->Seek(rocksdb::Slice(
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
}

bool RocksDBInvertedListsIterator::is_available() const {
return it->Valid() &&
it->key().starts_with(rocksdb::Slice(
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
}

void RocksDBInvertedListsIterator::next() {
it->Next();
}

std::pair<idx_t, const uint8_t*> RocksDBInvertedListsIterator::
get_id_and_codes() {
idx_t id =
*reinterpret_cast<const idx_t*>(&it->key().data()[sizeof(size_t)]);
assert(code_size == it->value().size());
return {id, reinterpret_cast<const uint8_t*>(it->value().data())};
}

RocksDBInvertedLists::RocksDBInvertedLists(
const char* db_directory,
size_t nlist,
size_t code_size)
: InvertedLists(nlist, code_size) {
use_iterator = true;

rocksdb::Options options;
options.create_if_missing = true;
rocksdb::DB* db;
rocksdb::Status status = rocksdb::DB::Open(options, db_directory, &db);
db_ = std::unique_ptr<rocksdb::DB>(db);
assert(status.ok());
}

size_t RocksDBInvertedLists::list_size(size_t /*list_no*/) const {
FAISS_THROW_MSG("list_size is not supported");
}

const uint8_t* RocksDBInvertedLists::get_codes(size_t /*list_no*/) const {
FAISS_THROW_MSG("get_codes is not supported");
}

const idx_t* RocksDBInvertedLists::get_ids(size_t /*list_no*/) const {
FAISS_THROW_MSG("get_ids is not supported");
}

size_t RocksDBInvertedLists::add_entries(
size_t list_no,
size_t n_entry,
const idx_t* ids,
const uint8_t* code) {
rocksdb::WriteOptions wo;
std::vector<char> key(sizeof(size_t) + sizeof(idx_t));
memcpy(key.data(), &list_no, sizeof(size_t));
for (size_t i = 0; i < n_entry; i++) {
memcpy(key.data() + sizeof(size_t), ids + i, sizeof(idx_t));
rocksdb::Status status = db_->Put(
wo,
rocksdb::Slice(key.data(), key.size()),
rocksdb::Slice(
reinterpret_cast<const char*>(code + i * code_size),
code_size));
assert(status.ok());
}
return 0; // ignored
}

void RocksDBInvertedLists::update_entries(
size_t /*list_no*/,
size_t /*offset*/,
size_t /*n_entry*/,
const idx_t* /*ids*/,
const uint8_t* /*code*/) {
FAISS_THROW_MSG("update_entries is not supported");
}

void RocksDBInvertedLists::resize(size_t /*list_no*/, size_t /*new_size*/) {
FAISS_THROW_MSG("resize is not supported");
}

InvertedListsIterator* RocksDBInvertedLists::get_iterator(
size_t list_no) const {
return new RocksDBInvertedListsIterator(db_.get(), list_no, code_size);
}

} // namespace faiss_rocksdb
58 changes: 58 additions & 0 deletions demos/rocksdb_ivf/RocksDBInvertedLists.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#pragma once

#include <faiss/invlists/InvertedLists.h>

#include <rocksdb/db.h>

namespace faiss_rocksdb {

struct RocksDBInvertedListsIterator : faiss::InvertedListsIterator {
RocksDBInvertedListsIterator(
rocksdb::DB* db,
size_t list_no,
size_t code_size);
virtual bool is_available() const override;
virtual void next() override;
virtual std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes() override;

private:
std::unique_ptr<rocksdb::Iterator> it;
size_t list_no;
size_t code_size;
std::vector<uint8_t> codes; // buffer for returning codes in next()
};

struct RocksDBInvertedLists : faiss::InvertedLists {
RocksDBInvertedLists(
const char* db_directory,
size_t nlist,
size_t code_size);

size_t list_size(size_t list_no) const override;
const uint8_t* get_codes(size_t list_no) const override;
const faiss::idx_t* get_ids(size_t list_no) const override;

size_t add_entries(
size_t list_no,
size_t n_entry,
const faiss::idx_t* ids,
const uint8_t* code) override;

void update_entries(
size_t list_no,
size_t offset,
size_t n_entry,
const faiss::idx_t* ids,
const uint8_t* code) override;

void resize(size_t list_no, size_t new_size) override;

faiss::InvertedListsIterator* get_iterator(size_t list_no) const override;

private:
std::unique_ptr<rocksdb::DB> db_;
};

} // namespace faiss_rocksdb
81 changes: 81 additions & 0 deletions demos/rocksdb_ivf/demo_rocksdb_ivf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <exception>
#include <iostream>
#include <memory>

#include "RocksDBInvertedLists.h"

#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissException.h>
#include <faiss/utils/random.h>

using namespace faiss;

int main(int argc, char* argv[]) {
try {
if (argc != 2) {
std::cerr << "missing db directory argument" << std::endl;
return -1;
}
size_t d = 128;
size_t nlist = 100;
IndexFlatL2 quantizer(d);
IndexIVFFlat index(&quantizer, d, nlist);
faiss_rocksdb::RocksDBInvertedLists ril(
argv[1], nlist, index.code_size);
index.replace_invlists(&ril, false);

idx_t nb = 10000;
std::vector<float> xb(d * nb);
float_rand(xb.data(), d * nb, 12345);
std::vector<idx_t> xids(nb);
std::iota(xids.begin(), xids.end(), 0);

index.train(nb, xb.data());
index.add_with_ids(nb, xb.data(), xids.data());

idx_t nq = 20; // nb;
index.nprobe = 2;

std::cout << "search" << std::endl;
idx_t k = 5;
std::vector<float> distances(nq * k);
std::vector<idx_t> labels(nq * k, -1);
index.search(
nq, xb.data(), k, distances.data(), labels.data(), nullptr);

for (idx_t iq = 0; iq < nq; iq++) {
std::cout << iq << ": ";
for (auto j = 0; j < k; j++) {
std::cout << labels[iq * k + j] << " " << distances[iq * k + j]
<< " | ";
}
std::cout << std::endl;
}

std::cout << std::endl << "range search" << std::endl;
float range = 15.0f;
RangeSearchResult result(nq);
index.range_search(nq, xb.data(), range, &result);

for (idx_t iq = 0; iq < nq; iq++) {
std::cout << iq << ": ";
for (auto j = result.lims[iq]; j < result.lims[iq + 1]; j++) {
std::cout << result.labels[j] << " " << result.distances[j]
<< " | ";
}
std::cout << std::endl;
}

} catch (FaissException& e) {
std::cerr << e.what() << '\n';
} catch (std::exception& e) {
std::cerr << e.what() << '\n';
} catch (...) {
std::cerr << "Unrecognized exception!\n";
}
return 0;
}

0 comments on commit 51b6083

Please sign in to comment.