From 1ac95d234e0aa92f7680eed65495e41a5ed26dbf Mon Sep 17 00:00:00 2001 From: Yash Patel <47032340+yashpatel007@users.noreply.github.com> Date: Thu, 22 Jun 2023 16:38:59 -0400 Subject: [PATCH] patelyash/index factory (#340) * gi# This is a combination of 2 commits. remove _u, _s typedefs * added some seed files * add seed files * New distance metric hierarchy * Refactoring changes * Fixing compile errors in refactored code * Fixing compile errors * DiskANN Builds with initial refactoring changes * Saving changes for Ravi * More refactoring * Refactor * Fixed most of the bugs related to _data * add seed files * gi# This is a combination of 2 commits. remove _u, _s typedefs * added some seed files * New distance metric hierarchy * Refactoring changes * Fixing compile errors in refactored code * Fixing compile errors * DiskANN Builds with initial refactoring changes * Saving changes for Ravi * More refactoring * Refactor * Fixed most of the bugs related to _data * Post merge with main * Refactored version which compiles on Windows * now compiles on linux * minor clean-up * minor bug fix * minor bug * clang format fix + build error fix * clang format fix * minor changes * added back the fast_l2 feature * added back set_start_points in index.cpp * Version for review * Incorporating Harsha's comments - 2 * move implementation of abstract data store methods to a cpp file * clang format * clang format * Added slot manager file (empty) and fixed compile errors * fixed a linux compile error * clang * debugging workflow failure * clang * more debug * more debug * debug for workflow * remove slot manager * Removed the #ifdef WINDOWS directive from class definitions * Refactoring alignment factor into distance hierarchy * Fixing cosine distance * Ensuring we call preprocess_query always * Fixed distance invocations * fixed cosine bug, clang-formatted * cleaned up and added comments * clang-formatted * more clang-format * clang-format 3 * remove deleted code in scratch.cpp * reverted clang to Microsoft * small change * Removed slot_manager from this PR * newline at EOF in_mem_Graph_store.cpp * rename distance_metric to distance_fn * resolving PR comments * minor bug fix for initialization * creating index_factory * using index factory to build inmem index * clang format fix * minor bug fix * fixing build error * replacing mem_store with abstract_mem_store + injecting data_store to Index * minor fix * clang format fix * commenting data_store injection to prevent double invocation and mem leak (for now) * fixing the build for fiters * moving abstract index to abstract_index.h * IndexBuildParamsbuilder to build IndexBuildParams properly with error checking * fixing build errors * fixing minor error * refactoring index search to be simple * clang format fix * refactoring search_mem_index to use index factory * clang fix * minor fix * minor fix for build * optimize for fast l2 restore * removing comments * removing comments * adding templating to IndexFactory (can't avoide it anymore) * fixing build error * fixing ubuntu build error * ubuntu build exception fix * passing num_pq_bytes * giving one more shot to config dricen arch with boost::any (type erasure) * clang fix * modifying search to use boost::any * fixing ubuntu build errors/warning * created indexconfigbuilder and fixed a typo * fixing error in pq build * some comments + lazy_delete impl * bumping to std c++17 & replacing boost::any with std::any * clang fix * c++ std 17 for ubuntu * minor fix * converting search to batch_search + A vector wrapper using std::any to store vector as a shared ptr * adding AnyVector to encapsulate vector in std::any + adding basic yaml parser(WIP) * adding wrapper code for vector and set, checked with Andrija * fixinh ubuntu build error * trying to resolve ubuntu build error * testing test streaming index with IndexFactory * fixing ubuntu build error * fixing search for test insert delete consolidate * refactored test_streaming_scenario * refactored test_insert_delete_consolidate to use AbstractIndex and Indexfactory * fixing ubuntu build error * making build method in abstract index consistent * some code cleanup + abstract_cpp to add implementation * remoing coments and code cleanup * build error fix * fixing -Wreorder warning * separating build structs to their header + refactor search and remove batch search * fixing ubuntu build errors * resolving segfault error from search_mem_index * fixing query_result_tag allocation * minor update * search fix * trying to fix windows latest build for dynamic index * ading temp loggin to debug windows latest build issue * removing logging for debug * fixning windows latest build error for dynamix index search * moving any wrappers to separate file + organizing code * fixing check error * updating private vsr naming convention * minor update * unravelig search methods in abstract index. Iteraton 1 * minor fix * unused vars remove * returning a unique_ptr to Abstract Index from index factory * adding implementation from abstract_index.h to abstract_index.cpp * making abstract index api to be more explicit (expriment) * some code cleanup * removing detected memory leaks (free up index) * separtaing enums for data and graph stratagy * Index ctor(config) now uses injected datastore from IndexFactory * distance in index population in new config ctor * resolving some comments from Andrija * Resolving some restructuring comments by Andrija * minor fix * fixing ubuntu build error * warning fix * simplified get() in anywrappers * making index config a unique ptr and owned by IndexFactory * removing complex if/else calling recursively + added unimplemented TagT to AbsIdx * renaming get_instance to create_instance * clang format fix * removing const_cast from any_wrapper * fixing andrija's comments * removing warnings --------- Co-authored-by: harsha vardhan simhadri Co-authored-by: Gopal Srinivasa Co-authored-by: ravishankar Co-authored-by: Harsha Vardhan Simhadri --- apps/build_memory_index.cpp | 77 +++--- apps/search_memory_index.cpp | 67 ++--- apps/test_insert_deletes_consolidate.cpp | 86 +++--- apps/test_streaming_scenario.cpp | 57 ++-- include/abstract_data_store.h | 3 + include/abstract_index.h | 118 ++++++++ include/any_wrappers.h | 53 ++++ include/in_mem_data_store.h | 2 +- include/index.h | 65 +++-- include/index_build_params.h | 72 +++++ include/index_config.h | 224 +++++++++++++++ include/index_factory.h | 37 +++ include/types.h | 11 +- include/utils.h | 59 ++++ src/CMakeLists.txt | 2 +- src/abstract_index.cpp | 280 +++++++++++++++++++ src/dll/CMakeLists.txt | 3 +- src/index.cpp | 329 +++++++++++++++++++++-- src/index_factory.cpp | 150 +++++++++++ 19 files changed, 1520 insertions(+), 175 deletions(-) create mode 100644 include/abstract_index.h create mode 100644 include/any_wrappers.h create mode 100644 include/index_build_params.h create mode 100644 include/index_config.h create mode 100644 include/index_factory.h create mode 100644 src/abstract_index.cpp create mode 100644 src/index_factory.cpp diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 3712350c3..8d483f5c4 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -17,6 +17,7 @@ #include "memory_mapper.h" #include "ann_exception.h" +#include "index_factory.h" namespace po = boost::program_options; @@ -155,46 +156,42 @@ int main(int argc, char **argv) { diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha << " #threads: " << num_threads << std::endl; - if (label_file != "" && label_type == "ushort") - { - if (data_type == std::string("int8")) - return build_in_memory_index( - metric, data_path, R, L, alpha, index_path_prefix, num_threads, use_pq_build, build_PQ_bytes, - use_opq, label_file, universal_label, Lf); - else if (data_type == std::string("uint8")) - return build_in_memory_index( - metric, data_path, R, L, alpha, index_path_prefix, num_threads, use_pq_build, build_PQ_bytes, - use_opq, label_file, universal_label, Lf); - else if (data_type == std::string("float")) - return build_in_memory_index( - metric, data_path, R, L, alpha, index_path_prefix, num_threads, use_pq_build, build_PQ_bytes, - use_opq, label_file, universal_label, Lf); - else - { - std::cout << "Unsupported type. Use one of int8, uint8 or float." << std::endl; - return -1; - } - } - else - { - if (data_type == std::string("int8")) - return build_in_memory_index(metric, data_path, R, L, alpha, index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, - Lf); - else if (data_type == std::string("uint8")) - return build_in_memory_index(metric, data_path, R, L, alpha, index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq, label_file, - universal_label, Lf); - else if (data_type == std::string("float")) - return build_in_memory_index(metric, data_path, R, L, alpha, index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, - Lf); - else - { - std::cout << "Unsupported type. Use one of int8, uint8 or float." << std::endl; - return -1; - } - } + + size_t data_num, data_dim; + diskann::get_bin_metadata(data_path, data_num, data_dim); + + auto config = diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(data_dim) + .with_max_points(data_num) + .with_data_load_store_strategy(diskann::MEMORY) + .with_data_type(data_type) + .with_label_type(label_type) + .is_dynamic_index(false) + .is_enable_tags(false) + .is_use_opq(use_opq) + .is_pq_dist_build(use_pq_build) + .with_num_pq_chunks(build_PQ_bytes) + .build(); + + auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) + .with_filter_list_size(Lf) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + auto build_params = diskann::IndexBuildParamsBuilder(index_build_params) + .with_universal_label(universal_label) + .with_label_file(label_file) + .with_save_path_prefix(index_path_prefix) + .build(); + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->build(data_path, data_num, build_params); + index->save(index_path_prefix.c_str()); + index.reset(); + return 0; } catch (const std::exception &e) { diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index bd5c867a0..4af0266d6 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -20,6 +20,7 @@ #include "index.h" #include "memory_mapper.h" #include "utils.h" +#include "index_factory.h" namespace po = boost::program_options; @@ -30,6 +31,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, const bool dynamic, const bool tags, const bool show_qps_per_thread, const std::vector &query_filters, const float fail_if_recall_below) { + using TagT = uint32_t; // Load the query file T *query = nullptr; uint32_t *gt_ids = nullptr; @@ -37,7 +39,6 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); - // Check for ground truth bool calc_recall_flag = false; if (truthset_file != std::string("null") && file_exists(truthset_file)) { @@ -66,18 +67,32 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } } - using TagT = uint32_t; - const bool concurrent = false, pq_dist_build = false, use_opq = false; - const size_t num_pq_chunks = 0; - using IndexType = diskann::Index; - const size_t num_frozen_pts = IndexType::get_graph_num_frozen_points(index_path); - IndexType index(metric, query_dim, 0, dynamic, tags, concurrent, pq_dist_build, num_pq_chunks, use_opq, - num_frozen_pts); - std::cout << "Index class instantiated" << std::endl; - index.load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); + const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); + + auto config = diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(query_dim) + .with_max_points(0) + .with_data_load_store_strategy(diskann::MEMORY) + .with_data_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_tag_type(diskann_type_to_name()) + .is_dynamic_index(dynamic) + .is_enable_tags(tags) + .is_concurrent_consolidate(false) + .is_pq_dist_build(false) + .is_use_opq(false) + .with_num_pq_chunks(0) + .with_num_frozen_pts(num_frozen_pts) + .build(); + + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); std::cout << "Index loaded" << std::endl; + if (metric == diskann::FAST_L2) - index.optimize_index_layout(); + index->optimize_index_layout(); std::cout << "Using " << num_threads << " threads to search" << std::endl; std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); @@ -148,29 +163,22 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, auto qs = std::chrono::high_resolution_clock::now(); if (filtered_search) { - LabelT filter_label_as_num; - if (query_filters.size() == 1) - { - filter_label_as_num = index.get_converted_label(query_filters[0]); - } - else - { - filter_label_as_num = index.get_converted_label(query_filters[i]); - } - auto retval = index.search_with_filters(query + i * query_aligned_dim, filter_label_as_num, recall_at, - L, query_result_ids[test_id].data() + i * recall_at, - query_result_dists[test_id].data() + i * recall_at); + std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, + query_result_ids[test_id].data() + i * recall_at, + query_result_dists[test_id].data() + i * recall_at); cmp_stats[i] = retval.second; } else if (metric == diskann::FAST_L2) { - index.search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at); + index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at); } else if (tags) { - index.search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res); + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res); for (int64_t r = 0; r < (int64_t)recall_at; r++) { query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; @@ -179,8 +187,8 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, else { cmp_stats[i] = index - .search(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at) + ->search(query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at) .second; } auto qe = std::chrono::high_resolution_clock::now(); @@ -245,7 +253,6 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } diskann::aligned_free(query); - return best_recall >= fail_if_recall_below ? 0 : -1; } diff --git a/apps/test_insert_deletes_consolidate.cpp b/apps/test_insert_deletes_consolidate.cpp index ebfd7cabe..4d64de3a5 100644 --- a/apps/test_insert_deletes_consolidate.cpp +++ b/apps/test_insert_deletes_consolidate.cpp @@ -11,6 +11,7 @@ #include #include "utils.h" +#include "index_factory.h" #ifndef _WINDOWS #include @@ -90,8 +91,8 @@ std::string get_save_filename(const std::string &save_path, size_t points_to_ski } template -void insert_till_next_checkpoint(diskann::Index &index, size_t start, size_t end, int32_t thread_count, - T *data, size_t aligned_dim) +void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data, + size_t aligned_dim) { diskann::Timer insert_timer; @@ -106,7 +107,7 @@ void insert_till_next_checkpoint(diskann::Index &index, size_t start, s } template -void delete_from_beginning(diskann::Index &index, diskann::IndexWriteParameters &delete_params, +void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t points_to_skip, size_t points_to_delete_from_beginning) { try @@ -135,26 +136,39 @@ void delete_from_beginning(diskann::Index &index, diskann::IndexWritePa } template -void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha, - const uint32_t thread_count, size_t points_to_skip, size_t max_points_to_insert, - size_t beginning_index_size, float start_point_norm, uint32_t num_start_pts, - size_t points_per_checkpoint, size_t checkpoints_per_snapshot, +void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip, + size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm, + uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot, const std::string &save_path, size_t points_to_delete_from_beginning, size_t start_deletes_after, bool concurrent) { - diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) - .with_max_occlusion_size(500) // C = 500 - .with_alpha(alpha) - .with_num_threads(thread_count) - .with_num_frozen_points(num_start_pts) - .build(); - size_t dim, aligned_dim; size_t num_points; - diskann::get_bin_metadata(data_path, num_points, dim); aligned_dim = ROUND_UP(dim, 8); + bool enable_tags = true; + using TagT = uint32_t; + auto data_type = diskann_type_to_name(); + auto tag_type = diskann_type_to_name(); + diskann::IndexConfig index_config = diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(max_points_to_insert) + .is_dynamic_index(true) + .with_index_write_params(params) + .with_search_threads(params.num_threads) + .with_initial_search_list_size(params.search_list_size) + .with_data_type(data_type) + .with_tag_type(tag_type) + .with_data_load_store_strategy(diskann::MEMORY) + .is_enable_tags(enable_tags) + .is_concurrent_consolidate(concurrent) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + if (points_to_skip > num_points) { throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -172,12 +186,6 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con << " points since the data file has only that many" << std::endl; } - using TagT = uint32_t; - const bool enable_tags = true; - - diskann::Index index(diskann::L2, dim, max_points_to_insert, true, params, L, thread_count, enable_tags, - concurrent); - size_t current_point_offset = points_to_skip; const size_t last_point_threshold = points_to_skip + max_points_to_insert; @@ -206,13 +214,11 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con if (beginning_index_size > 0) { - index.build(data, beginning_index_size, params, tags); - index.enable_delete(); + index->build(data, beginning_index_size, params, tags); } else { - index.set_start_points_at_random(static_cast(start_point_norm)); - index.enable_delete(); + index->set_start_points_at_random(static_cast(start_point_norm)); } const double elapsedSeconds = timer.elapsed() / 1000000.0; @@ -230,7 +236,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con if (concurrent) { - int32_t sub_threads = (thread_count + 1) / 2; + int32_t sub_threads = (params.num_threads + 1) / 2; bool delete_launched = false; std::future delete_task; @@ -244,7 +250,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con auto insert_task = std::async(std::launch::async, [&]() { load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint(index, start, end, sub_threads, data, aligned_dim); + insert_till_next_checkpoint(*index, start, end, sub_threads, data, aligned_dim); }); insert_task.wait(); @@ -256,7 +262,8 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build(); delete_task = std::async(std::launch::async, [&]() { - delete_from_beginning(index, delete_params, points_to_skip, points_to_delete_from_beginning); + delete_from_beginning(*index, delete_params, points_to_skip, + points_to_delete_from_beginning); }); } } @@ -265,7 +272,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip, points_to_delete_from_beginning, last_point_threshold); - index.save(save_path_inc.c_str(), true); + index->save(save_path_inc.c_str(), true); } else { @@ -279,7 +286,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint(index, start, end, (int32_t)thread_count, data, aligned_dim); + insert_till_next_checkpoint(*index, start, end, (int32_t)params.num_threads, data, aligned_dim); if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0) { @@ -287,7 +294,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end); - index.save(save_path_inc.c_str(), false); + index->save(save_path_inc.c_str(), false); const double elapsedSeconds = save_timer.elapsed() / 1000000.0; const size_t points_saved = end - points_to_skip; @@ -310,11 +317,11 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con if (points_to_delete_from_beginning > 0) { - delete_from_beginning(index, params, points_to_skip, points_to_delete_from_beginning); + delete_from_beginning(*index, params, points_to_skip, points_to_delete_from_beginning); } const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip, points_to_delete_from_beginning, last_point_threshold); - index.save(save_path_inc.c_str(), true); + index->save(save_path_inc.c_str(), true); } diskann::aligned_free(data); @@ -398,18 +405,25 @@ int main(int argc, char **argv) try { + diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(500) + .with_alpha(alpha) + .with_num_threads(num_threads) + .with_num_frozen_points(num_start_pts) + .build(); + if (data_type == std::string("int8")) - build_incremental_index(data_path, L, R, alpha, num_threads, points_to_skip, max_points_to_insert, + build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, start_deletes_after, concurrent); else if (data_type == std::string("uint8")) - build_incremental_index(data_path, L, R, alpha, num_threads, points_to_skip, max_points_to_insert, + build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, start_deletes_after, concurrent); else if (data_type == std::string("float")) - build_incremental_index(data_path, L, R, alpha, num_threads, points_to_skip, max_points_to_insert, + build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, start_deletes_after, concurrent); diff --git a/apps/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp index 3281a0573..c48c74843 100644 --- a/apps/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "utils.h" @@ -81,8 +83,8 @@ std::string get_save_filename(const std::string &save_path, size_t active_window return final_path; } -template -void insert_next_batch(diskann::Index &index, size_t start, size_t end, size_t insert_threads, T *data, +template +void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data, size_t aligned_dim) { try @@ -113,9 +115,9 @@ void insert_next_batch(diskann::Index &index, size_t start, siz } } -template -void delete_and_consolidate(diskann::Index &index, diskann::IndexWriteParameters &delete_params, - size_t start, size_t end) +template +void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start, + size_t end) { try { @@ -149,7 +151,7 @@ void delete_and_consolidate(diskann::Index &index, diskann::Ind } auto points_processed = report._active_points + report._slots_released; auto deletion_rate = points_processed / report._time; - std::cout << "#active points: " << report._active_points << std::endl + std::cout << "#active points: " << report._active_points << std::endl << "max points: " << report._max_points << std::endl << "empty slots: " << report._empty_slots << std::endl << "deletes processed: " << report._slots_released << std::endl @@ -172,6 +174,8 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con { const uint32_t C = 500; const bool saturate_graph = false; + using TagT = uint32_t; + using LabelT = uint32_t; diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) .with_max_occlusion_size(C) @@ -196,6 +200,27 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con << std::endl; aligned_dim = ROUND_UP(dim, 8); + auto index_config = diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(active_window + 4 * consolidate_interval) + .is_dynamic_index(true) + .is_enable_tags(true) + .is_use_opq(false) + .with_num_pq_chunks(0) + .is_pq_dist_build(false) + .with_search_threads(insert_threads) + .with_initial_search_list_size(L) + .with_tag_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_data_type(diskann_type_to_name()) + .with_index_write_params(params) + .with_data_load_store_strategy(diskann::MEMORY) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + if (max_points_to_insert == 0) { max_points_to_insert = num_points; @@ -214,14 +239,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con if (consolidate_interval < max_points_to_insert / 1000) throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__); - using TagT = uint32_t; - using LabelT = uint32_t; - const bool enable_tags = true; - - diskann::Index index(diskann::L2, dim, active_window + 4 * consolidate_interval, true, params, L, - insert_threads, enable_tags, true); - index.set_start_points_at_random(static_cast(start_point_norm)); - index.enable_delete(); + index->set_start_points_at_random(static_cast(start_point_norm)); T *data = nullptr; diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T), @@ -236,7 +254,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con auto insert_task = std::async(std::launch::async, [&]() { load_aligned_bin_part(data_path, data, 0, active_window); - insert_next_batch(index, 0, active_window, insert_threads, data, aligned_dim); + insert_next_batch(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim); }); insert_task.wait(); @@ -246,7 +264,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con auto end = std::min(start + consolidate_interval, max_points_to_insert); auto insert_task = std::async(std::launch::async, [&]() { load_aligned_bin_part(data_path, data, start, end - start); - insert_next_batch(index, start, end, insert_threads, data, aligned_dim); + insert_next_batch(*index, start, end, params.num_threads, data, aligned_dim); }); insert_task.wait(); @@ -257,8 +275,9 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con auto start_del = start - active_window - consolidate_interval; auto end_del = start - active_window; - delete_tasks.emplace_back(std::async( - std::launch::async, [&]() { delete_and_consolidate(index, delete_params, start_del, end_del); })); + delete_tasks.emplace_back(std::async(std::launch::async, [&]() { + delete_and_consolidate(*index, delete_params, (size_t)start_del, (size_t)end_del); + })); } } if (delete_tasks.size() > 0) @@ -267,7 +286,7 @@ void build_incremental_index(const std::string &data_path, const uint32_t L, con std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; const auto save_path_inc = get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert); - index.save(save_path_inc.c_str(), true); + index->save(save_path_inc.c_str(), true); diskann::aligned_free(data); } diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index 71ce319fc..2e0266814 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -8,6 +8,7 @@ #include "types.h" #include "windows_customizations.h" +#include "distance.h" namespace diskann { @@ -87,6 +88,8 @@ template class AbstractDataStore // in the dataset virtual location_t calculate_medoid() const = 0; + virtual Distance *get_dist_fn() = 0; + // search helpers // if the base data is aligned per the request of the metric, this will tell // how to align the query vector in a consistent manner diff --git a/include/abstract_index.h b/include/abstract_index.h new file mode 100644 index 000000000..1a32bf8da --- /dev/null +++ b/include/abstract_index.h @@ -0,0 +1,118 @@ +#pragma once +#include "distance.h" +#include "parameters.h" +#include "utils.h" +#include "types.h" +#include "index_config.h" +#include "index_build_params.h" +#include + +namespace diskann +{ +struct consolidation_report +{ + enum status_code + { + SUCCESS = 0, + FAIL = 1, + LOCK_FAIL = 2, + INCONSISTENT_COUNT_ERROR = 3 + }; + status_code _status; + size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete; + double _time; + + consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots, + size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete, + double time_secs) + : _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots), + _slots_released(slots_released), _delete_set_size(delete_set_size), + _num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs) + { + } +}; + +/* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods +that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index. +*/ +class AbstractIndex +{ + public: + AbstractIndex() = default; + virtual ~AbstractIndex() = default; + + virtual void build(const std::string &data_file, const size_t num_points_to_load, + IndexBuildParams &build_params) = 0; + + template + void build(const data_type *data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, + const std::vector &tags); + + virtual void save(const char *filename, bool compact_before_save = false) = 0; + +#ifdef EXEC_ENV_OLS + virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0; +#else + virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l) = 0; +#endif + + // For FastL2 search on optimized layout + template + void search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices); + + // Initialize space for res_vectors before calling. + template + size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + float *distances, std::vector &res_vectors); + + // Added search overload that takes L as parameter, so that we + // can customize L on a per-query basis without tampering with "Parameters" + // IDtype is either uint32_t or uint64_t + template + std::pair search(const data_type *query, const size_t K, const uint32_t L, IDType *indices, + float *distances = nullptr); + + // Filter support search + // IndexType is either uint32_t or uint64_t + template + std::pair search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances); + + template int insert_point(const data_type *point, const tag_type tag); + + template int lazy_delete(const tag_type &tag); + + template + void lazy_delete(const std::vector &tags, std::vector &failed_tags); + + template void get_active_tags(tsl::robin_set &active_tags); + + template void set_start_points_at_random(data_type radius, uint32_t random_seed = 0); + + virtual consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters) = 0; + + virtual void optimize_index_layout() = 0; + + // memory should be allocated for vec before calling this function + template int get_vector_by_tag(tag_type &tag, data_type *vec); + + private: + virtual void _build(const DataType &data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, + TagVector &tags) = 0; + virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances = nullptr) = 0; + virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, + const size_t K, const uint32_t L, std::any &indices, + float *distances) = 0; + virtual int _insert_point(const DataType &data_point, const TagType tag) = 0; + virtual int _lazy_delete(const TagType &tag) = 0; + virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0; + virtual void _get_active_tags(TagRobinSet &active_tags) = 0; + virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0; + virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; + virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors) = 0; + virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; +}; +} // namespace diskann diff --git a/include/any_wrappers.h b/include/any_wrappers.h new file mode 100644 index 000000000..da9005cfb --- /dev/null +++ b/include/any_wrappers.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include "tsl/robin_set.h" + +namespace AnyWrapper +{ + +/* + * Base Struct to hold refrence to the data. + * Note: No memory mamagement, caller need to keep object alive. + */ +struct AnyReference +{ + template AnyReference(Ty &reference) : _data(&reference) + { + } + + template Ty &get() + { + auto ptr = std::any_cast(_data); + return *ptr; + } + + private: + std::any _data; +}; +struct AnyRobinSet : public AnyReference +{ + template AnyRobinSet(const tsl::robin_set &robin_set) : AnyReference(robin_set) + { + } + template AnyRobinSet(tsl::robin_set &robin_set) : AnyReference(robin_set) + { + } +}; + +struct AnyVector : public AnyReference +{ + template AnyVector(const std::vector &vector) : AnyReference(vector) + { + } + template AnyVector(std::vector &vector) : AnyReference(vector) + { + } +}; +} // namespace AnyWrapper diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 70d1fa28f..e9372c43a 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -50,7 +50,7 @@ template class InMemDataStore : public AbstractDataStore *get_dist_fn(); + virtual Distance *get_dist_fn() override; virtual size_t get_alignment_factor() const override; diff --git a/include/index.h b/include/index.h index 9e8ab645a..b6738df52 100644 --- a/include/index.h +++ b/include/index.h @@ -19,6 +19,7 @@ #include "windows_customizations.h" #include "scratch.h" #include "in_mem_data_store.h" +#include "abstract_index.h" #define OVERHEAD_FACTOR 1.1 #define EXPAND_IF_FULL 0 @@ -37,30 +38,7 @@ inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, u return OVERHEAD_FACTOR * (size_of_data + size_of_graph + size_of_locks + size_of_outer_vector); } -struct consolidation_report -{ - enum status_code - { - SUCCESS = 0, - FAIL = 1, - LOCK_FAIL = 2, - INCONSISTENT_COUNT_ERROR = 3 - }; - status_code _status; - size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete; - double _time; - - consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots, - size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete, - double time_secs) - : _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots), - _slots_released(slots_released), _delete_set_size(delete_set_size), - _num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs) - { - } -}; - -template class Index +template class Index : public AbstractIndex { /************************************************************************** * @@ -76,7 +54,8 @@ template clas DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points = 1, const bool dynamic_index = false, const bool enable_tags = false, const bool concurrent_consolidate = false, const bool pq_dist_build = false, const size_t num_pq_chunks = 0, - const bool use_opq = false, const size_t num_frozen_pts = 0); + const bool use_opq = false, const size_t num_frozen_pts = 0, + const bool init_data_store = true); // Constructor for incremental index DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points, const bool dynamic_index, @@ -85,6 +64,9 @@ template clas const bool concurrent_consolidate = false, const bool pq_dist_build = false, const size_t num_pq_chunks = 0, const bool use_opq = false); + DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::unique_ptr> data_store + /* std::unique_ptr graph_store*/); + DISKANN_DLLEXPORT ~Index(); // Saves graph, data, metadata and associated tags. @@ -120,6 +102,9 @@ template clas DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, const std::vector &tags); + DISKANN_DLLEXPORT void build(const std::string &data_file, const size_t num_points_to_load, + IndexBuildParams &build_params); + // Filtered Support DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file, const size_t num_points_to_load, IndexWriteParameters ¶meters, @@ -213,6 +198,34 @@ template clas // ******************************** protected: + // overload of abstract index virtual methods + virtual void _build(const DataType &data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, + TagVector &tags) override; + + virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances = nullptr) override; + virtual std::pair _search_with_filters(const DataType &query, + const std::string &filter_label_raw, const size_t K, + const uint32_t L, std::any &indices, + float *distances) override; + + virtual int _insert_point(const DataType &data_point, const TagType tag) override; + + virtual int _lazy_delete(const TagType &tag) override; + + virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) override; + + virtual void _get_active_tags(TagRobinSet &active_tags) override; + + virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) override; + + virtual int _get_vector_by_tag(TagType &tag, DataType &vec) override; + + virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override; + + virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors) override; + // No copy/assign. Index(const Index &) = delete; Index &operator=(const Index &) = delete; @@ -319,7 +332,7 @@ template clas std::shared_ptr> _distance; // Data - std::unique_ptr> _data_store; + std::unique_ptr> _data_store; char *_opt_graph = nullptr; // Graph related data structures diff --git a/include/index_build_params.h b/include/index_build_params.h new file mode 100644 index 000000000..ff68c5001 --- /dev/null +++ b/include/index_build_params.h @@ -0,0 +1,72 @@ +#include "common_includes.h" +#include "parameters.h" + +namespace diskann +{ +struct IndexBuildParams +{ + public: + diskann::IndexWriteParameters index_write_params; + std::string save_path_prefix; + std::string label_file; + std::string universal_label; + uint32_t filter_threshold = 0; + + private: + IndexBuildParams(const IndexWriteParameters &index_write_params, const std::string &save_path_prefix, + const std::string &label_file, const std::string &universal_label, uint32_t filter_threshold) + : index_write_params(index_write_params), save_path_prefix(save_path_prefix), label_file(label_file), + universal_label(universal_label), filter_threshold(filter_threshold) + { + } + + friend class IndexBuildParamsBuilder; +}; +class IndexBuildParamsBuilder +{ + public: + IndexBuildParamsBuilder(const diskann::IndexWriteParameters ¶s) : _index_write_params(paras){}; + + IndexBuildParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) + { + if (save_path_prefix.empty() || save_path_prefix == "") + throw ANNException("Error: save_path_prefix can't be empty", -1); + this->_save_path_prefix = save_path_prefix; + return *this; + } + + IndexBuildParamsBuilder &with_label_file(const std::string &label_file) + { + this->_label_file = label_file; + return *this; + } + + IndexBuildParamsBuilder &with_universal_label(const std::string &univeral_label) + { + this->_universal_label = univeral_label; + return *this; + } + + IndexBuildParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold) + { + this->_filter_threshold = filter_threshold; + return *this; + } + + IndexBuildParams build() + { + return IndexBuildParams(_index_write_params, _save_path_prefix, _label_file, _universal_label, + _filter_threshold); + } + + IndexBuildParamsBuilder(const IndexBuildParamsBuilder &) = delete; + IndexBuildParamsBuilder &operator=(const IndexBuildParamsBuilder &) = delete; + + private: + diskann::IndexWriteParameters _index_write_params; + std::string _save_path_prefix; + std::string _label_file; + std::string _universal_label; + uint32_t _filter_threshold = 0; +}; +} // namespace diskann diff --git a/include/index_config.h b/include/index_config.h new file mode 100644 index 000000000..b291c744d --- /dev/null +++ b/include/index_config.h @@ -0,0 +1,224 @@ +#include "common_includes.h" +#include "parameters.h" + +namespace diskann +{ +enum DataStoreStrategy +{ + MEMORY +}; + +enum GraphStoreStrategy +{ +}; +struct IndexConfig +{ + DataStoreStrategy data_strategy; + GraphStoreStrategy graph_strategy; + + Metric metric; + size_t dimension; + size_t max_points; + + bool dynamic_index; + bool enable_tags; + bool pq_dist_build; + bool concurrent_consolidate; + bool use_opq; + + size_t num_pq_chunks; + size_t num_frozen_pts; + + std::string label_type; + std::string tag_type; + std::string data_type; + + std::shared_ptr index_write_params; + + uint32_t search_threads; + uint32_t initial_search_list_size; + + private: + IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension, + size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags, + bool pq_dist_build, bool concurrent_consolidate, bool use_opq, const std::string &data_type, + const std::string &tag_type, const std::string &label_type, + std::shared_ptr index_write_params, uint32_t search_threads, + uint32_t initial_search_list_size) + : data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension), + max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build), + concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), num_pq_chunks(num_pq_chunks), + num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type), data_type(data_type), + index_write_params(index_write_params), search_threads(search_threads), + initial_search_list_size(initial_search_list_size) + { + } + + friend class IndexConfigBuilder; +}; + +class IndexConfigBuilder +{ + public: + IndexConfigBuilder() + { + } + + IndexConfigBuilder &with_metric(Metric m) + { + this->_metric = m; + return *this; + } + + IndexConfigBuilder &with_graph_load_store_strategy(GraphStoreStrategy graph_strategy) + { + this->_graph_strategy = graph_strategy; + return *this; + } + + IndexConfigBuilder &with_data_load_store_strategy(DataStoreStrategy data_strategy) + { + this->_data_strategy = data_strategy; + return *this; + } + + IndexConfigBuilder &with_dimension(size_t dimension) + { + this->_dimension = dimension; + return *this; + } + + IndexConfigBuilder &with_max_points(size_t max_points) + { + this->_max_points = max_points; + return *this; + } + + IndexConfigBuilder &is_dynamic_index(bool dynamic_index) + { + this->_dynamic_index = dynamic_index; + return *this; + } + + IndexConfigBuilder &is_enable_tags(bool enable_tags) + { + this->_enable_tags = enable_tags; + return *this; + } + + IndexConfigBuilder &is_pq_dist_build(bool pq_dist_build) + { + this->_pq_dist_build = pq_dist_build; + return *this; + } + + IndexConfigBuilder &is_concurrent_consolidate(bool concurrent_consolidate) + { + this->_concurrent_consolidate = concurrent_consolidate; + return *this; + } + + IndexConfigBuilder &is_use_opq(bool use_opq) + { + this->_use_opq = use_opq; + return *this; + } + + IndexConfigBuilder &with_num_pq_chunks(size_t num_pq_chunks) + { + this->_num_pq_chunks = num_pq_chunks; + return *this; + } + + IndexConfigBuilder &with_num_frozen_pts(size_t num_frozen_pts) + { + this->_num_frozen_pts = num_frozen_pts; + return *this; + } + + IndexConfigBuilder &with_label_type(const std::string &label_type) + { + this->_label_type = label_type; + return *this; + } + + IndexConfigBuilder &with_tag_type(const std::string &tag_type) + { + this->_tag_type = tag_type; + return *this; + } + + IndexConfigBuilder &with_data_type(const std::string &data_type) + { + this->_data_type = data_type; + return *this; + } + + IndexConfigBuilder &with_index_write_params(IndexWriteParameters &index_write_params) + { + this->_index_write_params = std::make_shared(index_write_params); + return *this; + } + + IndexConfigBuilder &with_search_threads(uint32_t search_threads) + { + this->_search_threads = search_threads; + return *this; + } + + IndexConfigBuilder &with_initial_search_list_size(uint32_t search_list_size) + { + this->_initial_search_list_size = search_list_size; + return *this; + } + + IndexConfig build() + { + if (_data_type == "" || _data_type.empty()) + throw ANNException("Error: data_type can not be empty", -1); + + if (_dynamic_index && _index_write_params != nullptr) + { + if (_search_threads == 0) + throw ANNException("Error: please pass search_threads for building dynamic index.", -1); + + if (_initial_search_list_size == 0) + throw ANNException("Error: please pass initial_search_list_size for building dynamic index.", -1); + } + + return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks, + _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate, + _use_opq, _data_type, _tag_type, _label_type, _index_write_params, _search_threads, + _initial_search_list_size); + } + + IndexConfigBuilder(const IndexConfigBuilder &) = delete; + IndexConfigBuilder &operator=(const IndexConfigBuilder &) = delete; + + private: + DataStoreStrategy _data_strategy; + GraphStoreStrategy _graph_strategy; + + Metric _metric; + size_t _dimension; + size_t _max_points; + + bool _dynamic_index = false; + bool _enable_tags = false; + bool _pq_dist_build = false; + bool _concurrent_consolidate = false; + bool _use_opq = false; + + size_t _num_pq_chunks = 0; + size_t _num_frozen_pts = 0; + + std::string _label_type = "uint32"; + std::string _tag_type = "uint32"; + std::string _data_type; + + std::shared_ptr _index_write_params; + + uint32_t _search_threads; + uint32_t _initial_search_list_size; +}; +} // namespace diskann diff --git a/include/index_factory.h b/include/index_factory.h new file mode 100644 index 000000000..3d1eb7992 --- /dev/null +++ b/include/index_factory.h @@ -0,0 +1,37 @@ +#include "index.h" +#include "abstract_graph_store.h" +#include "in_mem_graph_store.h" + +namespace diskann +{ +class IndexFactory +{ + public: + DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config); + DISKANN_DLLEXPORT std::unique_ptr create_instance(); + + private: + void check_config(); + + template + std::unique_ptr> construct_datastore(DataStoreStrategy stratagy, size_t num_points, + size_t dimension); + + std::unique_ptr construct_graphstore(GraphStoreStrategy stratagy, size_t size); + + template + std::unique_ptr create_instance(); + + std::unique_ptr create_instance(const std::string &data_type, const std::string &tag_type, + const std::string &label_type); + + template + std::unique_ptr create_instance(const std::string &tag_type, const std::string &label_type); + + template + std::unique_ptr create_instance(const std::string &label_type); + + std::unique_ptr _config; +}; + +} // namespace diskann diff --git a/include/types.h b/include/types.h index ea04cd34d..b95848869 100644 --- a/include/types.h +++ b/include/types.h @@ -5,8 +5,17 @@ #include #include +#include +#include "any_wrappers.h" namespace diskann { typedef uint32_t location_t; -} // namespace diskann \ No newline at end of file + +using DataType = std::any; +using TagType = std::any; +using LabelType = std::any; +using TagVector = AnyWrapper::AnyVector; +using DataVector = AnyWrapper::AnyVector; +using TagRobinSet = AnyWrapper::AnyRobinSet; +} // namespace diskann diff --git a/include/utils.h b/include/utils.h index b484c2aeb..58bb52a3b 100644 --- a/include/utils.h +++ b/include/utils.h @@ -27,6 +27,7 @@ typedef int FileHandle; #include "windows_customizations.h" #include "tsl/robin_set.h" #include "types.h" +#include #ifdef EXEC_ENV_OLS #include "content_buf.h" @@ -337,6 +338,26 @@ inline void get_bin_metadata(const std::string &bin_file, size_t &nrows, size_t } // get_bin_metadata functions END +#ifndef EXEC_ENV_OLS +inline size_t get_graph_num_frozen_points(const std::string &graph_file) +{ + size_t expected_file_size; + uint32_t max_observed_degree, start; + size_t file_frozen_pts; + + std::ifstream in; + in.exceptions(std::ios::badbit | std::ios::failbit); + + in.open(graph_file, std::ios::binary); + in.read((char *)&expected_file_size, sizeof(size_t)); + in.read((char *)&max_observed_degree, sizeof(uint32_t)); + in.read((char *)&start, sizeof(uint32_t)); + in.read((char *)&file_frozen_pts, sizeof(size_t)); + + return file_frozen_pts; +} +#endif + template inline std::string getValues(T *data, size_t num) { std::stringstream stream; @@ -1092,6 +1113,44 @@ inline void clean_up_artifacts(tsl::robin_set paths_to_clean, tsl:: } } +template inline const char *diskann_type_to_name() = delete; +template <> inline const char *diskann_type_to_name() +{ + return "float"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint8"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int8"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint16"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int16"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint32"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int32"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint64"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int64"; +} + #ifdef _WINDOWS #include #include diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 723ac9aca..2206a01f7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,7 +13,7 @@ else() linux_aligned_file_reader.cpp math_utils.cpp natural_number_map.cpp in_mem_data_store.cpp in_mem_graph_store.cpp natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp - pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp) + pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp index_factory.cpp abstract_index.cpp) if (RESTAPI) list(APPEND CPP_SOURCES restapi/search_wrapper.cpp restapi/server.cpp) endif() diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp new file mode 100644 index 000000000..518f8b7dd --- /dev/null +++ b/src/abstract_index.cpp @@ -0,0 +1,280 @@ +#include "common_includes.h" +#include "windows_customizations.h" +#include "abstract_index.h" + +namespace diskann +{ + +template +void AbstractIndex::build(const data_type *data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, const std::vector &tags) +{ + auto any_data = std::any(data); + auto any_tags_vec = TagVector(tags); + this->_build(any_data, num_points_to_load, parameters, any_tags_vec); +} + +template +std::pair AbstractIndex::search(const data_type *query, const size_t K, const uint32_t L, + IDType *indices, float *distances) +{ + auto any_indices = std::any(indices); + auto any_query = std::any(query); + return _search(any_query, K, L, any_indices, distances); +} + +template +size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + float *distances, std::vector &res_vectors) +{ + auto any_query = std::any(query); + auto any_tags = std::any(tags); + auto any_res_vectors = DataVector(res_vectors); + return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors); +} + +template +std::pair AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances) +{ + auto any_indices = std::any(indices); + return _search_with_filters(query, raw_label, K, L, any_indices, distances); +} + +template +void AbstractIndex::search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices) +{ + auto any_query = std::any(query); + this->_search_with_optimized_layout(any_query, K, L, indices); +} + +template +int AbstractIndex::insert_point(const data_type *point, const tag_type tag) +{ + auto any_point = std::any(point); + auto any_tag = std::any(tag); + return this->_insert_point(any_point, any_tag); +} + +template int AbstractIndex::lazy_delete(const tag_type &tag) +{ + auto any_tag = std::any(tag); + return this->_lazy_delete(any_tag); +} + +template +void AbstractIndex::lazy_delete(const std::vector &tags, std::vector &failed_tags) +{ + auto any_tags = TagVector(tags); + auto any_failed_tags = TagVector(failed_tags); + this->_lazy_delete(any_tags, any_failed_tags); +} + +template void AbstractIndex::get_active_tags(tsl::robin_set &active_tags) +{ + auto any_active_tags = TagRobinSet(active_tags); + this->_get_active_tags(any_active_tags); +} + +template void AbstractIndex::set_start_points_at_random(data_type radius, uint32_t random_seed) +{ + auto any_radius = std::any(radius); + this->_set_start_points_at_random(any_radius, random_seed); +} + +template int AbstractIndex::get_vector_by_tag(tag_type &tag, data_type *vec) +{ + auto any_tag = std::any(tag); + auto any_data_ptr = std::any(vec); + return this->_get_vector_by_tag(any_tag, any_data_ptr); +} + +// exports +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, int32_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t +AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, + int32_t *tags, float *distances, std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + int32_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, uint32_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + uint32_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, int64_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t +AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, + int64_t *tags, float *distances, std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + int64_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, uint64_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + uint64_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, + size_t L, uint32_t *indices); +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const uint8_t *query, size_t K, + size_t L, uint32_t *indices); +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const int8_t *query, size_t K, + size_t L, uint32_t *indices); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const int32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const int32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const int32_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const uint32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const uint32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const uint32_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const int64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const int64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const int64_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const uint64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const uint64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const uint64_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const int32_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const uint32_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const int64_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const uint64_t &tag); + +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); + +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); + +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(float radius, uint32_t random_seed); +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(uint8_t radius, + uint32_t random_seed); +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(int8_t radius, uint32_t random_seed); + +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, int8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, int8_t *vec); + +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, int8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, int8_t *vec); + +} // namespace diskann diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index e02996f32..d00cfeb95 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -4,9 +4,10 @@ add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") + set(DISKANN_DLL_IMPLIB "${TARGET_DIR}/${PROJECT_NAME}.lib") target_compile_definitions(${PROJECT_NAME} PRIVATE _USRDLL _WINDLL) diff --git a/src/index.cpp b/src/index.cpp index 9a6751d0c..c30e44ac2 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -34,6 +34,10 @@ Index::Index(Metric m, const size_t dim, const size_t max_point : Index(m, dim, max_points, dynamic_index, enable_tags, concurrent_consolidate, pq_dist_build, num_pq_chunks, use_opq, indexParams.num_frozen_points) { + if (dynamic_index) + { + this->enable_delete(); + } _indexingQueueSize = indexParams.search_list_size; _indexingRange = indexParams.max_degree; _indexingMaxC = indexParams.max_occlusion_size; @@ -50,7 +54,8 @@ Index::Index(Metric m, const size_t dim, const size_t max_point template Index::Index(Metric m, const size_t dim, const size_t max_points, const bool dynamic_index, const bool enable_tags, const bool concurrent_consolidate, const bool pq_dist_build, - const size_t num_pq_chunks, const bool use_opq, const size_t num_frozen_pts) + const size_t num_pq_chunks, const bool use_opq, const size_t num_frozen_pts, + const bool init_data_store) : _dist_metric(m), _dim(dim), _max_points(max_points), _num_frozen_pts(num_frozen_pts), _dynamic_index(dynamic_index), _enable_tags(enable_tags), _indexingMaxC(DEFAULT_MAXC), _query_scratch(nullptr), _pq_dist(pq_dist_build), _use_opq(use_opq), _num_pq_chunks(num_pq_chunks), @@ -98,24 +103,27 @@ Index::Index(Metric m, const size_t dim, const size_t max_point _final_graph.resize(total_internal_points); - // This should come from a factory. - if (m == diskann::Metric::COSINE && std::is_floating_point::value) - { - // This is safe because T is float inside the if block. - this->_distance.reset((Distance *)new AVXNormalizedCosineDistanceFloat()); - this->_normalize_vecs = true; - diskann::cout << "Normalizing vectors and using L2 for cosine " - "AVXNormalizedCosineDistanceFloat()." - << std::endl; - } - else + if (init_data_store) { - this->_distance.reset((Distance *)get_distance_function(m)); + // Issue #374: data_store is injected from index factory. Keeping this for backward compatibility. + // distance is owned by data_store + if (m == diskann::Metric::COSINE && std::is_floating_point::value) + { + // This is safe because T is float inside the if block. + this->_distance.reset((Distance *)new AVXNormalizedCosineDistanceFloat()); + this->_normalize_vecs = true; + diskann::cout << "Normalizing vectors and using L2 for cosine " + "AVXNormalizedCosineDistanceFloat()." + << std::endl; + } + else + { + this->_distance.reset((Distance *)get_distance_function(m)); + } + // Note: moved this to factory, keeping this for backward compatibility. + _data_store = + std::make_unique>((location_t)total_internal_points, _dim, this->_distance); } - // REFACTOR: TODO This should move to a factory method. - - _data_store = - std::make_unique>((location_t)total_internal_points, _dim, this->_distance); _locks = std::vector(total_internal_points); @@ -126,6 +134,37 @@ Index::Index(Metric m, const size_t dim, const size_t max_point } } +template +Index::Index(const IndexConfig &index_config, std::unique_ptr> data_store) + : Index(index_config.metric, index_config.dimension, index_config.max_points, index_config.dynamic_index, + index_config.enable_tags, index_config.concurrent_consolidate, index_config.pq_dist_build, + index_config.num_pq_chunks, index_config.use_opq, index_config.num_frozen_pts, false) +{ + + _data_store = std::move(data_store); + _distance.reset(_data_store->get_dist_fn()); + + // enable delete by default for dynamic index + if (_dynamic_index) + { + this->enable_delete(); + } + if (_dynamic_index && index_config.index_write_params != nullptr) + { + _indexingQueueSize = index_config.index_write_params->search_list_size; + _indexingRange = index_config.index_write_params->max_degree; + _indexingMaxC = index_config.index_write_params->max_occlusion_size; + _indexingAlpha = index_config.index_write_params->alpha; + _filterIndexingQueueSize = index_config.index_write_params->filter_list_size; + + uint32_t num_threads_indx = index_config.index_write_params->num_threads; + uint32_t num_scratch_spaces = index_config.search_threads + num_threads_indx; + + initialize_query_scratch(num_scratch_spaces, index_config.initial_search_list_size, _indexingQueueSize, + _indexingRange, _indexingMaxC, _data_store->get_dims()); + } +} + template Index::~Index() { // Ensure that no other activity is happening before dtor() @@ -791,6 +830,25 @@ size_t Index::load_graph(std::string filename, size_t expected_ return nodes_read; } +template +int Index::_get_vector_by_tag(TagType &tag, DataType &vec) +{ + try + { + TagT tag_val = std::any_cast(tag); + T *vec_val = std::any_cast(vec); + return this->get_vector_by_tag(tag_val, vec_val); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _get_vector_by_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template int Index::get_vector_by_tag(TagT &tag, T *vec) { std::shared_lock lock(_tag_lock); @@ -1558,6 +1616,25 @@ void Index::set_start_points(const T *data, size_t data_count) diskann::cout << "Index start points set: #" << _num_frozen_pts << std::endl; } +template +void Index::_set_start_points_at_random(DataType radius, uint32_t random_seed) +{ + try + { + T radius_to_use = std::any_cast(radius); + this->set_start_points_at_random(radius_to_use, random_seed); + } + catch (const std::bad_any_cast &e) + { + throw ANNException( + "Error: bad any cast while performing _set_start_points_at_random() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template void Index::set_start_points_at_random(T radius, uint32_t random_seed) { @@ -1642,7 +1719,24 @@ void Index::build_with_data_populated(const IndexWriteParameter _max_observed_degree = std::max((uint32_t)max, _max_observed_degree); _has_built = true; } - +template +void Index::_build(const DataType &data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, TagVector &tags) +{ + try + { + this->build(std::any_cast(data), num_points_to_load, parameters, + tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast in while building index. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error" + std::string(e.what()), -1); + } +} template void Index::build(const T *data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, const std::vector &tags) @@ -1683,7 +1777,11 @@ template void Index::build(const char *filename, const size_t num_points_to_load, const IndexWriteParameters ¶meters, const std::vector &tags) { + // idealy this should call build_filtered_index based on params passed + std::unique_lock ul(_update_lock); + + // error checks if (num_points_to_load == 0) throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -1813,6 +1911,42 @@ void Index::build(const char *filename, const size_t num_points build(filename, num_points_to_load, parameters, tags); } +template +void Index::build(const std::string &data_file, const size_t num_points_to_load, + IndexBuildParams &build_params) +{ + std::string labels_file_to_use = build_params.save_path_prefix + "_label_formatted.txt"; + std::string mem_labels_int_map_file = build_params.save_path_prefix + "_labels_map.txt"; + + size_t points_to_load = num_points_to_load == 0 ? _max_points : num_points_to_load; + + auto s = std::chrono::high_resolution_clock::now(); + if (build_params.label_file == "") + { + this->build(data_file.c_str(), points_to_load, build_params.index_write_params); + } + else + { + // TODO: this should ideally happen in save() + convert_labels_string_to_int(build_params.label_file, labels_file_to_use, mem_labels_int_map_file, + build_params.universal_label); + if (build_params.universal_label != "") + { + LabelT unv_label_as_num = 0; + this->set_universal_label(unv_label_as_num); + } + this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load, + build_params.index_write_params); + } + std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; + std::cout << "Indexing time: " << diff.count() << "\n"; + // cleanup + if (build_params.label_file != "") + { + // clean_up_artifacts({labels_file_to_use, mem_labels_int_map_file}, {}); + } +} + template std::unordered_map Index::load_label_map(const std::string &labels_map_file) { @@ -1909,10 +2043,11 @@ void Index::build_filtered_index(const char *filename, const st const size_t num_points_to_load, IndexWriteParameters ¶meters, const std::vector &tags) { - _labels_file = label_file; + _labels_file = label_file; // original label file _filtered_index = true; _label_to_medoid_id.clear(); size_t num_points_labels = 0; + parse_label_file(label_file, num_points_labels); // determines medoid for each label and identifies // the points to label mapping @@ -1973,6 +2108,38 @@ void Index::build_filtered_index(const char *filename, const st this->build(filename, num_points_to_load, parameters, tags); } +template +std::pair Index::_search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances) +{ + try + { + auto typed_query = std::any_cast(query); + if (typeid(uint32_t *) == indices.type()) + { + auto u32_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u32_ptr, distances); + } + else if (typeid(uint64_t *) == indices.type()) + { + auto u64_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u64_ptr, distances); + } + else + { + throw ANNException("Error: indices type can only be uint64_t or uint32_t.", -1); + } + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while searching. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template template std::pair Index::search(const T *query, const size_t K, const uint32_t L, @@ -2036,6 +2203,29 @@ std::pair Index::search(const T *query, con return retval; } +template +std::pair Index::_search_with_filters(const DataType &query, + const std::string &raw_label, const size_t K, + const uint32_t L, std::any &indices, + float *distances) +{ + auto converted_label = this->get_converted_label(raw_label); + if (typeid(uint64_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + } + else if (typeid(uint32_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + } + else + { + throw ANNException("Error: Id type can only be uint64_t or uint32_t.", -1); + } +} + template template std::pair Index::search_with_filters(const T *query, const LabelT &filter_label, @@ -2114,6 +2304,25 @@ std::pair Index::search_with_filters(const return retval; } +template +size_t Index::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, + const TagType &tags, float *distances, DataVector &res_vectors) +{ + try + { + return this->search_with_tags(std::any_cast(query), K, L, std::any_cast(tags), distances, + res_vectors.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _search_with_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template size_t Index::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors) @@ -2709,6 +2918,23 @@ template void Index(stop - start).count() << "s" << std::endl; } +template +int Index::_insert_point(const DataType &point, const TagType tag) +{ + try + { + return this->insert_point(std::any_cast(point), std::any_cast(tag)); + } + catch (const std::bad_any_cast &anycast_e) + { + throw new ANNException("Error:Trying to insert invalid data type" + std::string(anycast_e.what()), -1); + } + catch (const std::exception &e) + { + throw new ANNException("Error:" + std::string(e.what()), -1); + } +} + template int Index::insert_point(const T *point, const TagT tag) { @@ -2822,6 +3048,35 @@ int Index::insert_point(const T *point, const TagT tag) return 0; } +template int Index::_lazy_delete(const TagType &tag) +{ + try + { + return lazy_delete(std::any_cast(tag)); + } + catch (const std::bad_any_cast &e) + { + throw ANNException(std::string("Error: ") + e.what(), -1); + } +} + +template +void Index::_lazy_delete(TagVector &tags, TagVector &failed_tags) +{ + try + { + this->lazy_delete(tags.get>(), failed_tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _lazy_delete() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template int Index::lazy_delete(const TagT &tag) { std::shared_lock ul(_update_lock); @@ -2877,6 +3132,23 @@ template bool Index +void Index::_get_active_tags(TagRobinSet &active_tags) +{ + try + { + this->get_active_tags(active_tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad_any cast while performing _get_active_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error :" + std::string(e.what()), -1); + } +} + template void Index::get_active_tags(tsl::robin_set &active_tags) { @@ -2998,6 +3270,24 @@ template void Index +void Index::_search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) +{ + try + { + return this->search_with_optimized_layout(std::any_cast(query), K, L, indices); + } + catch (const std::bad_any_cast &e) + { + throw ANNException( + "Error: bad any cast while performing _search_with_optimized_layout() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template void Index::search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices) { @@ -3238,5 +3528,4 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); - } // namespace diskann diff --git a/src/index_factory.cpp b/src/index_factory.cpp new file mode 100644 index 000000000..c5607f4a0 --- /dev/null +++ b/src/index_factory.cpp @@ -0,0 +1,150 @@ +#include "index_factory.h" + +namespace diskann +{ + +IndexFactory::IndexFactory(const IndexConfig &config) : _config(std::make_unique(config)) +{ + check_config(); +} + +std::unique_ptr IndexFactory::create_instance() +{ + return create_instance(_config->data_type, _config->tag_type, _config->label_type); +} + +void IndexFactory::check_config() +{ + if (_config->dynamic_index && !_config->enable_tags) + { + throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_config->pq_dist_build) + { + if (_config->dynamic_index) + throw ANNException("ERROR: Dynamic Indexing not supported with PQ distance based " + "index construction", + -1, __FUNCSIG__, __FILE__, __LINE__); + if (_config->metric == diskann::Metric::INNER_PRODUCT) + throw ANNException("ERROR: Inner product metrics not yet supported " + "with PQ distance " + "base index", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_config->data_type != "float" && _config->data_type != "uint8" && _config->data_type != "int8") + { + throw ANNException("ERROR: invalid data type : + " + _config->data_type + + " is not supported. please select from [float, int8, uint8]", + -1); + } + + if (_config->tag_type != "int32" && _config->tag_type != "uint32" && _config->tag_type != "int64" && + _config->tag_type != "uint64") + { + throw ANNException("ERROR: invalid data type : + " + _config->tag_type + + " is not supported. please select from [int32, uint32, int64, uint64]", + -1); + } +} + +template +std::unique_ptr> IndexFactory::construct_datastore(DataStoreStrategy strategy, size_t num_points, + size_t dimension) +{ + const size_t total_internal_points = num_points + _config->num_frozen_pts; + std::shared_ptr> distance; + switch (strategy) + { + case MEMORY: + if (_config->metric == diskann::Metric::COSINE && std::is_same::value) + { + distance.reset((Distance *)new AVXNormalizedCosineDistanceFloat()); + return std::make_unique>((location_t)total_internal_points, dimension, distance); + } + else + { + distance.reset((Distance *)get_distance_function(_config->metric)); + return std::make_unique>((location_t)total_internal_points, dimension, distance); + } + break; + default: + break; + } + return nullptr; +} + +std::unique_ptr IndexFactory::construct_graphstore(GraphStoreStrategy, size_t size) +{ + return std::make_unique(size); +} + +template +std::unique_ptr IndexFactory::create_instance() +{ + size_t num_points = _config->max_points; + size_t dim = _config->dimension; + // auto graph_store = construct_graphstore(_config->graph_strategy, num_points); + auto data_store = construct_datastore(_config->data_strategy, num_points, dim); + return std::make_unique>(*_config, std::move(data_store)); +} + +std::unique_ptr IndexFactory::create_instance(const std::string &data_type, const std::string &tag_type, + const std::string &label_type) +{ + if (data_type == std::string("float")) + { + return create_instance(tag_type, label_type); + } + else if (data_type == std::string("uint8")) + { + return create_instance(tag_type, label_type); + } + else if (data_type == std::string("int8")) + { + return create_instance(tag_type, label_type); + } + else + throw ANNException("Error: unsupported data_type please choose from [float/int8/uint8]", -1); +} + +template +std::unique_ptr IndexFactory::create_instance(const std::string &tag_type, const std::string &label_type) +{ + if (tag_type == std::string("int32")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("uint32")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("int64")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("uint64")) + { + return create_instance(label_type); + } + else + throw ANNException("Error: unsupported tag_type please choose from [int32/uint32/int64/uint64]", -1); +} + +template +std::unique_ptr IndexFactory::create_instance(const std::string &label_type) +{ + if (label_type == std::string("uint16") || label_type == std::string("ushort")) + { + return create_instance(); + } + else if (label_type == std::string("uint32") || label_type == std::string("uint")) + { + return create_instance(); + } + else + throw ANNException("Error: unsupported label_type please choose from [uint/ushort]", -1); +} + +} // namespace diskann