From 26b500973b894c9f094c4aeed434439d87e8e523 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Wed, 11 Sep 2024 15:18:05 +0800 Subject: [PATCH] Roll back to 6b2fe566 faiss-1.7.2 Signed-off-by: Cai Yudong --- cmake/libs/RAPIDS.cmake | 73 + cmake/libs/libcardinal.cmake | 74 - cmake/libs/libcutlass.cmake | 107 + cmake/libs/libdiskann.cmake | 16 +- cmake/libs/libfaiss.cmake | 84 +- cmake/libs/libraft.cmake | 45 +- cmake/utils/compile_flags.cmake | 4 - .../fetch_rapids.cmake} | 14 +- cmake/utils/platform_check.cmake | 7 +- include/knowhere/binaryset.h | 12 - include/knowhere/bitsetview.h | 14 +- include/knowhere/bitsetview_idselector.h | 33 - include/knowhere/cluster/cluster.h | 149 -- include/knowhere/cluster/cluster_factory.h | 62 - include/knowhere/cluster/cluster_node.h | 52 - include/knowhere/comp/brute_force.h | 25 +- include/knowhere/comp/index_param.h | 87 +- include/knowhere/comp/knowhere_check.h | 135 - include/knowhere/comp/knowhere_config.h | 21 - include/knowhere/comp/materialized_view.h | 47 - include/knowhere/comp/task.h | 44 - include/knowhere/comp/thread_pool.h | 243 +- include/knowhere/config.h | 415 ++-- include/knowhere/dataset.h | 139 +- include/knowhere/device_bitset.h | 90 + include/knowhere/expected.h | 20 +- include/knowhere/factory.h | 43 + include/knowhere/{index => }/index.h | 36 +- include/knowhere/index/index_factory.h | 88 - include/knowhere/index/index_node.h | 583 ----- .../index/index_node_data_mock_wrapper.h | 106 - include/knowhere/index/index_table.h | 133 - include/knowhere/index_node.h | 115 + .../index_node_thread_pool_wrapper.h | 12 +- include/knowhere/object.h | 29 - include/knowhere/operands.h | 168 -- include/knowhere/prometheus_client.h | 105 +- include/knowhere/sparse_utils.h | 277 --- include/knowhere/tolower.h | 29 - include/knowhere/tracer.h | 70 - include/knowhere/utils.h | 144 +- include/knowhere/version.h | 2 +- python/knowhere/__init__.py | 105 +- python/knowhere/knowhere.i | 274 +-- python/setup.py | 2 +- src/cluster/cluster.cc | 68 - src/cluster/cluster_factory.cc | 76 - src/common/comp/brute_force.cc | 657 +---- src/common/comp/knowhere_config.cc | 60 +- src/common/comp/materialized_view.cc | 47 - src/common/config.cc | 93 +- src/common/factory.cc | 45 + src/common/index.cc | 248 ++ .../index_node_thread_pool_wrapper.cc | 8 +- src/common/prometheus_client.cc | 105 +- .../raft/integration/brute_force_index.cu | 24 - src/common/raft/integration/cagra_index.cu | 30 - .../raft/integration/cagra_instantiations.cu | 27 - src/common/raft/integration/ivf_flat_index.cu | 30 - .../integration/ivf_flat_instantiations.cu | 27 - src/common/raft/integration/ivf_pq_index.cu | 30 - .../raft/integration/ivf_pq_instantiations.cu | 27 - .../raft/integration/raft_initialization.cc | 79 - .../raft/integration/raft_initialization.hpp | 31 - .../raft/integration/raft_knowhere_config.hpp | 125 - .../raft/integration/raft_knowhere_index.cuh | 784 ------ .../raft/integration/raft_knowhere_index.hpp | 84 - src/common/raft/integration/type_mappers.hpp | 74 - .../proto/filtered_search_instantiation.cuh | 54 - src/common/raft/proto/raft_index.cuh | 472 ---- src/common/raft/proto/raft_index_kind.hpp | 21 - src/common/raft/raft.cu | 39 + src/common/raft/raft_utils.cc | 46 + src/common/raft/raft_utils.h | 200 ++ src/common/range_util.cc | 78 +- {include/knowhere => src/common}/range_util.h | 19 +- src/common/thread/thread.cc | 80 - src/common/tracer.cc | 166 -- src/common/utils.cc | 91 +- src/index/cagra/cagra.cu | 210 ++ src/index/cagra/cagra_config.h | 47 + src/index/diskann/diskann.cc | 207 +- src/index/diskann/diskann_config.h | 72 +- src/index/flat/flat.cc | 210 +- src/index/gpu/flat_gpu/flat_gpu.cc | 34 +- src/index/gpu/ivf_gpu/ivf_gpu.cc | 43 +- src/index/gpu_raft/gpu_raft.h | 272 --- src/index/gpu_raft/gpu_raft_brute_force.cc | 37 - .../gpu_raft/gpu_raft_brute_force_config.h | 40 - src/index/gpu_raft/gpu_raft_cagra.cc | 173 -- src/index/gpu_raft/gpu_raft_cagra_config.h | 174 -- src/index/gpu_raft/gpu_raft_ivf_flat.cc | 36 - src/index/gpu_raft/gpu_raft_ivf_flat_config.h | 81 - src/index/gpu_raft/gpu_raft_ivf_pq.cc | 43 - src/index/hnsw/faiss_hnsw.cc | 1544 ------------ src/index/hnsw/faiss_hnsw_config.h | 243 -- src/index/hnsw/hnsw.cc | 345 +-- src/index/hnsw/hnsw_config.h | 53 +- src/index/hnsw/impl/FederVisitor.h | 49 - src/index/hnsw/impl/IndexBruteForceWrapper.cc | 110 - src/index/hnsw/impl/IndexBruteForceWrapper.h | 29 - src/index/hnsw/impl/IndexHNSWWrapper.cc | 243 -- src/index/hnsw/impl/IndexHNSWWrapper.h | 51 - src/index/hnsw/impl/IndexWrapperCosine.cc | 29 - src/index/hnsw/impl/IndexWrapperCosine.h | 32 - src/index/index.cc | 336 --- src/index/index_factory.cc | 123 - src/index/index_node_data_mock_wrapper.cc | 86 - src/index/ivf/ivf.cc | 927 +++---- src/index/ivf/ivf_config.h | 117 +- src/index/ivf_raft/ivf_raft.cu | 45 + src/index/ivf_raft/ivf_raft.cuh | 631 +++++ .../ivf_raft_config.h} | 102 +- src/index/sparse/sparse_index_node.cc | 321 --- src/index/sparse/sparse_inverted_index.h | 636 ----- .../sparse/sparse_inverted_index_config.h | 68 - src/io/memory_io.h | 10 - src/simd/distances_avx.cc | 409 +--- src/simd/distances_avx.h | 52 - src/simd/distances_avx512.cc | 436 +--- src/simd/distances_avx512.h | 52 - src/simd/distances_neon.cc | 398 --- src/simd/distances_neon.h | 27 - src/simd/distances_ref.cc | 274 +-- src/simd/distances_ref.h | 74 - src/simd/distances_sse.cc | 87 - src/simd/distances_sse.h | 18 - src/simd/hook.cc | 174 +- src/simd/hook.h | 71 - src/simd/simd_util.h | 123 - tests/faiss/CMakeLists.txt | 72 - .../cmake/utils/platform_check.cmake | 12 - tests/faiss_isolated/cmake/utils/utils.cmake | 60 - tests/python/conftest.py | 21 - tests/python/test_index_load_and_save.py | 76 +- tests/python/test_index_random.py | 88 - tests/ut/CMakeLists.txt | 24 +- tests/ut/test_binaryset.cc | 34 - tests/ut/test_bruteforce.cc | 126 +- tests/ut/test_cluster.cc | 91 - tests/ut/test_config.cc | 635 +---- tests/ut/test_diskann.cc | 116 +- tests/ut/test_distances.cc | 11 - tests/ut/test_faiss_hnsw.cc | 885 ------- tests/ut/test_feder.cc | 22 +- tests/ut/test_get_vector.cc | 40 +- tests/ut/test_gpu_search.cc | 117 +- tests/ut/test_index_check.cc | 283 --- tests/ut/test_iterator.cc | 419 +--- tests/ut/test_ivfflat_cc.cc | 80 +- tests/ut/test_knowhere_init.cc | 19 +- .../ut/test_materialized_view_search_info.cc | 152 -- tests/ut/test_mmap.cc | 437 ++++ tests/ut/test_range_util.cc | 54 +- tests/ut/test_search.cc | 393 +-- tests/ut/test_simd.cc | 112 +- tests/ut/test_sparse.cc | 525 ---- tests/ut/test_tracer.cc | 66 - tests/ut/test_type.cc | 126 - tests/ut/test_utils.cc | 143 +- tests/ut/utils.h | 159 +- .../include/diskann/aio_context_pool.h | 4 + .../include/diskann/aligned_file_reader.h | 40 +- .../DiskANN/include/diskann/ann_exception.h | 20 +- .../DiskANN/include/diskann/aux_utils.h | 102 +- .../DiskANN/include/diskann/cached_io.h | 28 +- thirdparty/DiskANN/include/diskann/distance.h | 4 +- thirdparty/DiskANN/include/diskann/index.h | 140 +- .../diskann/linux_aligned_file_reader.h | 15 +- thirdparty/DiskANN/include/diskann/logger.h | 6 +- .../DiskANN/include/diskann/logger_impl.h | 42 +- .../DiskANN/include/diskann/math_utils.h | 24 +- .../DiskANN/include/diskann/memory_mapper.h | 10 + .../include/diskann/partition_and_pq.h | 18 +- .../include/diskann/percentile_stats.h | 4 +- .../DiskANN/include/diskann/pq_flash_index.h | 143 +- thirdparty/DiskANN/include/diskann/pq_table.h | 65 +- .../DiskANN/include/diskann/simd_utils.h | 9 +- thirdparty/DiskANN/include/diskann/utils.h | 274 ++- .../diskann/windows_aligned_file_reader.h | 51 + .../include/diskann/windows_customizations.h | 16 + thirdparty/DiskANN/src/aux_utils.cpp | 261 +- thirdparty/DiskANN/src/distance.cpp | 80 +- thirdparty/DiskANN/src/dll/CMakeLists.txt | 28 + thirdparty/DiskANN/src/dll/dllmain.cpp | 14 + thirdparty/DiskANN/src/index.cpp | 348 +-- thirdparty/DiskANN/src/logger.cpp | 29 +- thirdparty/DiskANN/src/math_utils.cpp | 125 +- thirdparty/DiskANN/src/memory_mapper.cpp | 56 + thirdparty/DiskANN/src/partition_and_pq.cpp | 249 +- thirdparty/DiskANN/src/pq_flash_index.cpp | 698 +++--- thirdparty/DiskANN/src/utils.cpp | 12 +- .../src/windows_aligned_file_reader.cpp | 158 ++ thirdparty/faiss/.circleci/Dockerfile.cpu | 11 + .../faiss/.circleci/Dockerfile.faiss_gpu | 28 + thirdparty/faiss/.circleci/config.yml | 443 +++- .../.github/actions/build_cmake/action.yml | 105 - .../.github/actions/build_conda/action.yml | 96 - thirdparty/faiss/.github/workflows/build.yml | 244 -- .../faiss/.github/workflows/nightly.yml | 139 -- thirdparty/faiss/CHANGELOG.md | 85 +- thirdparty/faiss/CMakeLists.txt | 49 +- thirdparty/faiss/CONTRIBUTING.md | 4 +- thirdparty/faiss/Doxyfile | 2 +- thirdparty/faiss/INSTALL.md | 45 +- thirdparty/faiss/README.md | 50 +- thirdparty/faiss/benchs/CMakeLists.txt | 11 - thirdparty/faiss/benchs/README.md | 39 +- thirdparty/faiss/benchs/bench_6bit_codec.cpp | 39 +- .../faiss/benchs/bench_all_ivf/README.md | 2 +- .../benchs/bench_all_ivf/bench_all_ivf.py | 596 ++--- .../benchs/bench_all_ivf/cmp_with_scann.py | 66 +- .../{datasets_oss.py => datasets.py} | 1 + .../faiss/benchs/bench_big_batch_ivf.py | 109 - .../benchs/bench_cppcontrib_sa_decode.cpp | 1702 ------------- thirdparty/faiss/benchs/bench_fw/__init__.py | 0 thirdparty/faiss/benchs/bench_fw/benchmark.py | 1178 --------- .../faiss/benchs/bench_fw/benchmark_io.py | 272 --- .../faiss/benchs/bench_fw/descriptors.py | 325 --- thirdparty/faiss/benchs/bench_fw/index.py | 1086 --------- thirdparty/faiss/benchs/bench_fw/optimize.py | 335 --- thirdparty/faiss/benchs/bench_fw/utils.py | 248 -- thirdparty/faiss/benchs/bench_fw_codecs.py | 146 -- thirdparty/faiss/benchs/bench_fw_ivf.py | 125 - .../faiss/benchs/bench_fw_notebook.ipynb | 532 ---- thirdparty/faiss/benchs/bench_fw_optimize.py | 58 - thirdparty/faiss/benchs/bench_fw_range.py | 85 - thirdparty/faiss/benchs/bench_gpu_1bn.py | 2 +- thirdparty/faiss/benchs/bench_gpu_sift1m.py | 3 +- .../faiss/benchs/bench_hamming_computer.cpp | 222 -- thirdparty/faiss/benchs/bench_hamming_knn.py | 29 - thirdparty/faiss/benchs/bench_hnsw.py | 2 +- .../faiss/benchs/bench_hnsw_knowhere.cpp | 206 -- .../faiss/benchs/bench_hybrid_cpu_gpu.py | 599 ----- thirdparty/faiss/benchs/bench_ivf_fastscan.py | 112 - .../benchs/bench_ivf_fastscan_single_query.py | 122 - .../faiss/benchs/bench_ivf_selector.cpp | 145 -- .../faiss/benchs/bench_polysemous_1bn.py | 2 +- .../bench_pq_transposed_centroid_table.py | 136 -- thirdparty/faiss/benchs/bench_quantizer.py | 35 +- .../faiss/benchs/distributed_ondisk/README.md | 147 +- .../distributed_ondisk/distributed_kmeans.py | 201 +- .../distributed_ondisk/make_index_vslice.py | 2 +- .../distributed_ondisk/merge_to_ondisk.py | 2 +- .../faiss/benchs/distributed_ondisk/rpc.py | 252 ++ .../distributed_ondisk/search_server.py | 9 +- .../faiss/benchs/link_and_code/README.md | 135 +- .../link_and_code/bench_link_and_code.py | 303 +++ .../faiss/benchs/link_and_code/datasets.py | 236 ++ .../benchs/link_and_code/neighbor_codec.py | 241 ++ .../faiss/c_api/IndexScalarQuantizer_c.h | 3 - thirdparty/faiss/c_api/clone_index_c.cpp | 12 - thirdparty/faiss/c_api/clone_index_c.h | 4 - thirdparty/faiss/c_api/index_factory_c.cpp | 16 +- thirdparty/faiss/c_api/index_factory_c.h | 11 +- thirdparty/faiss/c_api/utils/distances_c.h | 1 - thirdparty/faiss/conda/Dockerfile.cpu | 19 + thirdparty/faiss/conda/Dockerfile.cuda10.2 | 18 + thirdparty/faiss/conda/Dockerfile.cuda11.3 | 18 + .../faiss/conda/conda_build_config.yaml | 9 +- .../faiss/conda/faiss-gpu-raft/build-lib.sh | 26 - .../faiss/conda/faiss-gpu-raft/build-pkg.sh | 24 - .../faiss/conda/faiss-gpu-raft/meta.yaml | 125 - .../conda/faiss-gpu-raft/test_cpu_dispatch.sh | 11 - thirdparty/faiss/conda/faiss-gpu/build-lib.sh | 13 +- thirdparty/faiss/conda/faiss-gpu/build-pkg.sh | 7 +- .../faiss/conda/faiss-gpu/install-cmake.sh | 10 + thirdparty/faiss/conda/faiss-gpu/meta.yaml | 60 +- .../conda/faiss-gpu/test_cpu_dispatch.sh | 5 +- .../faiss/conda/faiss/build-lib-arm64.sh | 22 - thirdparty/faiss/conda/faiss/build-lib-osx.sh | 27 - thirdparty/faiss/conda/faiss/build-lib.sh | 6 +- .../faiss/conda/faiss/build-pkg-arm64.sh | 22 - thirdparty/faiss/conda/faiss/build-pkg-osx.sh | 26 - thirdparty/faiss/conda/faiss/build-pkg.sh | 6 +- thirdparty/faiss/conda/faiss/install-cmake.sh | 10 + thirdparty/faiss/conda/faiss/meta.yaml | 39 +- .../faiss/conda/faiss/test_cpu_dispatch.sh | 5 +- thirdparty/faiss/contrib/README.md | 17 +- thirdparty/faiss/contrib/big_batch_search.py | 515 ---- thirdparty/faiss/contrib/client_server.py | 2 +- thirdparty/faiss/contrib/clustering.py | 399 --- thirdparty/faiss/contrib/datasets.py | 72 +- thirdparty/faiss/contrib/evaluation.py | 254 +- thirdparty/faiss/contrib/exhaustive_search.py | 130 +- thirdparty/faiss/contrib/factory_tools.py | 42 +- thirdparty/faiss/contrib/inspect_tools.py | 27 - thirdparty/faiss/contrib/ivf_tools.py | 87 +- thirdparty/faiss/contrib/ondisk.py | 20 +- thirdparty/faiss/contrib/rpc.py | 55 +- thirdparty/faiss/contrib/torch_utils.py | 27 +- thirdparty/faiss/contrib/vecs_io.py | 12 +- thirdparty/faiss/demos/CMakeLists.txt | 3 - thirdparty/faiss/demos/demo_imi_flat.cpp | 4 +- thirdparty/faiss/demos/demo_imi_pq.cpp | 9 +- .../faiss/demos/demo_ivfpq_indexing.cpp | 4 +- thirdparty/faiss/demos/demo_nndescent.cpp | 4 +- .../faiss/demos/demo_residual_quantizer.cpp | 297 --- thirdparty/faiss/demos/demo_sift1M.cpp | 8 +- .../faiss/demos/demo_weighted_kmeans.cpp | 9 +- thirdparty/faiss/demos/offline_ivf/README.md | 52 - .../faiss/demos/offline_ivf/__init__.py | 0 .../faiss/demos/offline_ivf/config_ssnpp.yaml | 110 - .../offline_ivf/create_sharded_ssnpp_files.py | 63 - thirdparty/faiss/demos/offline_ivf/dataset.py | 173 -- .../demos/offline_ivf/generate_config.py | 45 - .../faiss/demos/offline_ivf/offline_ivf.py | 948 -------- thirdparty/faiss/demos/offline_ivf/run.py | 218 -- .../offline_ivf/tests/test_iterate_input.py | 132 - .../offline_ivf/tests/test_offline_ivf.py | 288 --- .../demos/offline_ivf/tests/testing_utils.py | 180 -- thirdparty/faiss/demos/offline_ivf/utils.py | 94 - .../faiss/demos/rocksdb_ivf/CMakeLists.txt | 8 - thirdparty/faiss/demos/rocksdb_ivf/README.md | 23 - .../rocksdb_ivf/RocksDBInvertedLists.cpp | 109 - .../demos/rocksdb_ivf/RocksDBInvertedLists.h | 60 - .../demos/rocksdb_ivf/demo_rocksdb_ivf.cpp | 81 - thirdparty/faiss/faiss/AutoTune.cpp | 46 +- thirdparty/faiss/faiss/AutoTune.h | 2 + thirdparty/faiss/faiss/CMakeLists.txt | 45 +- thirdparty/faiss/faiss/Clustering.cpp | 73 +- thirdparty/faiss/faiss/Clustering.h | 63 +- thirdparty/faiss/faiss/FaissHook.cpp | 11 - thirdparty/faiss/faiss/FaissHook.h | 14 +- thirdparty/faiss/faiss/IVFlib.cpp | 140 +- thirdparty/faiss/faiss/IVFlib.h | 30 +- thirdparty/faiss/faiss/Index.cpp | 47 +- thirdparty/faiss/faiss/Index.h | 82 +- thirdparty/faiss/faiss/Index2Layer.cpp | 26 +- thirdparty/faiss/faiss/Index2Layer.h | 10 +- .../faiss/faiss/IndexAdditiveQuantizer.cpp | 404 +-- .../faiss/faiss/IndexAdditiveQuantizer.h | 76 +- .../faiss/IndexAdditiveQuantizerFastScan.cpp | 299 --- .../faiss/IndexAdditiveQuantizerFastScan.h | 199 -- thirdparty/faiss/faiss/IndexBinary.cpp | 31 +- thirdparty/faiss/faiss/IndexBinary.h | 46 +- thirdparty/faiss/faiss/IndexBinaryFlat.cpp | 53 +- thirdparty/faiss/faiss/IndexBinaryFlat.h | 10 +- .../faiss/faiss/IndexBinaryFromFloat.cpp | 7 +- thirdparty/faiss/faiss/IndexBinaryFromFloat.h | 2 +- thirdparty/faiss/faiss/IndexBinaryHNSW.cpp | 60 +- thirdparty/faiss/faiss/IndexBinaryHNSW.h | 2 +- thirdparty/faiss/faiss/IndexBinaryHash.cpp | 100 +- thirdparty/faiss/faiss/IndexBinaryHash.h | 8 +- thirdparty/faiss/faiss/IndexBinaryIVF.cpp | 813 +++---- thirdparty/faiss/faiss/IndexBinaryIVF.h | 116 +- .../faiss/faiss/IndexBinaryIVFThreadSafe.cpp | 819 +++++++ thirdparty/faiss/faiss/IndexCosine.cpp | 517 ---- thirdparty/faiss/faiss/IndexCosine.h | 217 -- thirdparty/faiss/faiss/IndexFastScan.cpp | 562 ----- thirdparty/faiss/faiss/IndexFastScan.h | 153 -- thirdparty/faiss/faiss/IndexFlat.cpp | 308 +-- thirdparty/faiss/faiss/IndexFlat.h | 52 +- thirdparty/faiss/faiss/IndexFlatCodes.cpp | 52 +- thirdparty/faiss/faiss/IndexFlatCodes.h | 25 +- thirdparty/faiss/faiss/IndexFlatElkan.cpp | 83 - thirdparty/faiss/faiss/IndexFlatElkan.h | 60 - thirdparty/faiss/faiss/IndexHNSW.cpp | 773 +++--- thirdparty/faiss/faiss/IndexHNSW.h | 118 +- thirdparty/faiss/faiss/IndexIDMap.cpp | 288 --- thirdparty/faiss/faiss/IndexIDMap.h | 129 - thirdparty/faiss/faiss/IndexIVF.cpp | 903 +++---- thirdparty/faiss/faiss/IndexIVF.h | 367 +-- .../faiss/faiss/IndexIVFAdditiveQuantizer.cpp | 129 +- .../faiss/faiss/IndexIVFAdditiveQuantizer.h | 69 +- .../IndexIVFAdditiveQuantizerFastScan.cpp | 570 ----- .../faiss/IndexIVFAdditiveQuantizerFastScan.h | 172 -- thirdparty/faiss/faiss/IndexIVFFastScan.cpp | 1706 ------------- thirdparty/faiss/faiss/IndexIVFFastScan.h | 295 --- thirdparty/faiss/faiss/IndexIVFFlat.cpp | 390 +-- thirdparty/faiss/faiss/IndexIVFFlat.h | 21 +- .../faiss/IndexIVFIndependentQuantizer.cpp | 172 -- .../faiss/IndexIVFIndependentQuantizer.h | 56 - thirdparty/faiss/faiss/IndexIVFPQ.cpp | 636 ++--- thirdparty/faiss/faiss/IndexIVFPQ.h | 20 +- thirdparty/faiss/faiss/IndexIVFPQFastScan.cpp | 1239 +++++++++- thirdparty/faiss/faiss/IndexIVFPQFastScan.h | 159 +- thirdparty/faiss/faiss/IndexIVFPQR.cpp | 87 +- thirdparty/faiss/faiss/IndexIVFPQR.h | 12 +- .../faiss/faiss/IndexIVFScalarQuantizerCC.cpp | 137 -- .../faiss/faiss/IndexIVFScalarQuantizerCC.h | 47 - .../faiss/faiss/IndexIVFSpectralHash.cpp | 72 +- thirdparty/faiss/faiss/IndexIVFSpectralHash.h | 15 +- thirdparty/faiss/faiss/IndexIVFThreadSafe.cpp | 162 ++ thirdparty/faiss/faiss/IndexLSH.cpp | 43 +- thirdparty/faiss/faiss/IndexLSH.h | 2 +- thirdparty/faiss/faiss/IndexLattice.cpp | 10 +- thirdparty/faiss/faiss/IndexLattice.h | 3 +- thirdparty/faiss/faiss/IndexNNDescent.cpp | 45 +- thirdparty/faiss/faiss/IndexNNDescent.h | 3 +- thirdparty/faiss/faiss/IndexNSG.cpp | 66 +- thirdparty/faiss/faiss/IndexNSG.h | 43 +- thirdparty/faiss/faiss/IndexPQ.cpp | 316 ++- thirdparty/faiss/faiss/IndexPQ.h | 25 +- thirdparty/faiss/faiss/IndexPQFastScan.cpp | 461 +++- thirdparty/faiss/faiss/IndexPQFastScan.h | 82 +- thirdparty/faiss/faiss/IndexPreTransform.cpp | 105 +- thirdparty/faiss/faiss/IndexPreTransform.h | 16 +- thirdparty/faiss/faiss/IndexRefine.cpp | 113 +- thirdparty/faiss/faiss/IndexRefine.h | 11 +- thirdparty/faiss/faiss/IndexReplicas.cpp | 57 +- thirdparty/faiss/faiss/IndexReplicas.h | 3 +- thirdparty/faiss/faiss/IndexRowwiseMinMax.cpp | 445 ---- thirdparty/faiss/faiss/IndexRowwiseMinMax.h | 99 - thirdparty/faiss/faiss/IndexScaNN.cpp | 68 +- thirdparty/faiss/faiss/IndexScaNN.h | 17 +- .../faiss/faiss/IndexScalarQuantizer.cpp | 99 +- thirdparty/faiss/faiss/IndexScalarQuantizer.h | 30 +- thirdparty/faiss/faiss/IndexShards.cpp | 187 +- thirdparty/faiss/faiss/IndexShards.h | 7 +- thirdparty/faiss/faiss/IndexShardsIVF.cpp | 245 -- thirdparty/faiss/faiss/IndexShardsIVF.h | 42 - thirdparty/faiss/faiss/MatrixStats.cpp | 49 +- thirdparty/faiss/faiss/MatrixStats.h | 30 +- thirdparty/faiss/faiss/MetaIndexes.cpp | 336 ++- thirdparty/faiss/faiss/MetaIndexes.h | 112 +- thirdparty/faiss/faiss/MetricType.h | 21 +- thirdparty/faiss/faiss/VectorTransform.cpp | 96 +- thirdparty/faiss/faiss/VectorTransform.h | 44 +- thirdparty/faiss/faiss/clone_index.cpp | 275 +-- thirdparty/faiss/faiss/clone_index.h | 10 - .../faiss/faiss/cppcontrib/SaDecodeKernels.h | 322 --- .../faiss/cppcontrib/detail/CoarseBitType.h | 31 - .../faiss/cppcontrib/detail/UintReader.h | 351 --- .../knowhere/IndexBruteForceWrapper.cpp | 123 - .../knowhere/IndexBruteForceWrapper.h | 39 - .../cppcontrib/knowhere/IndexHNSWWrapper.cpp | 231 -- .../cppcontrib/knowhere/IndexHNSWWrapper.h | 56 - .../cppcontrib/knowhere/IndexWrapper.cpp | 65 - .../faiss/cppcontrib/knowhere/IndexWrapper.h | 51 - .../cppcontrib/knowhere/impl/Bruteforce.h | 73 - .../cppcontrib/knowhere/impl/HnswSearcher.h | 415 ---- .../faiss/cppcontrib/knowhere/impl/Neighbor.h | 229 -- .../faiss/cppcontrib/knowhere/utils/Bitset.h | 115 - .../cppcontrib/sa_decode/Level2-avx2-inl.h | 2072 ---------------- .../faiss/cppcontrib/sa_decode/Level2-inl.h | 467 ---- .../cppcontrib/sa_decode/Level2-neon-inl.h | 2161 ----------------- .../faiss/cppcontrib/sa_decode/MinMax-inl.h | 467 ---- .../cppcontrib/sa_decode/MinMaxFP16-inl.h | 472 ---- .../faiss/cppcontrib/sa_decode/PQ-avx2-inl.h | 1625 ------------- .../faiss/faiss/cppcontrib/sa_decode/PQ-inl.h | 257 -- .../faiss/cppcontrib/sa_decode/PQ-neon-inl.h | 1460 ----------- thirdparty/faiss/faiss/gpu/GpuIcmEncoder.cu | 12 +- .../faiss/faiss/gpu/impl/IcmEncoder.cuh | 2 +- .../faiss/faiss/impl/AdditiveQuantizer.cpp | 173 +- .../faiss/faiss/impl/AdditiveQuantizer.h | 79 +- .../faiss/faiss/impl/AuxIndexStructures.cpp | 103 +- .../faiss/faiss/impl/AuxIndexStructures.h | 96 +- thirdparty/faiss/faiss/impl/CodePacker.cpp | 67 - thirdparty/faiss/faiss/impl/CodePacker.h | 71 - .../faiss/faiss/impl/DistanceComputer.h | 131 - thirdparty/faiss/faiss/impl/FaissAssert.h | 6 +- thirdparty/faiss/faiss/impl/FaissException.h | 47 +- thirdparty/faiss/faiss/impl/HNSW.cpp | 828 +------ thirdparty/faiss/faiss/impl/HNSW.h | 96 +- thirdparty/faiss/faiss/impl/IDSelector.cpp | 125 - thirdparty/faiss/faiss/impl/IDSelector.h | 173 -- .../faiss/faiss/impl/LocalSearchQuantizer.cpp | 169 +- .../faiss/faiss/impl/LocalSearchQuantizer.h | 30 +- .../faiss/faiss/impl/LookupTableScaler.h | 77 - thirdparty/faiss/faiss/impl/NNDescent.cpp | 78 +- thirdparty/faiss/faiss/impl/NNDescent.h | 20 +- thirdparty/faiss/faiss/impl/NSG.cpp | 38 +- thirdparty/faiss/faiss/impl/NSG.h | 13 +- .../faiss/faiss/impl/PolysemousTraining.cpp | 45 +- .../faiss/faiss/impl/PolysemousTraining.h | 21 +- .../faiss/impl/ProductAdditiveQuantizer.cpp | 376 --- .../faiss/impl/ProductAdditiveQuantizer.h | 154 -- .../faiss/faiss/impl/ProductQuantizer.cpp | 457 ++-- .../faiss/faiss/impl/ProductQuantizer.h | 51 +- thirdparty/faiss/faiss/impl/Quantizer.h | 46 - .../faiss/faiss/impl/ResidualQuantizer.cpp | 740 ++++-- .../faiss/faiss/impl/ResidualQuantizer.h | 142 +- thirdparty/faiss/faiss/impl/ResultHandler.h | 516 ++-- .../faiss/faiss/impl/ScalarQuantizer.cpp | 125 +- thirdparty/faiss/faiss/impl/ScalarQuantizer.h | 276 ++- .../faiss/faiss/impl/ScalarQuantizerCodec.h | 325 +-- .../faiss/impl/ScalarQuantizerCodec_avx.h | 380 +-- .../faiss/impl/ScalarQuantizerCodec_avx512.h | 561 ++--- .../faiss/impl/ScalarQuantizerCodec_neon.h | 761 ------ .../faiss/faiss/impl/ScalarQuantizerDC.cpp | 11 +- .../faiss/faiss/impl/ScalarQuantizerDC.h | 9 +- .../faiss/impl/ScalarQuantizerDC_avx.cpp | 13 +- .../faiss/faiss/impl/ScalarQuantizerDC_avx.h | 9 +- .../faiss/impl/ScalarQuantizerDC_avx512.cpp | 17 +- .../faiss/impl/ScalarQuantizerDC_avx512.h | 9 +- .../faiss/impl/ScalarQuantizerDC_neon.cpp | 69 - .../faiss/faiss/impl/ScalarQuantizerDC_neon.h | 37 - .../faiss/faiss/impl/ScalarQuantizerOp.cpp | 116 +- .../faiss/faiss/impl/ScalarQuantizerOp.h | 56 +- .../faiss/faiss/impl/ScalarQuantizerScanner.h | 339 --- .../faiss/faiss/impl/ThreadedIndex-inl.h | 6 +- thirdparty/faiss/faiss/impl/ThreadedIndex.h | 8 +- .../impl/code_distance/code_distance-avx2.h | 534 ---- .../impl/code_distance/code_distance-avx512.h | 248 -- .../code_distance/code_distance-generic.h | 81 - .../faiss/impl/code_distance/code_distance.h | 133 - thirdparty/faiss/faiss/impl/index_read.cpp | 416 +--- thirdparty/faiss/faiss/impl/index_write.cpp | 440 +--- thirdparty/faiss/faiss/impl/io.cpp | 36 +- thirdparty/faiss/faiss/impl/io.h | 8 +- thirdparty/faiss/faiss/impl/kmeans1d.cpp | 11 +- thirdparty/faiss/faiss/impl/kmeans1d.h | 6 +- thirdparty/faiss/faiss/impl/lattice_Zn.cpp | 7 +- thirdparty/faiss/faiss/impl/platform_macros.h | 115 +- thirdparty/faiss/faiss/impl/pq4_fast_scan.cpp | 135 +- thirdparty/faiss/faiss/impl/pq4_fast_scan.h | 59 +- .../faiss/impl/pq4_fast_scan_search_1.cpp | 137 +- .../faiss/impl/pq4_fast_scan_search_qbs.cpp | 217 +- .../impl/residual_quantizer_encode_steps.cpp | 962 -------- .../impl/residual_quantizer_encode_steps.h | 176 -- .../faiss/faiss/impl/simd_result_handlers.h | 829 +++---- thirdparty/faiss/faiss/index_factory.cpp | 260 +- thirdparty/faiss/faiss/index_io.h | 22 +- .../faiss/invlists/BlockInvertedLists.cpp | 73 +- .../faiss/faiss/invlists/BlockInvertedLists.h | 13 +- thirdparty/faiss/faiss/invlists/DirectMap.cpp | 13 +- thirdparty/faiss/faiss/invlists/DirectMap.h | 5 +- .../faiss/faiss/invlists/InvertedLists.cpp | 239 +- .../faiss/faiss/invlists/InvertedLists.h | 81 +- .../faiss/invlists/OnDiskInvertedLists.cpp | 39 +- .../faiss/invlists/OnDiskInvertedLists.h | 5 +- thirdparty/faiss/faiss/python/CMakeLists.txt | 58 +- thirdparty/faiss/faiss/python/loader.py | 49 +- .../faiss/faiss/python/python_callbacks.cpp | 2 +- thirdparty/faiss/faiss/python/setup.py | 19 +- thirdparty/faiss/faiss/python/swigfaiss.swig | 387 +-- thirdparty/faiss/faiss/utils/AlignedTable.h | 4 +- thirdparty/faiss/faiss/utils/Heap.cpp | 144 +- thirdparty/faiss/faiss/utils/Heap.h | 187 +- thirdparty/faiss/faiss/utils/WorkerThread.h | 1 - .../faiss/utils/approx_topk/approx_topk.h | 84 - .../faiss/faiss/utils/approx_topk/avx2-inl.h | 196 -- .../faiss/faiss/utils/approx_topk/generic.h | 138 -- .../faiss/faiss/utils/approx_topk/mode.h | 34 - .../approx_topk_hamming/approx_topk_hamming.h | 367 --- thirdparty/faiss/faiss/utils/bf16.h | 36 - .../faiss/faiss/utils/binary_distances.cpp | 74 +- .../faiss/faiss/utils/binary_distances.h | 10 +- thirdparty/faiss/faiss/utils/bit_table.cpp | 1 - .../faiss/faiss/utils/data_backup_file.cpp | 76 - .../faiss/faiss/utils/data_backup_file.h | 31 - thirdparty/faiss/faiss/utils/distances.cpp | 962 ++------ thirdparty/faiss/faiss/utils/distances.h | 203 +- .../faiss/utils/distances_fused/avx512.cpp | 346 --- .../faiss/utils/distances_fused/avx512.h | 36 - .../utils/distances_fused/distances_fused.cpp | 42 - .../utils/distances_fused/distances_fused.h | 40 - .../utils/distances_fused/simdlib_based.cpp | 352 --- .../utils/distances_fused/simdlib_based.h | 32 - thirdparty/faiss/faiss/utils/distances_if.h | 573 ----- .../faiss/faiss/utils/distances_simd.cpp | 2 +- .../faiss/faiss/utils/extra_distances-inl.h | 52 +- .../faiss/faiss/utils/extra_distances.cpp | 68 +- .../faiss/faiss/utils/extra_distances.h | 13 +- thirdparty/faiss/faiss/utils/fp16-arm.h | 29 - thirdparty/faiss/faiss/utils/fp16-fp16c.h | 28 - thirdparty/faiss/faiss/utils/fp16-inl.h | 108 - thirdparty/faiss/faiss/utils/fp16.h | 20 - thirdparty/faiss/faiss/utils/hamming-inl.h | 293 ++- thirdparty/faiss/faiss/utils/hamming.cpp | 495 ++-- thirdparty/faiss/faiss/utils/hamming.h | 100 +- .../faiss/utils/hamming_distance/avx2-inl.h | 462 ---- .../faiss/utils/hamming_distance/common.h | 49 - .../utils/hamming_distance/generic-inl.h | 446 ---- .../faiss/utils/hamming_distance/hamdis-inl.h | 83 - .../faiss/utils/hamming_distance/neon-inl.h | 524 ---- thirdparty/faiss/faiss/utils/jaccard-inl.h | 4 - .../faiss/faiss/utils/ordered_key_value.h | 10 - thirdparty/faiss/faiss/utils/partitioning.cpp | 53 +- .../faiss/faiss/utils/partitioning_avx2.cpp | 6 +- thirdparty/faiss/faiss/utils/prefetch.h | 77 - thirdparty/faiss/faiss/utils/quantize_lut.cpp | 76 +- thirdparty/faiss/faiss/utils/quantize_lut.h | 20 - thirdparty/faiss/faiss/utils/random.cpp | 96 - thirdparty/faiss/faiss/utils/random.h | 30 - thirdparty/faiss/faiss/utils/simdlib.h | 12 +- thirdparty/faiss/faiss/utils/simdlib_avx2.h | 609 ++--- thirdparty/faiss/faiss/utils/simdlib_avx512.h | 296 --- .../faiss/faiss/utils/simdlib_emulated.h | 397 +-- thirdparty/faiss/faiss/utils/simdlib_neon.h | 792 +----- thirdparty/faiss/faiss/utils/simdlib_ppc64.h | 1084 --------- thirdparty/faiss/faiss/utils/sorting.cpp | 827 ------- thirdparty/faiss/faiss/utils/sorting.h | 101 - .../utils/transpose/transpose-avx2-inl.h | 165 -- thirdparty/faiss/faiss/utils/utils.cpp | 285 ++- thirdparty/faiss/faiss/utils/utils.h | 85 +- thirdparty/faiss/tests/CMakeLists.txt | 67 +- thirdparty/faiss/tests/common_faiss_tests.py | 1 + thirdparty/faiss/tests/test_RCQ_cropping.cpp | 131 - thirdparty/faiss/tests/test_approx_topk.cpp | 225 -- thirdparty/faiss/tests/test_binary_flat.cpp | 2 +- .../faiss/tests/test_binary_hashindex.py | 10 + thirdparty/faiss/tests/test_build_blocks.py | 234 +- thirdparty/faiss/tests/test_callback.cpp | 37 - thirdparty/faiss/tests/test_callback_py.py | 32 - thirdparty/faiss/tests/test_clone.py | 88 - thirdparty/faiss/tests/test_clustering.py | 3 + thirdparty/faiss/tests/test_code_distance.cpp | 240 -- .../tests/test_common_ivf_empty_index.cpp | 144 -- thirdparty/faiss/tests/test_contrib.py | 375 +-- .../faiss/tests/test_contrib_with_scipy.py | 87 - .../faiss/tests/test_cppcontrib_sa_decode.cpp | 1306 ---------- .../tests/test_cppcontrib_uintreader.cpp | 114 - .../faiss/tests/test_dealloc_invlists.cpp | 4 +- .../tests/test_disable_pq_sdc_tables.cpp | 61 - thirdparty/faiss/tests/test_distances_if.cpp | 141 -- .../faiss/tests/test_distances_simd.cpp | 110 - .../faiss/tests/test_extra_distances.py | 39 - thirdparty/faiss/tests/test_factory.py | 57 +- thirdparty/faiss/tests/test_fast_scan.py | 320 +-- thirdparty/faiss/tests/test_fast_scan_ivf.py | 476 +--- thirdparty/faiss/tests/test_fastscan_perf.cpp | 66 - thirdparty/faiss/tests/test_graph_based.py | 465 ---- thirdparty/faiss/tests/test_heap.cpp | 53 - thirdparty/faiss/tests/test_hnsw.cpp | 192 -- thirdparty/faiss/tests/test_index.py | 348 ++- thirdparty/faiss/tests/test_index_accuracy.py | 324 +-- thirdparty/faiss/tests/test_index_binary.py | 33 +- .../faiss/tests/test_index_composite.py | 400 +-- thirdparty/faiss/tests/test_io.py | 96 +- thirdparty/faiss/tests/test_ivf_index.cpp | 251 -- thirdparty/faiss/tests/test_ivflib.py | 7 +- thirdparty/faiss/tests/test_ivfpq_codec.cpp | 5 +- .../faiss/tests/test_ivfpq_indexing.cpp | 4 +- thirdparty/faiss/tests/test_lowlevel_ivf.cpp | 18 +- ..._local_search_quantizer.py => test_lsq.py} | 210 +- thirdparty/faiss/tests/test_mem_leak.cpp | 4 +- thirdparty/faiss/tests/test_merge.cpp | 75 +- thirdparty/faiss/tests/test_merge_index.py | 289 --- thirdparty/faiss/tests/test_meta_index.py | 197 +- thirdparty/faiss/tests/test_ondisk_ivf.cpp | 22 +- .../faiss/tests/test_pairs_decoding.cpp | 10 +- .../faiss/tests/test_params_override.cpp | 66 +- thirdparty/faiss/tests/test_partition.py | 16 +- thirdparty/faiss/tests/test_partitioning.cpp | 33 - thirdparty/faiss/tests/test_pq_encoding.cpp | 52 - .../faiss/tests/test_product_quantizer.py | 40 +- thirdparty/faiss/tests/test_refine.py | 86 +- .../faiss/tests/test_residual_quantizer.py | 335 +-- thirdparty/faiss/tests/test_rowwise_minmax.py | 55 - thirdparty/faiss/tests/test_search_params.py | 507 ---- thirdparty/faiss/tests/test_simdlib.cpp | 264 -- thirdparty/faiss/tests/test_sliding_ivf.cpp | 4 +- .../faiss/tests/test_standalone_codec.py | 50 +- .../faiss/tests/test_threaded_index.cpp | 24 +- .../faiss/tests/test_transfer_invlists.cpp | 2 +- thirdparty/faiss/tests/test_util.h | 39 - thirdparty/faiss/tests/torch_test_contrib.py | 5 +- thirdparty/faiss/tutorial/cpp/1-Flat.cpp | 6 +- thirdparty/faiss/tutorial/cpp/2-IVFFlat.cpp | 9 +- thirdparty/faiss/tutorial/cpp/3-IVFPQ.cpp | 2 +- thirdparty/faiss/tutorial/cpp/5-GPU.cpp | 234 ++ thirdparty/faiss/tutorial/cpp/6-GPU.cpp | 255 ++ thirdparty/faiss/tutorial/cpp/6-HNSW.cpp | 73 - thirdparty/faiss/tutorial/cpp/6-RUN.cpp | 247 ++ thirdparty/faiss/tutorial/cpp/7-GPU.cpp | 347 +++ .../faiss/tutorial/cpp/7-PQFastScan.cpp | 75 - thirdparty/faiss/tutorial/cpp/8-GPU.cpp | 479 ++++ .../faiss/tutorial/cpp/8-PQFastScanRefine.cpp | 84 - .../faiss/tutorial/cpp/9-BinaryFlat.cpp | 115 + .../faiss/tutorial/cpp/9-RefineComparison.cpp | 104 - thirdparty/faiss/tutorial/cpp/CMakeLists.txt | 25 +- .../tutorial/cpp/tutorial_faiss_test.cpp | 378 +++ .../faiss/tutorial/python/7-PQFastScan.py | 35 - .../tutorial/python/8-PQFastScanRefine.py | 38 - .../tutorial/python/9-RefineComparison.py | 42 - thirdparty/hnswlib/hnswlib/hnswalg.h | 668 ++--- thirdparty/hnswlib/hnswlib/hnswlib.h | 84 +- thirdparty/hnswlib/hnswlib/neighbor.h | 178 +- thirdparty/hnswlib/hnswlib/space_cosine.h | 41 +- thirdparty/hnswlib/hnswlib/space_hamming.h | 2 +- thirdparty/hnswlib/hnswlib/space_ip.h | 46 +- thirdparty/hnswlib/hnswlib/space_jaccard.h | 2 +- thirdparty/hnswlib/hnswlib/space_l2.h | 63 +- 675 files changed, 22551 insertions(+), 88215 deletions(-) create mode 100644 cmake/libs/RAPIDS.cmake delete mode 100644 cmake/libs/libcardinal.cmake create mode 100644 cmake/libs/libcutlass.cmake rename cmake/{libs/librapids.cmake => utils/fetch_rapids.cmake} (71%) delete mode 100644 include/knowhere/bitsetview_idselector.h delete mode 100644 include/knowhere/cluster/cluster.h delete mode 100644 include/knowhere/cluster/cluster_factory.h delete mode 100644 include/knowhere/cluster/cluster_node.h delete mode 100644 include/knowhere/comp/knowhere_check.h delete mode 100644 include/knowhere/comp/materialized_view.h delete mode 100644 include/knowhere/comp/task.h create mode 100644 include/knowhere/device_bitset.h create mode 100644 include/knowhere/factory.h rename include/knowhere/{index => }/index.h (81%) delete mode 100644 include/knowhere/index/index_factory.h delete mode 100644 include/knowhere/index/index_node.h delete mode 100644 include/knowhere/index/index_node_data_mock_wrapper.h delete mode 100644 include/knowhere/index/index_table.h create mode 100644 include/knowhere/index_node.h rename include/knowhere/{index => }/index_node_thread_pool_wrapper.h (85%) delete mode 100644 include/knowhere/operands.h delete mode 100644 include/knowhere/sparse_utils.h delete mode 100644 include/knowhere/tolower.h delete mode 100644 include/knowhere/tracer.h delete mode 100644 src/cluster/cluster.cc delete mode 100644 src/cluster/cluster_factory.cc delete mode 100644 src/common/comp/materialized_view.cc create mode 100644 src/common/factory.cc create mode 100644 src/common/index.cc rename src/{index => common}/index_node_thread_pool_wrapper.cc (83%) delete mode 100644 src/common/raft/integration/brute_force_index.cu delete mode 100644 src/common/raft/integration/cagra_index.cu delete mode 100644 src/common/raft/integration/cagra_instantiations.cu delete mode 100644 src/common/raft/integration/ivf_flat_index.cu delete mode 100644 src/common/raft/integration/ivf_flat_instantiations.cu delete mode 100644 src/common/raft/integration/ivf_pq_index.cu delete mode 100644 src/common/raft/integration/ivf_pq_instantiations.cu delete mode 100644 src/common/raft/integration/raft_initialization.cc delete mode 100644 src/common/raft/integration/raft_initialization.hpp delete mode 100644 src/common/raft/integration/raft_knowhere_config.hpp delete mode 100644 src/common/raft/integration/raft_knowhere_index.cuh delete mode 100644 src/common/raft/integration/raft_knowhere_index.hpp delete mode 100644 src/common/raft/integration/type_mappers.hpp delete mode 100644 src/common/raft/proto/filtered_search_instantiation.cuh delete mode 100644 src/common/raft/proto/raft_index.cuh delete mode 100644 src/common/raft/proto/raft_index_kind.hpp create mode 100644 src/common/raft/raft.cu create mode 100644 src/common/raft/raft_utils.cc create mode 100644 src/common/raft/raft_utils.h rename {include/knowhere => src/common}/range_util.h (79%) delete mode 100644 src/common/thread/thread.cc delete mode 100644 src/common/tracer.cc create mode 100644 src/index/cagra/cagra.cu create mode 100644 src/index/cagra/cagra_config.h delete mode 100644 src/index/gpu_raft/gpu_raft.h delete mode 100644 src/index/gpu_raft/gpu_raft_brute_force.cc delete mode 100644 src/index/gpu_raft/gpu_raft_brute_force_config.h delete mode 100644 src/index/gpu_raft/gpu_raft_cagra.cc delete mode 100644 src/index/gpu_raft/gpu_raft_cagra_config.h delete mode 100644 src/index/gpu_raft/gpu_raft_ivf_flat.cc delete mode 100644 src/index/gpu_raft/gpu_raft_ivf_flat_config.h delete mode 100644 src/index/gpu_raft/gpu_raft_ivf_pq.cc delete mode 100644 src/index/hnsw/faiss_hnsw.cc delete mode 100644 src/index/hnsw/faiss_hnsw_config.h delete mode 100644 src/index/hnsw/impl/FederVisitor.h delete mode 100644 src/index/hnsw/impl/IndexBruteForceWrapper.cc delete mode 100644 src/index/hnsw/impl/IndexBruteForceWrapper.h delete mode 100644 src/index/hnsw/impl/IndexHNSWWrapper.cc delete mode 100644 src/index/hnsw/impl/IndexHNSWWrapper.h delete mode 100644 src/index/hnsw/impl/IndexWrapperCosine.cc delete mode 100644 src/index/hnsw/impl/IndexWrapperCosine.h delete mode 100644 src/index/index.cc delete mode 100644 src/index/index_factory.cc delete mode 100644 src/index/index_node_data_mock_wrapper.cc create mode 100644 src/index/ivf_raft/ivf_raft.cu create mode 100644 src/index/ivf_raft/ivf_raft.cuh rename src/index/{gpu_raft/gpu_raft_ivf_pq_config.h => ivf_raft/ivf_raft_config.h} (52%) delete mode 100644 src/index/sparse/sparse_index_node.cc delete mode 100644 src/index/sparse/sparse_inverted_index.h delete mode 100644 src/index/sparse/sparse_inverted_index_config.h delete mode 100644 src/simd/simd_util.h delete mode 100644 tests/faiss/CMakeLists.txt delete mode 100644 tests/faiss_isolated/cmake/utils/platform_check.cmake delete mode 100644 tests/faiss_isolated/cmake/utils/utils.cmake delete mode 100644 tests/ut/test_binaryset.cc delete mode 100644 tests/ut/test_cluster.cc delete mode 100644 tests/ut/test_faiss_hnsw.cc delete mode 100644 tests/ut/test_index_check.cc delete mode 100644 tests/ut/test_materialized_view_search_info.cc create mode 100644 tests/ut/test_mmap.cc delete mode 100644 tests/ut/test_sparse.cc delete mode 100644 tests/ut/test_tracer.cc delete mode 100644 tests/ut/test_type.cc create mode 100644 thirdparty/DiskANN/include/diskann/windows_aligned_file_reader.h create mode 100644 thirdparty/DiskANN/include/diskann/windows_customizations.h create mode 100644 thirdparty/DiskANN/src/dll/CMakeLists.txt create mode 100644 thirdparty/DiskANN/src/dll/dllmain.cpp create mode 100644 thirdparty/DiskANN/src/windows_aligned_file_reader.cpp create mode 100644 thirdparty/faiss/.circleci/Dockerfile.cpu create mode 100644 thirdparty/faiss/.circleci/Dockerfile.faiss_gpu delete mode 100644 thirdparty/faiss/.github/actions/build_cmake/action.yml delete mode 100644 thirdparty/faiss/.github/actions/build_conda/action.yml delete mode 100644 thirdparty/faiss/.github/workflows/build.yml delete mode 100644 thirdparty/faiss/.github/workflows/nightly.yml delete mode 100644 thirdparty/faiss/benchs/CMakeLists.txt rename thirdparty/faiss/benchs/bench_all_ivf/{datasets_oss.py => datasets.py} (99%) delete mode 100644 thirdparty/faiss/benchs/bench_big_batch_ivf.py delete mode 100644 thirdparty/faiss/benchs/bench_cppcontrib_sa_decode.cpp delete mode 100644 thirdparty/faiss/benchs/bench_fw/__init__.py delete mode 100644 thirdparty/faiss/benchs/bench_fw/benchmark.py delete mode 100644 thirdparty/faiss/benchs/bench_fw/benchmark_io.py delete mode 100644 thirdparty/faiss/benchs/bench_fw/descriptors.py delete mode 100644 thirdparty/faiss/benchs/bench_fw/index.py delete mode 100644 thirdparty/faiss/benchs/bench_fw/optimize.py delete mode 100644 thirdparty/faiss/benchs/bench_fw/utils.py delete mode 100644 thirdparty/faiss/benchs/bench_fw_codecs.py delete mode 100644 thirdparty/faiss/benchs/bench_fw_ivf.py delete mode 100644 thirdparty/faiss/benchs/bench_fw_notebook.ipynb delete mode 100644 thirdparty/faiss/benchs/bench_fw_optimize.py delete mode 100644 thirdparty/faiss/benchs/bench_fw_range.py delete mode 100644 thirdparty/faiss/benchs/bench_hamming_knn.py delete mode 100644 thirdparty/faiss/benchs/bench_hnsw_knowhere.cpp delete mode 100644 thirdparty/faiss/benchs/bench_hybrid_cpu_gpu.py delete mode 100644 thirdparty/faiss/benchs/bench_ivf_fastscan.py delete mode 100644 thirdparty/faiss/benchs/bench_ivf_fastscan_single_query.py delete mode 100644 thirdparty/faiss/benchs/bench_ivf_selector.cpp delete mode 100644 thirdparty/faiss/benchs/bench_pq_transposed_centroid_table.py create mode 100755 thirdparty/faiss/benchs/distributed_ondisk/rpc.py create mode 100755 thirdparty/faiss/benchs/link_and_code/bench_link_and_code.py create mode 100755 thirdparty/faiss/benchs/link_and_code/datasets.py create mode 100755 thirdparty/faiss/benchs/link_and_code/neighbor_codec.py create mode 100644 thirdparty/faiss/conda/Dockerfile.cpu create mode 100644 thirdparty/faiss/conda/Dockerfile.cuda10.2 create mode 100644 thirdparty/faiss/conda/Dockerfile.cuda11.3 delete mode 100644 thirdparty/faiss/conda/faiss-gpu-raft/build-lib.sh delete mode 100644 thirdparty/faiss/conda/faiss-gpu-raft/build-pkg.sh delete mode 100644 thirdparty/faiss/conda/faiss-gpu-raft/meta.yaml delete mode 100755 thirdparty/faiss/conda/faiss-gpu-raft/test_cpu_dispatch.sh create mode 100755 thirdparty/faiss/conda/faiss-gpu/install-cmake.sh delete mode 100755 thirdparty/faiss/conda/faiss/build-lib-arm64.sh delete mode 100755 thirdparty/faiss/conda/faiss/build-lib-osx.sh delete mode 100755 thirdparty/faiss/conda/faiss/build-pkg-arm64.sh delete mode 100755 thirdparty/faiss/conda/faiss/build-pkg-osx.sh create mode 100755 thirdparty/faiss/conda/faiss/install-cmake.sh delete mode 100644 thirdparty/faiss/contrib/big_batch_search.py delete mode 100644 thirdparty/faiss/contrib/clustering.py delete mode 100644 thirdparty/faiss/demos/demo_residual_quantizer.cpp delete mode 100644 thirdparty/faiss/demos/offline_ivf/README.md delete mode 100644 thirdparty/faiss/demos/offline_ivf/__init__.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/config_ssnpp.yaml delete mode 100644 thirdparty/faiss/demos/offline_ivf/create_sharded_ssnpp_files.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/dataset.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/generate_config.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/offline_ivf.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/run.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/tests/test_iterate_input.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/tests/test_offline_ivf.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/tests/testing_utils.py delete mode 100644 thirdparty/faiss/demos/offline_ivf/utils.py delete mode 100644 thirdparty/faiss/demos/rocksdb_ivf/CMakeLists.txt delete mode 100644 thirdparty/faiss/demos/rocksdb_ivf/README.md delete mode 100644 thirdparty/faiss/demos/rocksdb_ivf/RocksDBInvertedLists.cpp delete mode 100644 thirdparty/faiss/demos/rocksdb_ivf/RocksDBInvertedLists.h delete mode 100644 thirdparty/faiss/demos/rocksdb_ivf/demo_rocksdb_ivf.cpp delete mode 100644 thirdparty/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp delete mode 100644 thirdparty/faiss/faiss/IndexAdditiveQuantizerFastScan.h create mode 100644 thirdparty/faiss/faiss/IndexBinaryIVFThreadSafe.cpp delete mode 100644 thirdparty/faiss/faiss/IndexCosine.cpp delete mode 100644 thirdparty/faiss/faiss/IndexCosine.h delete mode 100644 thirdparty/faiss/faiss/IndexFastScan.cpp delete mode 100644 thirdparty/faiss/faiss/IndexFastScan.h delete mode 100644 thirdparty/faiss/faiss/IndexFlatElkan.cpp delete mode 100644 thirdparty/faiss/faiss/IndexFlatElkan.h delete mode 100644 thirdparty/faiss/faiss/IndexIDMap.cpp delete mode 100644 thirdparty/faiss/faiss/IndexIDMap.h delete mode 100644 thirdparty/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp delete mode 100644 thirdparty/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h delete mode 100644 thirdparty/faiss/faiss/IndexIVFFastScan.cpp delete mode 100644 thirdparty/faiss/faiss/IndexIVFFastScan.h delete mode 100644 thirdparty/faiss/faiss/IndexIVFIndependentQuantizer.cpp delete mode 100644 thirdparty/faiss/faiss/IndexIVFIndependentQuantizer.h delete mode 100644 thirdparty/faiss/faiss/IndexIVFScalarQuantizerCC.cpp delete mode 100644 thirdparty/faiss/faiss/IndexIVFScalarQuantizerCC.h create mode 100644 thirdparty/faiss/faiss/IndexIVFThreadSafe.cpp delete mode 100644 thirdparty/faiss/faiss/IndexRowwiseMinMax.cpp delete mode 100644 thirdparty/faiss/faiss/IndexRowwiseMinMax.h delete mode 100644 thirdparty/faiss/faiss/IndexShardsIVF.cpp delete mode 100644 thirdparty/faiss/faiss/IndexShardsIVF.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/SaDecodeKernels.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/detail/CoarseBitType.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/detail/UintReader.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.cpp delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/IndexWrapper.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Bruteforce.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/HnswSearcher.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/impl/Neighbor.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/knowhere/utils/Bitset.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h delete mode 100644 thirdparty/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h delete mode 100644 thirdparty/faiss/faiss/impl/CodePacker.cpp delete mode 100644 thirdparty/faiss/faiss/impl/CodePacker.h delete mode 100644 thirdparty/faiss/faiss/impl/DistanceComputer.h delete mode 100644 thirdparty/faiss/faiss/impl/IDSelector.cpp delete mode 100644 thirdparty/faiss/faiss/impl/IDSelector.h delete mode 100644 thirdparty/faiss/faiss/impl/LookupTableScaler.h delete mode 100644 thirdparty/faiss/faiss/impl/ProductAdditiveQuantizer.cpp delete mode 100644 thirdparty/faiss/faiss/impl/ProductAdditiveQuantizer.h delete mode 100644 thirdparty/faiss/faiss/impl/Quantizer.h delete mode 100644 thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_neon.h delete mode 100644 thirdparty/faiss/faiss/impl/ScalarQuantizerDC_neon.cpp delete mode 100644 thirdparty/faiss/faiss/impl/ScalarQuantizerDC_neon.h delete mode 100644 thirdparty/faiss/faiss/impl/ScalarQuantizerScanner.h delete mode 100644 thirdparty/faiss/faiss/impl/code_distance/code_distance-avx2.h delete mode 100644 thirdparty/faiss/faiss/impl/code_distance/code_distance-avx512.h delete mode 100644 thirdparty/faiss/faiss/impl/code_distance/code_distance-generic.h delete mode 100644 thirdparty/faiss/faiss/impl/code_distance/code_distance.h delete mode 100644 thirdparty/faiss/faiss/impl/residual_quantizer_encode_steps.cpp delete mode 100644 thirdparty/faiss/faiss/impl/residual_quantizer_encode_steps.h delete mode 100644 thirdparty/faiss/faiss/utils/approx_topk/approx_topk.h delete mode 100644 thirdparty/faiss/faiss/utils/approx_topk/avx2-inl.h delete mode 100644 thirdparty/faiss/faiss/utils/approx_topk/generic.h delete mode 100644 thirdparty/faiss/faiss/utils/approx_topk/mode.h delete mode 100644 thirdparty/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h delete mode 100644 thirdparty/faiss/faiss/utils/bf16.h delete mode 100644 thirdparty/faiss/faiss/utils/data_backup_file.cpp delete mode 100644 thirdparty/faiss/faiss/utils/data_backup_file.h delete mode 100644 thirdparty/faiss/faiss/utils/distances_fused/avx512.cpp delete mode 100644 thirdparty/faiss/faiss/utils/distances_fused/avx512.h delete mode 100644 thirdparty/faiss/faiss/utils/distances_fused/distances_fused.cpp delete mode 100644 thirdparty/faiss/faiss/utils/distances_fused/distances_fused.h delete mode 100644 thirdparty/faiss/faiss/utils/distances_fused/simdlib_based.cpp delete mode 100644 thirdparty/faiss/faiss/utils/distances_fused/simdlib_based.h delete mode 100644 thirdparty/faiss/faiss/utils/distances_if.h delete mode 100644 thirdparty/faiss/faiss/utils/fp16-arm.h delete mode 100644 thirdparty/faiss/faiss/utils/fp16-fp16c.h delete mode 100644 thirdparty/faiss/faiss/utils/fp16-inl.h delete mode 100644 thirdparty/faiss/faiss/utils/fp16.h delete mode 100644 thirdparty/faiss/faiss/utils/hamming_distance/avx2-inl.h delete mode 100644 thirdparty/faiss/faiss/utils/hamming_distance/common.h delete mode 100644 thirdparty/faiss/faiss/utils/hamming_distance/generic-inl.h delete mode 100644 thirdparty/faiss/faiss/utils/hamming_distance/hamdis-inl.h delete mode 100644 thirdparty/faiss/faiss/utils/hamming_distance/neon-inl.h delete mode 100644 thirdparty/faiss/faiss/utils/prefetch.h delete mode 100644 thirdparty/faiss/faiss/utils/simdlib_avx512.h delete mode 100644 thirdparty/faiss/faiss/utils/simdlib_ppc64.h delete mode 100644 thirdparty/faiss/faiss/utils/sorting.cpp delete mode 100644 thirdparty/faiss/faiss/utils/sorting.h delete mode 100644 thirdparty/faiss/faiss/utils/transpose/transpose-avx2-inl.h delete mode 100644 thirdparty/faiss/tests/test_RCQ_cropping.cpp delete mode 100644 thirdparty/faiss/tests/test_approx_topk.cpp delete mode 100644 thirdparty/faiss/tests/test_callback.cpp delete mode 100644 thirdparty/faiss/tests/test_callback_py.py delete mode 100644 thirdparty/faiss/tests/test_clone.py delete mode 100644 thirdparty/faiss/tests/test_code_distance.cpp delete mode 100644 thirdparty/faiss/tests/test_common_ivf_empty_index.cpp delete mode 100644 thirdparty/faiss/tests/test_contrib_with_scipy.py delete mode 100644 thirdparty/faiss/tests/test_cppcontrib_sa_decode.cpp delete mode 100644 thirdparty/faiss/tests/test_cppcontrib_uintreader.cpp delete mode 100644 thirdparty/faiss/tests/test_disable_pq_sdc_tables.cpp delete mode 100644 thirdparty/faiss/tests/test_distances_if.cpp delete mode 100644 thirdparty/faiss/tests/test_distances_simd.cpp delete mode 100644 thirdparty/faiss/tests/test_fastscan_perf.cpp delete mode 100644 thirdparty/faiss/tests/test_graph_based.py delete mode 100644 thirdparty/faiss/tests/test_heap.cpp delete mode 100644 thirdparty/faiss/tests/test_hnsw.cpp delete mode 100644 thirdparty/faiss/tests/test_ivf_index.cpp rename thirdparty/faiss/tests/{test_local_search_quantizer.py => test_lsq.py} (66%) delete mode 100644 thirdparty/faiss/tests/test_merge_index.py delete mode 100644 thirdparty/faiss/tests/test_partitioning.cpp delete mode 100644 thirdparty/faiss/tests/test_rowwise_minmax.py delete mode 100644 thirdparty/faiss/tests/test_search_params.py delete mode 100644 thirdparty/faiss/tests/test_simdlib.cpp delete mode 100644 thirdparty/faiss/tests/test_util.h create mode 100644 thirdparty/faiss/tutorial/cpp/5-GPU.cpp create mode 100644 thirdparty/faiss/tutorial/cpp/6-GPU.cpp delete mode 100644 thirdparty/faiss/tutorial/cpp/6-HNSW.cpp create mode 100644 thirdparty/faiss/tutorial/cpp/6-RUN.cpp create mode 100644 thirdparty/faiss/tutorial/cpp/7-GPU.cpp delete mode 100644 thirdparty/faiss/tutorial/cpp/7-PQFastScan.cpp create mode 100644 thirdparty/faiss/tutorial/cpp/8-GPU.cpp delete mode 100644 thirdparty/faiss/tutorial/cpp/8-PQFastScanRefine.cpp create mode 100644 thirdparty/faiss/tutorial/cpp/9-BinaryFlat.cpp delete mode 100644 thirdparty/faiss/tutorial/cpp/9-RefineComparison.cpp create mode 100644 thirdparty/faiss/tutorial/cpp/tutorial_faiss_test.cpp delete mode 100644 thirdparty/faiss/tutorial/python/7-PQFastScan.py delete mode 100644 thirdparty/faiss/tutorial/python/8-PQFastScanRefine.py delete mode 100644 thirdparty/faiss/tutorial/python/9-RefineComparison.py diff --git a/cmake/libs/RAPIDS.cmake b/cmake/libs/RAPIDS.cmake new file mode 100644 index 000000000..2b9788130 --- /dev/null +++ b/cmake/libs/RAPIDS.cmake @@ -0,0 +1,73 @@ +#============================================================================= +# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================= +# +# This is the preferred entry point for projects using rapids-cmake +# + +# Allow users to control which version is used +if(NOT rapids-cmake-version) + # Define a default version if the user doesn't set one + set(rapids-cmake-version 23.12) +endif() + +# Allow users to control which GitHub repo is fetched +if(NOT rapids-cmake-repo) + # Define a default repo if the user doesn't set one + set(rapids-cmake-repo rapidsai/rapids-cmake) +endif() + +# Allow users to control which branch is fetched +if(NOT rapids-cmake-branch) + # Define a default branch if the user doesn't set one + set(rapids-cmake-branch "branch-${rapids-cmake-version}") +endif() + +# Allow users to control the exact URL passed to FetchContent +if(NOT rapids-cmake-url) + # Construct a default URL if the user doesn't set one + set(rapids-cmake-url "https://github.com/${rapids-cmake-repo}/") + # In order of specificity + if(rapids-cmake-sha) + # An exact git SHA takes precedence over anything + string(APPEND rapids-cmake-url "archive/${rapids-cmake-sha}.zip") + elseif(rapids-cmake-tag) + # Followed by a git tag name + string(APPEND rapids-cmake-url "archive/refs/tags/${rapids-cmake-tag}.zip") + else() + # Or if neither of the above two were defined, use a branch + string(APPEND rapids-cmake-url "archive/refs/heads/${rapids-cmake-branch}.zip") + endif() +endif() + +if(POLICY CMP0135) + cmake_policy(PUSH) + cmake_policy(SET CMP0135 NEW) +endif() +include(FetchContent) +FetchContent_Declare(rapids-cmake URL "${rapids-cmake-url}") +if(POLICY CMP0135) + cmake_policy(POP) +endif() +FetchContent_GetProperties(rapids-cmake) +if(rapids-cmake_POPULATED) + # Something else has already populated rapids-cmake, only thing + # we need to do is setup the CMAKE_MODULE_PATH + if(NOT "${rapids-cmake-dir}" IN_LIST CMAKE_MODULE_PATH) + list(APPEND CMAKE_MODULE_PATH "${rapids-cmake-dir}") + endif() +else() + FetchContent_MakeAvailable(rapids-cmake) +endif() diff --git a/cmake/libs/libcardinal.cmake b/cmake/libs/libcardinal.cmake deleted file mode 100644 index 5521478c5..000000000 --- a/cmake/libs/libcardinal.cmake +++ /dev/null @@ -1,74 +0,0 @@ -# Use short SHA1 as version -set(CARDINAL_VERSION v2.4.9 ) -set(CARDINAL_REPO_URL "https://github.com/zilliztech/cardinal.git") - -set(CARDINAL_REPO_DIR "${CMAKE_CURRENT_SOURCE_DIR}/thirdparty/cardinal") - -message(STATUS "Build Cardinal-${CARDINAL_VERSION}") - -# Clone and checkout cardinal with given repo url and version -if (NOT EXISTS "${CARDINAL_REPO_DIR}/.git") - execute_process(COMMAND git clone ${CARDINAL_REPO_URL} ${CARDINAL_REPO_DIR} - RESULT_VARIABLE CARDINAL_CLONE_RESULT - OUTPUT_VARIABLE CARDINAL_CLONE_OUTPUT - ERROR_VARIABLE CARDINAL_CLONE_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE - ERROR_STRIP_TRAILING_WHITESPACE) - if (NOT CARDINAL_CLONE_RESULT EQUAL "0") - message(FATAL_ERROR "Failed to clone cardinal: ${CARDINAL_CLONE_ERROR}") - else() - message(STATUS "Successfully Clone Cardinal Repo") - execute_process(COMMAND git -C ${CARDINAL_REPO_DIR} checkout ${CARDINAL_VERSION} - RESULT_VARIABLE CARDINAL_CHECKOUT_RESULT - OUTPUT_VARIABLE CARDINAL_CHECKOUT_OUTPUT - ERROR_VARIABLE CARDINAL_CHECKOUT_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE - ERROR_STRIP_TRAILING_WHITESPACE) - if (NOT CARDINAL_CHECKOUT_RESULT EQUAL "0") - message(FATAL_ERROR "Failed to checkout cardinal: ${CARDINAL_CHECKOUT_ERROR}") - else() - message(STATUS "Successfully checkout Cardinal Version : ${CARDINAL_VERSION}") - endif() - endif() -else() - execute_process( - COMMAND git -C ${CARDINAL_REPO_DIR} rev-parse HEAD - OUTPUT_VARIABLE GIT_COMMIT_HASH - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - message(STATUS "Cardinal repo already exist! git commit : ${GIT_COMMIT_HASH}") -endif() - -# Force checkout the version specified as `CARDINAL_VERSION` if `CARDINAL_VERSION_FORCE_CHECKOUT` is set -# Default do not checkout for better development convenience -if(CARDINAL_VERSION_FORCE_CHECKOUT) - message(STATUS "Checking out cardinal version ${CARDINAL_VERSION}") - - execute_process( - COMMAND git -C ${CARDINAL_REPO_DIR} fetch - RESULT_VARIABLE CARDINAL_FETCH_RESULT - OUTPUT_VARIABLE CARDINAL_FETCH_OUTPUT - ERROR_VARIABLE CARDINAL_FETCH_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_STRIP_TRAILING_WHITESPACE) - - if(NOT CARDINAL_FETCH_RESULT EQUAL "0") - message( - FATAL_ERROR "Failed to fetch cardinal: ${CARDINAL_FETCH_ERROR}") - endif() - - message(STATUS "Fetched cardinal ${CARDINAL_FETCH_OUTPUT}") - - execute_process( - COMMAND git -C ${CARDINAL_REPO_DIR} checkout ${CARDINAL_VERSION} - RESULT_VARIABLE CARDINAL_CHECKOUT_RESULT - OUTPUT_VARIABLE CARDINAL_CHECKOUT_OUTPUT - ERROR_VARIABLE CARDINAL_CHECKOUT_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_STRIP_TRAILING_WHITESPACE) - - if(NOT CARDINAL_CHECKOUT_RESULT EQUAL "0") - message( - FATAL_ERROR "Failed to checkout cardinal: ${CARDINAL_CHECKOUT_ERROR}") - endif() -endif() - -include(${CARDINAL_REPO_DIR}/know/libcardinal.cmake) diff --git a/cmake/libs/libcutlass.cmake b/cmake/libs/libcutlass.cmake new file mode 100644 index 000000000..41dae6803 --- /dev/null +++ b/cmake/libs/libcutlass.cmake @@ -0,0 +1,107 @@ +# ============================================================================= +# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +function(find_and_configure_cutlass) + set(oneValueArgs VERSION REPOSITORY PINNED_TAG) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" + ${ARGN}) + + # if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) + set(CUTLASS_ENABLE_HEADERS_ONLY + ON + CACHE BOOL "Enable only the header library") + set(CUTLASS_NAMESPACE + "raft_cutlass" + CACHE STRING "Top level namespace of CUTLASS") + set(CUTLASS_ENABLE_CUBLAS + OFF + CACHE BOOL "Disable CUTLASS to build with cuBLAS library.") + + if(CUDA_STATIC_RUNTIME) + set(CUDART_LIBRARY + "${CUDA_cudart_static_LIBRARY}" + CACHE FILEPATH "fixing cutlass cmake code" FORCE) + endif() + + rapids_cpm_find( + NvidiaCutlass + ${PKG_VERSION} + GLOBAL_TARGETS + nvidia::cutlass::cutlass + CPM_ARGS + GIT_REPOSITORY + ${PKG_REPOSITORY} + GIT_TAG + ${PKG_PINNED_TAG} + GIT_SHALLOW + TRUE + OPTIONS + "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}") + + if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) + add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) + endif() + + if(NvidiaCutlass_ADDED) + rapids_export( + BUILD + NvidiaCutlass + EXPORT_SET + NvidiaCutlass + GLOBAL_TARGETS + nvidia::cutlass::cutlass + NAMESPACE + nvidia::cutlass::) + endif() + # endif() + + # We generate the cutlass-config files when we built cutlass locally, so + # always do `find_dependency` + rapids_export_package(BUILD NvidiaCutlass raft-distance-exports + GLOBAL_TARGETS nvidia::cutlass::cutlass) + rapids_export_package(INSTALL NvidiaCutlass raft-distance-exports + GLOBAL_TARGETS nvidia::cutlass::cutlass) + rapids_export_package(BUILD NvidiaCutlass raft-nn-exports GLOBAL_TARGETS + nvidia::cutlass::cutlass) + rapids_export_package(INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS + nvidia::cutlass::cutlass) + + # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we + # wrote. + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] + raft-distance-exports) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports) + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-nn-exports) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports) +endfunction() + +if(NOT RAFT_CUTLASS_GIT_TAG) + set(RAFT_CUTLASS_GIT_TAG v2.9.1) +endif() + +if(NOT RAFT_CUTLASS_GIT_REPOSITORY) + set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git) +endif() + +find_and_configure_cutlass( + VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG + ${RAFT_CUTLASS_GIT_TAG}) diff --git a/cmake/libs/libdiskann.cmake b/cmake/libs/libdiskann.cmake index c93843a9c..75b5d4c64 100644 --- a/cmake/libs/libdiskann.cmake +++ b/cmake/libs/libdiskann.cmake @@ -24,16 +24,12 @@ set(DISKANN_SOURCES find_package(folly REQUIRED) add_library(diskann STATIC ${DISKANN_SOURCES}) -target_link_libraries( - diskann - PUBLIC ${AIO_LIBRARIES} - ${DISKANN_BOOST_PROGRAM_OPTIONS_LIB} - nlohmann_json::nlohmann_json - Folly::folly - fmt::fmt-header-only - prometheus-cpp::core - prometheus-cpp::push - glog::glog) +target_link_libraries(diskann PUBLIC ${AIO_LIBRARIES} + ${DISKANN_BOOST_PROGRAM_OPTIONS_LIB} + nlohmann_json::nlohmann_json + Folly::folly + fmt::fmt-header-only + glog::glog) if(__X86_64) target_compile_options( diskann PRIVATE -fno-builtin-malloc -fno-builtin-calloc diff --git a/cmake/libs/libfaiss.cmake b/cmake/libs/libfaiss.cmake index e9d175f59..6f44c0899 100644 --- a/cmake/libs/libfaiss.cmake +++ b/cmake/libs/libfaiss.cmake @@ -1,29 +1,21 @@ knowhere_file_glob( GLOB FAISS_SRCS thirdparty/faiss/faiss/*.cpp thirdparty/faiss/faiss/impl/*.cpp thirdparty/faiss/faiss/invlists/*.cpp - thirdparty/faiss/faiss/utils/*.cpp - thirdparty/faiss/faiss/cppcontrib/knowhere/*.cpp) + thirdparty/faiss/faiss/utils/*.cpp) knowhere_file_glob(GLOB FAISS_AVX512_SRCS thirdparty/faiss/faiss/impl/*avx512.cpp) -knowhere_file_glob( - GLOB - FAISS_AVX2_SRCS - thirdparty/faiss/faiss/impl/*avx.cpp - thirdparty/faiss/faiss/impl/pq4_fast_scan_search_1.cpp - thirdparty/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp - thirdparty/faiss/faiss/utils/partitioning_avx2.cpp - thirdparty/faiss/faiss/IndexPQFastScan.cpp - thirdparty/faiss/faiss/IndexIVFFastScan.cpp - thirdparty/faiss/faiss/IndexIVFPQFastScan.cpp) +knowhere_file_glob(GLOB FAISS_AVX2_SRCS + thirdparty/faiss/faiss/impl/*avx.cpp + thirdparty/faiss/faiss/impl/pq4_fast_scan_search_1.cpp + thirdparty/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp + thirdparty/faiss/faiss/utils/partitioning_avx2.cpp + thirdparty/faiss/faiss/IndexPQFastScan.cpp + thirdparty/faiss/faiss/IndexIVFPQFastScan.cpp) list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX512_SRCS}) -# disable RHNSW -knowhere_file_glob(GLOB FAISS_RHNSW_SRCS thirdparty/faiss/faiss/impl/RHNSW.cpp) -list(REMOVE_ITEM FAISS_SRCS ${FAISS_RHNSW_SRCS}) - if(__X86_64) set(UTILS_SRC src/simd/distances_ref.cc src/simd/hook.cc) set(UTILS_SSE_SRC src/simd/distances_sse.cc) @@ -37,7 +29,7 @@ if(__X86_64) target_compile_options(utils_sse PRIVATE -msse4.2 -mpopcnt) target_compile_options(utils_avx PRIVATE -mfma -mf16c -mavx2 -mpopcnt) target_compile_options(utils_avx512 PRIVATE -mfma -mf16c -mavx512f -mavx512dq - -mavx512bw -mpopcnt -mavx512vl) + -mavx512bw -mpopcnt) add_library( knowhere_utils STATIC @@ -47,21 +39,11 @@ if(__X86_64) endif() if(__AARCH64) - set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc - src/simd/distances_neon.cc) - add_library(knowhere_utils STATIC ${UTILS_SRC}) - target_link_libraries(knowhere_utils PUBLIC glog::glog) -endif() - -# ToDo: Add distances_vsx.cc for powerpc64 SIMD acceleration -if(__PPC64) - set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc) + set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc src/simd/distances_neon.cc) add_library(knowhere_utils STATIC ${UTILS_SRC}) target_link_libraries(knowhere_utils PUBLIC glog::glog) endif() -find_package(LAPACK REQUIRED) - if(LINUX) set(BLA_VENDOR OpenBLAS) endif() @@ -71,16 +53,24 @@ if(APPLE) endif() find_package(BLAS REQUIRED) +if(LINUX) + set(BLA_VENDOR "") +endif() + +find_package(LAPACK REQUIRED) if(__X86_64) list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX2_SRCS}) - knowhere_file_glob(GLOB FAISS_NEON_SRCS thirdparty/faiss/faiss/impl/*neon.cpp) - list(REMOVE_ITEM FAISS_SRCS ${FAISS_NEON_SRCS}) - add_library(faiss_avx2 OBJECT ${FAISS_AVX2_SRCS}) - target_compile_options(faiss_avx2 PRIVATE $<$: -msse4.2 - -mavx2 -mfma -mf16c -mpopcnt>) + target_compile_options( + faiss_avx2 + PRIVATE $<$: + -msse4.2 + -mavx2 + -mfma + -mf16c + -mpopcnt>) add_library(faiss_avx512 OBJECT ${FAISS_AVX512_SRCS}) target_compile_options( faiss_avx512 @@ -92,7 +82,6 @@ if(__X86_64) -mavx512f -mavx512dq -mavx512bw - -mavx512vl -mpopcnt>) add_library(faiss STATIC ${FAISS_SRCS}) @@ -117,40 +106,13 @@ endif() if(__AARCH64) knowhere_file_glob(GLOB FAISS_AVX_SRCS thirdparty/faiss/faiss/impl/*avx.cpp) - list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX_SRCS}) - add_library(faiss STATIC ${FAISS_SRCS}) - - target_compile_options( - faiss - PRIVATE $<$: - -Wno-sign-compare - -Wno-unused-variable - -Wno-reorder - -Wno-unused-local-typedefs - -Wno-unused-function - -Wno-strict-aliasing>) - - add_dependencies(faiss knowhere_utils) - target_link_libraries(faiss PUBLIC OpenMP::OpenMP_CXX ${BLAS_LIBRARIES} - ${LAPACK_LIBRARIES} knowhere_utils) - target_compile_definitions(faiss PRIVATE FINTEGER=int) -endif() - -if(__PPC64) - knowhere_file_glob(GLOB FAISS_AVX_SRCS thirdparty/faiss/faiss/impl/*avx.cpp) list(REMOVE_ITEM FAISS_SRCS ${FAISS_AVX_SRCS}) - - knowhere_file_glob(GLOB FAISS_NEON_SRCS thirdparty/faiss/faiss/impl/*neon.cpp) - list(REMOVE_ITEM FAISS_SRCS ${FAISS_NEON_SRCS}) - add_library(faiss STATIC ${FAISS_SRCS}) target_compile_options( faiss PRIVATE $<$: - -mcpu=native - -mvsx -Wno-sign-compare -Wno-unused-variable -Wno-reorder diff --git a/cmake/libs/libraft.cmake b/cmake/libs/libraft.cmake index dd4506f56..3de1cfaec 100644 --- a/cmake/libs/libraft.cmake +++ b/cmake/libs/libraft.cmake @@ -14,13 +14,22 @@ # the License. add_definitions(-DKNOWHERE_WITH_RAFT) -add_definitions(-DRAFT_EXPLICIT_INSTANTIATE_ONLY) -set(RAFT_VERSION "${RAPIDS_VERSION}") -set(RAFT_FORK "milvus-io") -set(RAFT_PINNED_TAG "branch-24.04") +include(cmake/utils/fetch_rapids.cmake) +include(rapids-cmake) +include(rapids-cpm) +include(rapids-cuda) +include(rapids-export) +include(rapids-find) + +rapids_cpm_init() -rapids_find_package(CUDAToolkit REQUIRED BUILD_EXPORT_SET knowhere-exports - INSTALL_EXPORT_SET knowhere-exports) +set(CMAKE_CUDA_FLAGS + "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") + +set(RAPIDS_VERSION 23.04) +set(RAFT_VERSION "${RAPIDS_VERSION}") +set(RAFT_FORK "rapidsai") +set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") function(find_and_configure_raft) set(oneValueArgs VERSION FORK PINNED_TAG) @@ -36,7 +45,7 @@ function(find_and_configure_raft) GLOBAL_TARGETS raft::raft COMPONENTS - compiled_static + ${RAFT_COMPONENTS} CPM_ARGS GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git @@ -45,26 +54,18 @@ function(find_and_configure_raft) SOURCE_SUBDIR cpp OPTIONS - "RAFT_COMPILE_LIBRARY ON" "BUILD_TESTS OFF" "BUILD_BENCH OFF" "RAFT_USE_FAISS_STATIC OFF") # Turn this on to build FAISS into your binary - if(raft_ADDED) - message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_SOURCE_DIR}") - else() - message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_DIR}") - endif() + if(raft_ADDED) + message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_SOURCE_DIR}") + else() + message(VERBOSE "KNOWHERE: Using RAFT located in ${raft_DIR}") + endif() endfunction() # Change pinned tag here to test a commit in CI To use a different RAFT locally, # set the CMake variable CPM_raft_SOURCE=/path/to/local/raft -find_and_configure_raft( - VERSION - ${RAFT_VERSION}.00 - FORK - ${RAFT_FORK} - PINNED_TAG - ${RAFT_PINNED_TAG} - COMPILE_LIBRARY - OFF) +find_and_configure_raft(VERSION ${RAFT_VERSION}.00 FORK ${RAFT_FORK} PINNED_TAG + ${RAFT_PINNED_TAG} COMPILE_LIBRARY OFF) diff --git a/cmake/utils/compile_flags.cmake b/cmake/utils/compile_flags.cmake index 8de2c2504..26c0d40f8 100644 --- a/cmake/utils/compile_flags.cmake +++ b/cmake/utils/compile_flags.cmake @@ -17,10 +17,6 @@ endif() set(CMAKE_CXX_FLAGS "-Wall -fPIC ${CMAKE_CXX_FLAGS}") -if(__X86_64) - set(CMAKE_CXX_FLAGS "-msse4.2 ${CMAKE_CXX_FLAGS}") -endif() - set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") diff --git a/cmake/libs/librapids.cmake b/cmake/utils/fetch_rapids.cmake similarity index 71% rename from cmake/libs/librapids.cmake rename to cmake/utils/fetch_rapids.cmake index a8b4e6c8c..56899f2c5 100644 --- a/cmake/libs/librapids.cmake +++ b/cmake/utils/fetch_rapids.cmake @@ -13,7 +13,7 @@ # License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION 24.04) +set(RAPIDS_VERSION "23.04") if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) file( @@ -22,15 +22,3 @@ if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) endif() include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) - -include(rapids-cpm) # Dependency tracking -include(rapids-find) # Wrappers for finding packages -include(rapids-cuda) # Common CMake CUDA logic - -rapids_cuda_init_architectures(knowhere) -message(STATUS "INIT: ${CMAKE_CUDA_ARCHITECTURES}") - -rapids_cpm_init() - -set(CMAKE_CUDA_FLAGS - "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") diff --git a/cmake/utils/platform_check.cmake b/cmake/utils/platform_check.cmake index afc41d07a..d713a2d44 100644 --- a/cmake/utils/platform_check.cmake +++ b/cmake/utils/platform_check.cmake @@ -3,12 +3,9 @@ include(CheckSymbolExists) macro(detect_target_arch) check_symbol_exists(__aarch64__ "" __AARCH64) check_symbol_exists(__x86_64__ "" __X86_64) - check_symbol_exists(__powerpc64__ "" __PPC64) - if(NOT __AARCH64 - AND NOT __X86_64 - AND NOT __PPC64) - message(FATAL "knowhere only support amd64, ppc64 and arm64 architecture.") + if(NOT __AARCH64 AND NOT __X86_64) + message(FATAL "knowhere only support amd64 and arm64.") endif() endmacro() diff --git a/include/knowhere/binaryset.h b/include/knowhere/binaryset.h index b13611d7a..4fde1e084 100644 --- a/include/knowhere/binaryset.h +++ b/include/knowhere/binaryset.h @@ -89,18 +89,6 @@ class BinarySet { return binary_map_.find(key) != binary_map_.end(); } - // Return the total size of all binary data in binary set. - size_t - Size() const { - size_t size = 0; - for (auto& pair : binary_map_) { - if (pair.second != nullptr) { - size += pair.second->size; - } - } - return size; - } - public: std::map binary_map_; }; diff --git a/include/knowhere/bitsetview.h b/include/knowhere/bitsetview.h index 464bf774b..40d0e6119 100644 --- a/include/knowhere/bitsetview.h +++ b/include/knowhere/bitsetview.h @@ -23,8 +23,7 @@ class BitsetView { BitsetView() = default; ~BitsetView() = default; - BitsetView(const uint8_t* data, size_t num_bits, size_t filtered_out_num = 0) - : bits_(data), num_bits_(num_bits), filtered_out_num_(filtered_out_num) { + BitsetView(const uint8_t* data, size_t num_bits) : bits_(data), num_bits_(num_bits) { } BitsetView(const std::nullptr_t) : BitsetView() { @@ -58,16 +57,6 @@ class BitsetView { size_t count() const { - return filtered_out_num_; - } - - float - filter_ratio() const { - return empty() ? 0.0f : ((float)filtered_out_num_ / num_bits_); - } - - size_t - get_filtered_out_num_() const { size_t ret = 0; auto len_uint8 = byte_size(); auto len_uint64 = len_uint8 >> 3; @@ -111,7 +100,6 @@ class BitsetView { private: const uint8_t* bits_ = nullptr; size_t num_bits_ = 0; - size_t filtered_out_num_ = 0; }; } // namespace knowhere diff --git a/include/knowhere/bitsetview_idselector.h b/include/knowhere/bitsetview_idselector.h deleted file mode 100644 index c09aba222..000000000 --- a/include/knowhere/bitsetview_idselector.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include - -#include "knowhere/bitsetview.h" - -namespace knowhere { - -struct BitsetViewIDSelector final : faiss::IDSelector { - const BitsetView bitset_view; - - inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} { - } - - inline bool - is_member(faiss::idx_t id) const override final { - // it is by design that bitset_view.empty() is not tested here - return (!bitset_view.test(id)); - } -}; - -} // namespace knowhere diff --git a/include/knowhere/cluster/cluster.h b/include/knowhere/cluster/cluster.h deleted file mode 100644 index 84fc340e0..000000000 --- a/include/knowhere/cluster/cluster.h +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef CLUSTER_H -#define CLUSTER_H - -#include "knowhere/binaryset.h" -#include "knowhere/cluster/cluster_node.h" -#include "knowhere/config.h" -#include "knowhere/dataset.h" -#include "knowhere/expected.h" - -namespace knowhere { -template -class Cluster { - public: - template - friend class Cluster; - - Cluster() : node(nullptr) { - } - - template - static Cluster - Create(Args&&... args) { - return Cluster(new (std::nothrow) T1(std::forward(args)...)); - } - - Cluster(const Cluster& cluster) { - if (cluster.node == nullptr) { - node = nullptr; - return; - } - cluster.node->IncRef(); - node = cluster.node; - } - - Cluster(Cluster&& cluster) { - if (cluster.node == nullptr) { - node = nullptr; - return; - } - node = cluster.node; - cluster.node = nullptr; - } - - template - Cluster(const Cluster& cluster) { - static_assert(std::is_base_of::value); - if (cluster.node == nullptr) { - node = nullptr; - return; - } - cluster.node->IncRef(); - node = cluster.node; - } - - template - Cluster(Cluster&& cluster) { - static_assert(std::is_base_of::value); - if (cluster.node == nullptr) { - node = nullptr; - return; - } - node = cluster.node; - cluster.node = nullptr; - } - - template - Cluster& - operator=(const Cluster& cluster) { - static_assert(std::is_base_of::value); - if (node != nullptr) { - node->DecRef(); - if (!node->Ref()) - delete node; - } - if (cluster.node == nullptr) { - node = nullptr; - return *this; - } - node = cluster.node; - node->IncRef(); - return *this; - } - - template - Cluster& - operator=(Cluster&& cluster) { - static_assert(std::is_base_of::value); - if (node != nullptr) { - node->DecRef(); - if (!node->Ref()) - delete node; - } - node = cluster.node; - cluster.node = nullptr; - return *this; - } - - T1* - Node() { - return node; - } - - const T1* - Node() const { - return node; - } - - expected - Train(const DataSet& dataset, const Json& json); - - expected - Assign(const DataSet& dataset); - - expected - GetCentroids() const; - - std::string - Type() const; - - ~Cluster() { - if (node == nullptr) - return; - node->DecRef(); - if (!node->Ref()) - delete node; - } - - private: - Cluster(T1* node) : node(node) { - static_assert(std::is_base_of::value); - } - - T1* node; -}; - -} // namespace knowhere - -#endif /* CLUSTER_H */ diff --git a/include/knowhere/cluster/cluster_factory.h b/include/knowhere/cluster/cluster_factory.h deleted file mode 100644 index 8b2c60c20..000000000 --- a/include/knowhere/cluster/cluster_factory.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef CLUSTER_FACTORY_H -#define CLUSTER_FACTORY_H - -#include -#include -#include - -#include "knowhere/cluster/cluster.h" -#include "knowhere/utils.h" - -namespace knowhere { -class ClusterFactory { - public: - template - expected> - Create(const std::string& name, const Object& object = nullptr); - template - const ClusterFactory& - Register(const std::string& name, std::function(const Object&)> func); - static ClusterFactory& - Instance(); - - private: - struct FunMapValueBase { - virtual ~FunMapValueBase() = default; - }; - template - struct FunMapValue : FunMapValueBase { - public: - FunMapValue(std::function& input) : fun_value(input) { - } - std::function fun_value; - }; - typedef std::map> FuncMap; - ClusterFactory(); - static FuncMap& - MapInstance(); -}; - -#define KNOWHERE_CLUSTER_CONCAT(x, y) cluster_factory_ref_##x##y -#define KNOWHERE_CLUSTER_REGISTER_GLOBAL(name, func, data_type) \ - const ClusterFactory& KNOWHERE_CLUSTER_CONCAT(name, data_type) = \ - ClusterFactory::Instance().Register(#name, func) -#define KNOWHERE_CLUSTER_SIMPLE_REGISTER_GLOBAL(name, cluster_node, data_type, ...) \ - KNOWHERE_CLUSTER_REGISTER_GLOBAL(name, \ - (static_cast> (*)(const Object&)>( \ - &Cluster>::Create)), \ - data_type) -} // namespace knowhere - -#endif /* CLUSTER_FACTORY_H */ diff --git a/include/knowhere/cluster/cluster_node.h b/include/knowhere/cluster/cluster_node.h deleted file mode 100644 index 1c3f69d0f..000000000 --- a/include/knowhere/cluster/cluster_node.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef CLUSTER_NODE_H -#define CLUSTER_NODE_H - -#include "knowhere/binaryset.h" -#include "knowhere/bitsetview.h" -#include "knowhere/config.h" -#include "knowhere/dataset.h" -#include "knowhere/expected.h" -#include "knowhere/object.h" -#include "knowhere/operands.h" - -namespace knowhere { -class ClusterNode : public Object { - public: - // kmeans train, return id_mapping - // (rows, uint32_t* id_mapping) - virtual expected - Train(const DataSet& dataset, const Config& cfg) = 0; - - // cluster assign, return id_mapping - // (rows, uint32_t* id_mapping) - virtual expected - Assign(const DataSet& dataset) = 0; - - // return centroids, must be called after trained - // (rows, dim, centroid_vector_list) - virtual expected - GetCentroids() const = 0; - - virtual std::unique_ptr - CreateConfig() const = 0; - - virtual std::string - Type() const = 0; - - virtual ~ClusterNode() { - } -}; -} // namespace knowhere - -#endif /* CLUSTER_NODE_H */ diff --git a/include/knowhere/comp/brute_force.h b/include/knowhere/comp/brute_force.h index d3a11fe49..240aa0f4b 100644 --- a/include/knowhere/comp/brute_force.h +++ b/include/knowhere/comp/brute_force.h @@ -11,47 +11,24 @@ #ifndef BRUTE_FORCE_H #define BRUTE_FORCE_H - -#include -#include - #include "knowhere/bitsetview.h" #include "knowhere/dataset.h" -#include "knowhere/index/index_factory.h" -#include "knowhere/index/index_node.h" -#include "knowhere/operands.h" +#include "knowhere/factory.h" namespace knowhere { class BruteForce { public: - template static expected Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset); - template static Status SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset); - template static expected RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset); - - // Perform row oriented sparse vector brute force search. - static expected - SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset); - - static Status - SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, sparse::label_t* ids, float* dis, - const Json& config, const BitsetView& bitset); - - template - static expected> - AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset); }; } // namespace knowhere diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 1b5c847b0..986d07241 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -31,41 +31,21 @@ constexpr const char* INDEX_FAISS_IVFFLAT_CC = "IVF_FLAT_CC"; constexpr const char* INDEX_FAISS_IVFPQ = "IVF_PQ"; constexpr const char* INDEX_FAISS_SCANN = "SCANN"; constexpr const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8"; -constexpr const char* INDEX_FAISS_IVFSQ_CC = "IVF_SQ_CC"; constexpr const char* INDEX_FAISS_GPU_IDMAP = "GPU_FAISS_FLAT"; constexpr const char* INDEX_FAISS_GPU_IVFFLAT = "GPU_FAISS_IVF_FLAT"; constexpr const char* INDEX_FAISS_GPU_IVFPQ = "GPU_FAISS_IVF_PQ"; constexpr const char* INDEX_FAISS_GPU_IVFSQ8 = "GPU_FAISS_IVF_SQ8"; -constexpr const char* INDEX_RAFT_BRUTEFORCE = "GPU_RAFT_BRUTE_FORCE"; constexpr const char* INDEX_RAFT_IVFFLAT = "GPU_RAFT_IVF_FLAT"; constexpr const char* INDEX_RAFT_IVFPQ = "GPU_RAFT_IVF_PQ"; constexpr const char* INDEX_RAFT_CAGRA = "GPU_RAFT_CAGRA"; -constexpr const char* INDEX_GPU_BRUTEFORCE = "GPU_BRUTE_FORCE"; -constexpr const char* INDEX_GPU_IVFFLAT = "GPU_IVF_FLAT"; -constexpr const char* INDEX_GPU_IVFPQ = "GPU_IVF_PQ"; -constexpr const char* INDEX_GPU_CAGRA = "GPU_CAGRA"; - constexpr const char* INDEX_HNSW = "HNSW"; -constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8"; -constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE"; constexpr const char* INDEX_DISKANN = "DISKANN"; -constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT"; -constexpr const char* INDEX_FAISS_HNSW_SQ = "FAISS_HNSW_SQ"; -constexpr const char* INDEX_FAISS_HNSW_PQ = "FAISS_HNSW_PQ"; -constexpr const char* INDEX_FAISS_HNSW_PRQ = "FAISS_HNSW_PRQ"; - -constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; -constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; } // namespace IndexEnum -namespace ClusterEnum { -constexpr const char* CLUSTER_KMEANS = "KMEANS"; -} // namespace ClusterEnum - namespace meta { constexpr const char* INDEX_TYPE = "index_type"; constexpr const char* METRIC_TYPE = "metric_type"; @@ -76,13 +56,10 @@ constexpr const char* RETRIEVE_FRIENDLY = "retrieve_friendly"; constexpr const char* DIM = "dim"; constexpr const char* TENSOR = "tensor"; constexpr const char* ROWS = "rows"; -constexpr const char* NQ = "nq"; constexpr const char* IDS = "ids"; constexpr const char* DISTANCE = "distance"; constexpr const char* LIMS = "lims"; constexpr const char* TOPK = "k"; -constexpr const char* RANGE_SEARCH_K = "range_search_k"; -constexpr const char* RETAIN_ITERATOR_ORDER = "retain_iterator_order"; constexpr const char* RADIUS = "radius"; constexpr const char* RANGE_FILTER = "range_filter"; constexpr const char* INPUT_IDS = "input_ids"; @@ -92,77 +69,24 @@ constexpr const char* NUM_BUILD_THREAD = "num_build_thread"; constexpr const char* TRACE_VISIT = "trace_visit"; constexpr const char* JSON_INFO = "json_info"; constexpr const char* JSON_ID_SET = "json_id_set"; -constexpr const char* TRACE_ID = "trace_id"; -constexpr const char* SPAN_ID = "span_id"; -constexpr const char* TRACE_FLAGS = "trace_flags"; -constexpr const char* MATERIALIZED_VIEW_SEARCH_INFO = "materialized_view_search_info"; -constexpr const char* MATERIALIZED_VIEW_OPT_FIELDS_PATH = "opt_fields_path"; -constexpr const char* MAX_EMPTY_RESULT_BUCKETS = "max_empty_result_buckets"; -constexpr const char* BM25_K1 = "bm25_k1"; -constexpr const char* BM25_B = "bm25_b"; -// average document length -constexpr const char* BM25_AVGDL = "bm25_avgdl"; -constexpr const char* WAND_BM25_MAX_SCORE_RATIO = "wand_bm25_max_score_ratio"; }; // namespace meta namespace indexparam { // IVF Params constexpr const char* NPROBE = "nprobe"; constexpr const char* NLIST = "nlist"; -constexpr const char* USE_ELKAN = "use_elkan"; constexpr const char* NBITS = "nbits"; // PQ/SQ constexpr const char* M = "m"; // PQ param for IVFPQ constexpr const char* SSIZE = "ssize"; constexpr const char* REORDER_K = "reorder_k"; constexpr const char* WITH_RAW_DATA = "with_raw_data"; -constexpr const char* ENSURE_TOPK_FULL = "ensure_topk_full"; -constexpr const char* CODE_SIZE = "code_size"; -constexpr const char* RAW_DATA_STORE_PREFIX = "raw_data_store_prefix"; -// RAFT Params -constexpr const char* REFINE_RATIO = "refine_ratio"; -constexpr const char* CACHE_DATASET_ON_DEVICE = "cache_dataset_on_device"; -// RAFT-specific IVF Params -constexpr const char* KMEANS_N_ITERS = "kmeans_n_iters"; -constexpr const char* KMEANS_TRAINSET_FRACTION = "kmeans_trainset_fraction"; -constexpr const char* ADAPTIVE_CENTERS = "adaptive_centers"; // IVF FLAT -constexpr const char* CODEBOOK_KIND = "codebook_kind"; // IVF PQ -constexpr const char* FORCE_RANDOM_ROTATION = "force_random_rotation"; // IVF PQ -constexpr const char* CONSERVATIVE_MEMORY_ALLOCATION = "conservative_memory_allocation"; // IVF PQ -constexpr const char* LUT_DTYPE = "lut_dtype"; // IVF PQ -constexpr const char* INTERNAL_DISTANCE_DTYPE = "internal_distance_dtype"; // IVF PQ -constexpr const char* PREFERRED_SHMEM_CARVEOUT = "preferred_shmem_carveout"; // IVF PQ - -// CAGRA Params -constexpr const char* INTERMEDIATE_GRAPH_DEGREE = "intermediate_graph_degree"; -constexpr const char* GRAPH_DEGREE = "graph_degree"; -constexpr const char* ITOPK_SIZE = "itopk_size"; -constexpr const char* MAX_QUERIES = "max_queries"; -constexpr const char* BUILD_ALGO = "build_algo"; -constexpr const char* SEARCH_ALGO = "search_algo"; -constexpr const char* TEAM_SIZE = "team_size"; -constexpr const char* SEARCH_WIDTH = "search_width"; -constexpr const char* MIN_ITERATIONS = "min_iterations"; -constexpr const char* MAX_ITERATIONS = "max_iterations"; -constexpr const char* THREAD_BLOCK_SIZE = "thread_block_size"; -constexpr const char* HASHMAP_MODE = "hashmap_mode"; -constexpr const char* HASHMAP_MIN_BITLEN = "hashmap_min_bitlen"; -constexpr const char* HASHMAP_MAX_FILL_RATE = "hashmap_max_fill_rate"; -constexpr const char* NN_DESCENT_NITER = "nn_descent_niter"; -constexpr const char* ADAPT_FOR_CPU = "adapt_for_cpu"; // HNSW Params constexpr const char* EFCONSTRUCTION = "efConstruction"; constexpr const char* HNSW_M = "M"; constexpr const char* EF = "ef"; +constexpr const char* SEED_EF = "seed_ef"; constexpr const char* OVERVIEW_LEVELS = "overview_levels"; - -// FAISS additional Params -constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ -constexpr const char* PRQ_NUM = "nrq"; // for PRQ, number of redisual quantizers - -// Sparse Params -constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build"; -constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search"; } // namespace indexparam using MetricType = std::string; @@ -175,15 +99,6 @@ constexpr const char* HAMMING = "HAMMING"; constexpr const char* JACCARD = "JACCARD"; constexpr const char* SUBSTRUCTURE = "SUBSTRUCTURE"; constexpr const char* SUPERSTRUCTURE = "SUPERSTRUCTURE"; -constexpr const char* BM25 = "BM25"; } // namespace metric -enum VecType { - VECTOR_BINARY = 100, - VECTOR_FLOAT = 101, - VECTOR_FLOAT16 = 102, - VECTOR_BFLOAT16 = 103, - VECTOR_SPARSE_FLOAT = 104, -}; // keep the same value as milvus proto define - } // namespace knowhere diff --git a/include/knowhere/comp/knowhere_check.h b/include/knowhere/comp/knowhere_check.h deleted file mode 100644 index a485b5ca4..000000000 --- a/include/knowhere/comp/knowhere_check.h +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef COMP_KNOWHERE_CHECKER_H -#define COMP_KNOWHERE_CHECKER_H - -#include - -#include "knowhere/comp/index_param.h" -#include "knowhere/index/index_factory.h" -#ifdef KNOWHERE_WITH_CARDINAL -#include "cardinal/cardinal_utils.h" -#endif - -namespace knowhere { -namespace KnowhereCheck { -static bool -IndexTypeAndDataTypeCheck(const std::string& index_name, VecType data_type) { - auto& static_index_table = std::get<0>(IndexFactory::StaticIndexTableInstance()); - auto key = std::pair(index_name, data_type); - if (static_index_table.find(key) != static_index_table.end()) { - return true; - } else { - return false; - } -} - -static bool -SupportMmapIndexTypeCheck(const std::string& index_name) { - auto& mmap_index_table = std::get<1>(IndexFactory::StaticIndexTableInstance()); - if (mmap_index_table.find(index_name) != mmap_index_table.end()) { - return true; - } else { - return false; - } -} - -inline bool -CheckBooleanInJson(const knowhere::Json& json, std::string key) { - if (json.find(key) == json.end()) { - return true; - } - if (json[key].is_boolean()) { - return json[key]; - } - if (json[key].is_string()) { - if (json[key] == "true") { - return true; - } else { - return false; - } - } - return false; -} - -template -bool -IndexHasRawData(const knowhere::IndexType& indexType, const knowhere::MetricType& metricType, - const knowhere::IndexVersion& version, const knowhere::Json& params) { - static std::set has_raw_data_index_set = { - IndexEnum::INDEX_FAISS_BIN_IDMAP, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, IndexEnum::INDEX_FAISS_IVFFLAT, - IndexEnum::INDEX_FAISS_IVFFLAT_CC, IndexEnum::INDEX_HNSW_SQ8_REFINE, IndexEnum::INDEX_SPARSE_INVERTED_INDEX, - IndexEnum::INDEX_SPARSE_WAND}; - static std::set has_raw_data_index_alias_set = {"IVFBIN", "BINFLAT", "IVFFLAT", "IVFFLATCC"}; - - static std::set no_raw_data_index_set = { - IndexEnum::INDEX_FAISS_IVFPQ, IndexEnum::INDEX_FAISS_IVFSQ8, IndexEnum::INDEX_HNSW_SQ8, - IndexEnum::INDEX_FAISS_GPU_IDMAP, IndexEnum::INDEX_FAISS_GPU_IVFFLAT, IndexEnum::INDEX_FAISS_GPU_IVFSQ8, - IndexEnum::INDEX_FAISS_GPU_IVFPQ, IndexEnum::INDEX_GPU_BRUTEFORCE, IndexEnum::INDEX_GPU_IVFFLAT, - IndexEnum::INDEX_GPU_IVFPQ, IndexEnum::INDEX_GPU_CAGRA, IndexEnum::INDEX_RAFT_BRUTEFORCE, - IndexEnum::INDEX_RAFT_IVFFLAT, IndexEnum::INDEX_RAFT_IVFPQ, IndexEnum::INDEX_RAFT_CAGRA, - }; - - static std::set no_raw_data_index_alias_set = {"IVFPQ", "IVFSQ"}; - - static std::set conditional_hold_raw_data_index_set = { - IndexEnum::INDEX_FAISS_IDMAP, IndexEnum::INDEX_FAISS_SCANN, IndexEnum::INDEX_FAISS_IVFSQ_CC, - IndexEnum::INDEX_HNSW, IndexEnum::INDEX_DISKANN, - }; - - if (has_raw_data_index_set.find(indexType) != has_raw_data_index_set.end() || - has_raw_data_index_alias_set.find(indexType) != has_raw_data_index_alias_set.end()) { - return true; - } - - if (no_raw_data_index_set.find(indexType) != no_raw_data_index_set.end() || - no_raw_data_index_alias_set.find(indexType) != no_raw_data_index_alias_set.end()) { - return false; - } - - if (conditional_hold_raw_data_index_set.find(indexType) != conditional_hold_raw_data_index_set.end()) { - if (indexType == IndexEnum::INDEX_HNSW) { -#ifdef KNOWHERE_WITH_CARDINAL - return IndexHoldRawData(indexType, metricType, version, params); -#else - return true; -#endif - } else if (indexType == IndexEnum::INDEX_DISKANN) { -#ifdef KNOWHERE_WITH_CARDINAL - return IndexHoldRawData(indexType, metricType, version, params); -#else - return IsMetricType(metricType, metric::L2) || IsMetricType(metricType, metric::COSINE); -#endif - } else if (indexType == IndexEnum::INDEX_FAISS_SCANN) { - return CheckBooleanInJson(params, indexparam::WITH_RAW_DATA); - // INDEX_FAISS_IVFSQ_CC is not online yet - } else if (indexType == IndexEnum::INDEX_FAISS_IVFSQ_CC) { - return params.find(indexparam::RAW_DATA_STORE_PREFIX) != params.end(); - } else if (indexType == IndexEnum::INDEX_FAISS_IDMAP) { - if (knowhere::Version(version) <= Version::GetMinimalVersion()) { - return !IsMetricType(metricType, metric::COSINE); - } else { - return true; - } - } else { - LOG_KNOWHERE_ERROR_ << "unhandled index type : " << indexType; - } - } else { - LOG_KNOWHERE_ERROR_ << "unknown index type : " << indexType; - } - return false; -} - -} // namespace KnowhereCheck -} // namespace knowhere - -#endif /* COMP_KNOWHERE_CHECKER_H */ diff --git a/include/knowhere/comp/knowhere_config.h b/include/knowhere/comp/knowhere_config.h index c1547b866..de845b6cf 100644 --- a/include/knowhere/comp/knowhere_config.h +++ b/include/knowhere/comp/knowhere_config.h @@ -40,17 +40,6 @@ class KnowhereConfig { static std::string SetSimdType(const SimdType simd_type); - /** - *The purpose of this interface is: part of the sealed indexes default to using bf16 as the base data to achieve - *higher capacity; to ensure consistency in computation between growing and sealed, it is necessary to maintain the - *same precision in growing calculations as in sealed. - */ - static void - EnablePatchForComputeFP32AsBF16(); - - static void - DisablePatchForComputeFP32AsBF16(); - /** * Set openblas threshold * if nq < use_blas_threshold, calculated by omp @@ -97,13 +86,9 @@ class KnowhereConfig { static void SetBuildThreadPoolSize(size_t num_threads); - static size_t - GetBuildThreadPoolSize(); static void SetSearchThreadPoolSize(size_t num_threads); - static size_t - GetSearchThreadPoolSize(); /** * init GPU Resource @@ -122,12 +107,6 @@ class KnowhereConfig { */ static void SetRaftMemPool(size_t init_size, size_t max_size); - - /** - * Initialize RAFT with defaults - */ - static void - SetRaftMemPool(); }; } // namespace knowhere diff --git a/include/knowhere/comp/materialized_view.h b/include/knowhere/comp/materialized_view.h deleted file mode 100644 index 232c406ea..000000000 --- a/include/knowhere/comp/materialized_view.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (C) 2019-2024 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include -#include -#include - -namespace knowhere { -// MaterializedViewSearchInfo is used to store the search information when performing filtered search (i.e. Materialized -// View - vectors and scalars). -// This information is obtained from expression analysis only, not runtime, so might be inaccurate. -struct MaterializedViewSearchInfo { - // describes which scalar field is involved during search, - // and how many categories are touched - // for example, if we have scalar field `color` with field id `111` and it has three categories: red, green, blue - // expression `color == "red"`, yields `111 -> 1` - // expression `color == "red" && color == "green"`, yields `111 -> 2` - std::unordered_map field_id_to_touched_categories_cnt; - - // whether the search exression has AND (&&) logical operator only - bool is_pure_and = true; - - // whether the search expression has NOT (!) logical unary operator - bool has_not = false; -}; - -// DO NOT CALL THIS FUNCTION MANUALLY -// use `json j = materialized_view_search_info` -void -to_json(nlohmann::json& j, const MaterializedViewSearchInfo& info); - -// DO NOT CALL THIS FUNCTION MANUALLY -// use `auto j = j.get() or j[KEY]` -void -from_json(const nlohmann::json& j, MaterializedViewSearchInfo& info); -} // namespace knowhere diff --git a/include/knowhere/comp/task.h b/include/knowhere/comp/task.h deleted file mode 100644 index 2c24342a8..000000000 --- a/include/knowhere/comp/task.h +++ /dev/null @@ -1,44 +0,0 @@ -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. -#ifndef KNOWHERE_COMP_TASK_H -#define KNOWHERE_COMP_TASK_H -#include -#include -#include -namespace knowhere { - -void -ExecOverSearchThreadPool(std::vector>& tasks); -void -ExecOverBuildThreadPool(std::vector>& tasks); -void -InitBuildThreadPool(uint32_t num_threads); -void -InitSearchThreadPool(uint32_t num_threads); -size_t -GetSearchThreadPoolSize(); -size_t -GetBuildThreadPoolSize(); -class ThreadPool { - public: - class ScopedOmpSetter { - int omp_before; - - public: - explicit ScopedOmpSetter(int num_threads = 0); - ~ScopedOmpSetter(); - }; -}; -std::unique_ptr -CreateScopeOmpSetter(int num_threads = 0); - -} // namespace knowhere - -#endif /* TASK_H */ diff --git a/include/knowhere/comp/thread_pool.h b/include/knowhere/comp/thread_pool.h index a4a64e208..a00920de6 100644 --- a/include/knowhere/comp/thread_pool.h +++ b/include/knowhere/comp/thread_pool.h @@ -12,21 +12,7 @@ #pragma once #include - -#ifdef __linux__ - -#if defined(__PPC64__) || defined(__ppc64__) || defined(__PPC64LE__) || defined(__ppc64le__) || defined(__powerpc64__) -#include -#else -#include -#endif - #include -#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30 -#include -#define gettid() syscall(SYS_gettid) -#endif -#endif #include #include @@ -35,16 +21,12 @@ #include #include "folly/executors/CPUThreadPoolExecutor.h" -#include "folly/executors/task_queue/UnboundedBlockingQueue.h" #include "folly/futures/Future.h" -#include "knowhere/expected.h" #include "knowhere/log.h" namespace knowhere { class ThreadPool { - public: - enum class QueueType { LIFO, FIFO }; #ifdef __linux__ private: class LowPriorityThreadFactory : public folly::NamedThreadFactory { @@ -65,33 +47,23 @@ class ThreadPool { }; public: - explicit ThreadPool(uint32_t num_threads, const std::string& thread_name_prefix, QueueType queueT = QueueType::LIFO) - : pool_(queueT == QueueType::LIFO - ? folly::CPUThreadPoolExecutor( - num_threads, - std::make_unique>( - num_threads * kTaskQueueFactor), - std::make_shared(thread_name_prefix)) - : folly::CPUThreadPoolExecutor( - num_threads, - std::make_unique>(), - std::make_shared(thread_name_prefix))) { + explicit ThreadPool(uint32_t num_threads, const std::string& thread_name_prefix) + : pool_(folly::CPUThreadPoolExecutor( + num_threads, + std::make_unique< + folly::LifoSemMPMCQueue>( + num_threads * kTaskQueueFactor), + std::make_shared(thread_name_prefix))) { } #else public: - explicit ThreadPool(uint32_t num_threads, const std::string& thread_name_prefix, QueueType queueT = QueueType::LIFO) - : pool_(queueT == QueueType::LIFO - ? folly::CPUThreadPoolExecutor( - num_threads, - std::make_unique>( - num_threads * kTaskQueueFactor), - std::make_shared(thread_name_prefix)) - : folly::CPUThreadPoolExecutor( - num_threads, - std::make_unique>(), - std::make_shared(thread_name_prefix))) { + explicit ThreadPool(uint32_t num_threads, const std::string& thread_name_prefix) + : pool_(folly::CPUThreadPoolExecutor( + num_threads, + std::make_unique< + folly::LifoSemMPMCQueue>( + num_threads * kTaskQueueFactor), + std::make_shared(thread_name_prefix))) { } #endif @@ -112,195 +84,102 @@ class ThreadPool { [func = std::forward(func), &args...](auto&&) mutable { return func(std::forward(args)...); }); } - [[nodiscard]] size_t + [[nodiscard]] int32_t size() const noexcept { return pool_.numThreads(); } - size_t - GetPendingTaskCount() { - return pool_.getPendingTaskCount(); - } - - void - SetNumThreads(uint32_t num_threads) { - if (num_threads == 0) { - LOG_KNOWHERE_ERROR_ << "set number of threads can not be 0"; - return; - } else { - // setNumThreads() adjust the relevant variables instead of changing the number of threads directly; - // If numThreads < active threads, reduce number of running threads. - pool_.setNumThreads(num_threads); - return; - } - } - - static ThreadPool - CreateFIFO(uint32_t num_threads, const std::string& thread_name_prefix) { - return ThreadPool(num_threads, thread_name_prefix, QueueType::FIFO); - } - - static ThreadPool - CreateLIFO(uint32_t num_threads, const std::string& thread_name_prefix) { - return ThreadPool(num_threads, thread_name_prefix, QueueType::LIFO); - } - + /** + * @brief Set the threads number to the global build thread pool of knowhere + * + * @param num_threads + */ static void - InitGlobalBuildThreadPool(uint32_t num_threads) { + InitThreadPool(uint32_t num_threads, uint32_t& thread_pool_size) { if (num_threads <= 0) { LOG_KNOWHERE_ERROR_ << "num_threads should be bigger than 0"; return; } - if (build_pool_ == nullptr) { - std::lock_guard lock(build_pool_mutex_); - if (build_pool_ == nullptr) { - build_pool_ = std::make_shared(num_threads, "knowhere_build"); - LOG_KNOWHERE_INFO_ << "Init global build thread pool with size " << num_threads; + if (thread_pool_size == 0) { + std::lock_guard lock(global_thread_pool_mutex_); + if (thread_pool_size == 0) { + thread_pool_size = num_threads; return; } - } else { - LOG_KNOWHERE_INFO_ << "Global build thread pool size has already been initialized to " - << build_pool_->size(); } } static void - InitGlobalSearchThreadPool(uint32_t num_threads) { - if (num_threads <= 0) { - LOG_KNOWHERE_ERROR_ << "num_threads should be bigger than 0"; - return; - } - - if (search_pool_ == nullptr) { - std::lock_guard lock(search_pool_mutex_); - if (search_pool_ == nullptr) { - search_pool_ = std::make_shared(num_threads, "knowhere_search"); - LOG_KNOWHERE_INFO_ << "Init global search thread pool with size " << num_threads; - return; - } - } else { - LOG_KNOWHERE_INFO_ << "Global search thread pool size has already been initialized to " - << search_pool_->size(); - } - } - - static void - SetGlobalBuildThreadPoolSize(uint32_t num_threads) { - if (build_pool_ == nullptr) { - InitGlobalBuildThreadPool(num_threads); - return; - } else { - build_pool_->SetNumThreads(num_threads); - LOG_KNOWHERE_INFO_ << "Global build thread pool size has already been set to " << build_pool_->size(); - return; - } - } - - static size_t - GetGlobalBuildThreadPoolSize() { - return (build_pool_ == nullptr ? 0 : build_pool_->size()); + InitGlobalBuildThreadPool(uint32_t num_threads) { + InitThreadPool(num_threads, global_build_thread_pool_size_); + LOG_KNOWHERE_WARNING_ << "Global Build ThreadPool has already been initialized with threads num: " + << global_build_thread_pool_size_; } + /** + * @brief Set the threads number to the global search thread pool of knowhere + * + * @param num_threads + */ static void - SetGlobalSearchThreadPoolSize(uint32_t num_threads) { - if (search_pool_ == nullptr) { - InitGlobalSearchThreadPool(num_threads); - return; - } else { - search_pool_->SetNumThreads(num_threads); - LOG_KNOWHERE_INFO_ << "Global search thread pool size has already been set to " << search_pool_->size(); - return; - } - } - - static size_t - GetGlobalSearchThreadPoolSize() { - return (search_pool_ == nullptr ? 0 : search_pool_->size()); - } - - static size_t - GetSearchThreadPoolPendingTaskCount() { - return ThreadPool::GetGlobalSearchThreadPool()->GetPendingTaskCount(); + InitGlobalSearchThreadPool(uint32_t num_threads) { + InitThreadPool(num_threads, global_search_thread_pool_size_); + LOG_KNOWHERE_WARNING_ << "Global Search ThreadPool has already been initialized with threads num: " + << global_search_thread_pool_size_; } - static size_t - GetBuildThreadPoolPendingTaskCount() { - return ThreadPool::GetGlobalBuildThreadPool()->GetPendingTaskCount(); - } + /** + * @brief Get the global thread pool of knowhere. + * + * @return ThreadPool& + */ static std::shared_ptr GetGlobalBuildThreadPool() { - if (build_pool_ == nullptr) { - InitGlobalBuildThreadPool(std::thread::hardware_concurrency()); + if (global_build_thread_pool_size_ == 0) { + InitThreadPool(std::thread::hardware_concurrency(), global_build_thread_pool_size_); + LOG_KNOWHERE_WARNING_ << "Global Build ThreadPool has not been initialized yet, init it with threads num: " + << global_build_thread_pool_size_; } - return build_pool_; + static auto pool = std::make_shared(global_build_thread_pool_size_, "Knowhere_Build"); + return pool; } static std::shared_ptr GetGlobalSearchThreadPool() { - if (search_pool_ == nullptr) { - InitGlobalSearchThreadPool(std::thread::hardware_concurrency()); + if (global_search_thread_pool_size_ == 0) { + InitThreadPool(std::thread::hardware_concurrency(), global_search_thread_pool_size_); + LOG_KNOWHERE_WARNING_ << "Global Search ThreadPool has not been initialized yet, init it with threads num: " + << global_search_thread_pool_size_; } - return search_pool_; + static auto pool = std::make_shared(global_search_thread_pool_size_, "Knowhere_Search"); + return pool; } class ScopedOmpSetter { int omp_before; -#ifdef OPENBLAS_OS_LINUX - int blas_thread_before; -#endif + public: explicit ScopedOmpSetter(int num_threads = 0) { - if (build_pool_ == nullptr) { // this should not happen in prod + if (global_build_thread_pool_size_ == 0) { // this should not happen in prod omp_before = omp_get_max_threads(); } else { - omp_before = build_pool_->size(); + omp_before = global_build_thread_pool_size_; } -#ifdef OPENBLAS_OS_LINUX - blas_thread_before = openblas_get_num_threads(); - openblas_set_num_threads(num_threads <= 0 ? blas_thread_before : num_threads); -#endif - omp_set_num_threads(num_threads <= 0 ? omp_before : num_threads); } ~ScopedOmpSetter() { omp_set_num_threads(omp_before); -#ifdef OPENBLAS_OS_LINUX - openblas_set_num_threads(blas_thread_before); -#endif } }; private: folly::CPUThreadPoolExecutor pool_; - - inline static std::mutex build_pool_mutex_; - inline static std::shared_ptr build_pool_ = nullptr; - - inline static std::mutex search_pool_mutex_; - inline static std::shared_ptr search_pool_ = nullptr; - + inline static uint32_t global_build_thread_pool_size_ = 0; + inline static uint32_t global_search_thread_pool_size_ = 0; + inline static std::mutex global_thread_pool_mutex_; constexpr static size_t kTaskQueueFactor = 16; }; - -// T is either folly::Unit or Status -template -inline Status -WaitAllSuccess(std::vector>& futures) { - static_assert(std::is_same::value || std::is_same::value, - "WaitAllSuccess can only be used with folly::Unit or knowhere::Status"); - auto allFuts = folly::collectAll(futures.begin(), futures.end()).get(); - for (const auto& result : allFuts) { - result.throwUnlessValue(); - if constexpr (!std::is_same_v) { - if (result.value() != Status::success) { - return result.value(); - } - } - } - return Status::success; -} - } // namespace knowhere diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 146938387..94575ba08 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -23,9 +23,7 @@ #include #include #include -#include -#include "knowhere/comp/materialized_view.h" #include "knowhere/expected.h" #include "knowhere/log.h" #include "nlohmann/json.hpp" @@ -46,12 +44,12 @@ typedef nlohmann::json Json; #define CFG_FLOAT std::optional #endif -#ifndef CFG_BOOL -#define CFG_BOOL std::optional +#ifndef CFG_LIST +#define CFG_LIST std::optional> #endif -#ifndef CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE -#define CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE std::optional +#ifndef CFG_BOOL +#define CFG_BOOL std::optional #endif template @@ -65,18 +63,17 @@ enum PARAM_TYPE { DESERIALIZE = 1 << 4, DESERIALIZE_FROM_FILE = 1 << 5, ITERATOR = 1 << 6, - CLUSTER = 1 << 7, }; template <> struct Entry { - explicit Entry(CFG_STRING* v) { + explicit Entry(CFG_STRING* v) { val = v; type = 0x0; default_val = std::nullopt; desc = std::nullopt; } - Entry() { + Entry() { val = nullptr; type = 0x0; default_val = std::nullopt; @@ -91,14 +88,14 @@ struct Entry { template <> struct Entry { - explicit Entry(CFG_FLOAT* v) { + explicit Entry(CFG_FLOAT* v) { val = v; default_val = std::nullopt; type = 0x0; range = std::nullopt; desc = std::nullopt; } - Entry() { + Entry() { val = nullptr; default_val = std::nullopt; type = 0x0; @@ -116,14 +113,14 @@ struct Entry { template <> struct Entry { - explicit Entry(CFG_INT* v) { + explicit Entry(CFG_INT* v) { val = v; default_val = std::nullopt; type = 0x0; range = std::nullopt; desc = std::nullopt; } - Entry() { + Entry() { val = nullptr; default_val = std::nullopt; type = 0x0; @@ -140,46 +137,46 @@ struct Entry { }; template <> -struct Entry { - explicit Entry(CFG_BOOL* v) { +struct Entry { + explicit Entry(CFG_LIST* v) { val = v; default_val = std::nullopt; type = 0x0; desc = std::nullopt; } - Entry() { + Entry() { val = nullptr; default_val = std::nullopt; type = 0x0; desc = std::nullopt; } - CFG_BOOL* val; - std::optional default_val; + CFG_LIST* val; + std::optional default_val; uint32_t type; std::optional desc; bool allow_empty_without_default = false; }; template <> -struct Entry { - explicit Entry(CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE* v) { +struct Entry { + explicit Entry(CFG_BOOL* v) { val = v; default_val = std::nullopt; type = 0x0; desc = std::nullopt; } - Entry() { + Entry() { val = nullptr; default_val = std::nullopt; type = 0x0; desc = std::nullopt; } - CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE* val; - std::optional default_val; + CFG_BOOL* val; + std::optional default_val; uint32_t type; std::optional desc; bool allow_empty_without_default = false; @@ -245,12 +242,6 @@ class EntryAccess { return *this; } - EntryAccess& - for_cluster() { - entry->type |= PARAM_TYPE::CLUSTER; - return *this; - } - EntryAccess& for_deserialize() { entry->type |= PARAM_TYPE::DESERIALIZE; @@ -282,12 +273,6 @@ class Config { static Status Load(Config& cfg, const Json& json, PARAM_TYPE type, std::string* const err_msg = nullptr) { - auto show_err_msg = [&](std::string& msg) { - LOG_KNOWHERE_ERROR_ << msg; - if (err_msg) { - *err_msg = msg; - } - }; for (const auto& it : cfg.__DICT__) { const auto& var = it.second; @@ -295,43 +280,47 @@ class Config { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end()) { - if (!ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { - continue; - } - std::string msg = "param '" + it.first + "' not exist in json"; - show_err_msg(msg); - return Status::invalid_param_in_json; - } else { - *ptr->val = ptr->default_val; + if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { continue; } + LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; + if (err_msg) { + *err_msg = std::string("invalid param ") + it.first; + } + return Status::invalid_param_in_json; + } + if (json.find(it.first) == json.end()) { + *ptr->val = ptr->default_val; + continue; } if (!json[it.first].is_number_integer()) { - std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + - ") should be integer"; - show_err_msg(msg); + LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be integer."; + if (err_msg) { + *err_msg = std::string("param ") + it.first + " should be integer"; + } return Status::type_conflict_in_json; } if (ptr->range.has_value()) { if (json[it.first].get() > std::numeric_limits::max()) { - std::string msg = "Arithmetic overflow: param '" + it.first + "' (" + - to_string(json[it.first]) + ") should not bigger than " + - std::to_string(std::numeric_limits::max()); - show_err_msg(msg); + LOG_KNOWHERE_ERROR_ << "Arithmetic overflow: param [" << it.first << "] should be at most " + << std::numeric_limits::max(); + if (err_msg) { + *err_msg = std::string("param ") + it.first + " should be at most 2147483647"; + } return Status::arithmetic_overflow; } CFG_INT::value_type v = json[it.first]; - auto range_val = ptr->range.value(); - if (range_val.first <= v && v <= range_val.second) { + if (ptr->range.value().first <= v && v <= ptr->range.value().second) { *ptr->val = v; } else { - std::string msg = "Out of range in json: param '" + it.first + "' (" + - to_string(json[it.first]) + ") should be in range [" + - std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + - "]"; - show_err_msg(msg); + LOG_KNOWHERE_ERROR_ << "Out of range in json: param [" << it.first << "] should be in [" + << ptr->range.value().first << ", " << ptr->range.value().second << "]."; + if (err_msg) { + *err_msg = std::string("param ") + it.first + " out of range " + "[ " + + std::to_string(ptr->range.value().first) + "," + + std::to_string(ptr->range.value().second) + " ]"; + } return Status::out_of_range_in_json; } } else { @@ -343,43 +332,51 @@ class Config { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end()) { - if (!ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { - continue; - } - std::string msg = "param '" + it.first + "' not exist in json"; - show_err_msg(msg); - return Status::invalid_param_in_json; - } else { - *ptr->val = ptr->default_val; + if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { continue; } + LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; + if (err_msg) { + *err_msg = std::string("invalid param ") + it.first; + } + + return Status::invalid_param_in_json; + } + if (json.find(it.first) == json.end()) { + *ptr->val = ptr->default_val; + continue; } if (!json[it.first].is_number()) { - std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + - ") should be a number"; - show_err_msg(msg); + LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be a number."; + if (err_msg) { + *err_msg = std::string("param ") + it.first + " should be a number"; + } + return Status::type_conflict_in_json; } if (ptr->range.has_value()) { if (json[it.first].get() > std::numeric_limits::max()) { - std::string msg = "Arithmetic overflow: param '" + it.first + "' (" + - to_string(json[it.first]) + ") should not bigger than " + - std::to_string(std::numeric_limits::max()); - show_err_msg(msg); + LOG_KNOWHERE_ERROR_ << "Arithmetic overflow: param [" << it.first << "] should be at most " + << std::numeric_limits::max(); + if (err_msg) { + *err_msg = std::string("param ") + it.first + " should be at most 3.402823e+38"; + } + return Status::arithmetic_overflow; } CFG_FLOAT::value_type v = json[it.first]; - auto range_val = ptr->range.value(); - if (range_val.first <= v && v <= range_val.second) { + if (ptr->range.value().first <= v && v <= ptr->range.value().second) { *ptr->val = v; } else { - std::string msg = "Out of range in json: param '" + it.first + "' (" + - to_string(json[it.first]) + ") should be in range [" + - std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + - "]"; - show_err_msg(msg); + LOG_KNOWHERE_ERROR_ << "Out of range in json: param [" << it.first << "] should be in [" + << ptr->range.value().first << ", " << ptr->range.value().second << "]."; + if (err_msg) { + *err_msg = std::string("param ") + it.first + " out of range " + "[ " + + std::to_string(ptr->range.value().first) + "," + + std::to_string(ptr->range.value().second) + " ]"; + } + return Status::out_of_range_in_json; } } else { @@ -391,103 +388,110 @@ class Config { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end()) { - if (!ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { - continue; - } - std::string msg = "param [" + it.first + "] not exist in json"; - show_err_msg(msg); - return Status::invalid_param_in_json; - } else { - *ptr->val = ptr->default_val; + if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { continue; } + LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; + if (err_msg) { + *err_msg = std::string("invalid param ") + it.first; + } + return Status::invalid_param_in_json; + } + if (json.find(it.first) == json.end()) { + *ptr->val = ptr->default_val; + continue; } if (!json[it.first].is_string()) { - std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + - ") should be a string"; - show_err_msg(msg); + LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be a string."; + if (err_msg) { + *err_msg = std::string("param ") + it.first + " should be a string"; + } return Status::type_conflict_in_json; } *ptr->val = json[it.first]; } - if (const Entry* ptr = std::get_if>(&var)) { + if (const Entry* ptr = std::get_if>(&var)) { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end()) { - if (!ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { - continue; - } - std::string msg = "param '" + it.first + "' not exist in json"; - show_err_msg(msg); - return Status::invalid_param_in_json; - } else { - *ptr->val = ptr->default_val; + if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { continue; } + LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; + if (err_msg) { + *err_msg = std::string("invalid param ") + it.first; + } + + return Status::invalid_param_in_json; } - if (!json[it.first].is_boolean()) { - std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + - ") should be a boolean"; - show_err_msg(msg); + if (json.find(it.first) == json.end()) { + *ptr->val = ptr->default_val; + continue; + } + if (!json[it.first].is_array()) { + LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be an array."; + if (err_msg) { + *err_msg = std::string("param ") + it.first + " should be an array"; + } + return Status::type_conflict_in_json; } - *ptr->val = json[it.first]; + *ptr->val = CFG_LIST::value_type(); + for (auto&& i : json[it.first]) { + ptr->val->value().push_back(i); + } } - if (const Entry* ptr = - std::get_if>(&var)) { + if (const Entry* ptr = std::get_if>(&var)) { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end()) { - if (!ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { - continue; - } - std::string msg = "param '" + it.first + "' not exist in json"; - show_err_msg(msg); - return Status::invalid_param_in_json; - } else { - *ptr->val = ptr->default_val; + if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { continue; } + LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; + if (err_msg) { + *err_msg = std::string("invalid param ") + it.first; + } + + return Status::invalid_param_in_json; + } + if (json.find(it.first) == json.end()) { + *ptr->val = ptr->default_val; + continue; + } + if (!json[it.first].is_boolean()) { + LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be a boolean."; + if (err_msg) { + *err_msg = std::string("param ") + it.first + " should be a boolean"; + } + + return Status::type_conflict_in_json; } *ptr->val = json[it.first]; } } - if (!err_msg) { - std::string tem_msg; - return cfg.CheckAndAdjust(type, &tem_msg); - } - return cfg.CheckAndAdjust(type, err_msg); + return Status::success; } virtual ~Config() { } - using VarEntry = std::variant, Entry, Entry, Entry, - Entry>; + using VarEntry = + std::variant, Entry, Entry, Entry, Entry>; std::unordered_map __DICT__; - - protected: - inline virtual Status - CheckAndAdjust(PARAM_TYPE param_type, std::string* const err_msg) { - return Status::success; - } }; #define KNOHWERE_DECLARE_CONFIG(CONFIG) CONFIG() -#define KNOWHERE_CONFIG_DECLARE_FIELD(PARAM) \ - __DICT__[#PARAM] = knowhere::Config::VarEntry(std::in_place_type>, &PARAM); \ - knowhere::EntryAccess PARAM##_access( \ - std::get_if>(&__DICT__[#PARAM])); \ +#define KNOWHERE_CONFIG_DECLARE_FIELD(PARAM) \ + __DICT__[#PARAM] = knowhere::Config::VarEntry(std::in_place_type>, &PARAM); \ + EntryAccess PARAM##_access(std::get_if>(&__DICT__[#PARAM])); \ PARAM##_access const float defaultRangeFilter = 1.0f / 0.0; @@ -500,42 +504,17 @@ class BaseConfig : public Config { CFG_BOOL retrieve_friendly; CFG_STRING data_path; CFG_STRING index_prefix; - // for distance metrics, we search for vectors with distance in [range_filter, radius). - // for similarity metrics, we search for vectors with similarity in (radius, range_filter]. CFG_FLOAT radius; - CFG_INT range_search_k; CFG_FLOAT range_filter; - CFG_FLOAT range_search_level; - CFG_BOOL retain_iterator_order; CFG_BOOL trace_visit; CFG_BOOL enable_mmap; - CFG_BOOL enable_mmap_pop; CFG_BOOL for_tuning; - CFG_BOOL shuffle_build; - CFG_STRING trace_id; - CFG_STRING span_id; - CFG_INT trace_flags; - CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE materialized_view_search_info; - CFG_STRING opt_fields_path; - CFG_FLOAT iterator_refine_ratio; - /** - * k1, b, avgdl are used by BM25 metric only. - * - k1, b, avgdl must be provided at load time. - * - k1 and b can be overridden at search time for SPARSE_INVERTED_INDEX - * but not for SPARSE_WAND. - * - avgdl must always be provided at search time. - */ - CFG_FLOAT bm25_k1; - CFG_FLOAT bm25_b; - CFG_FLOAT bm25_avgdl; KNOHWERE_DECLARE_CONFIG(BaseConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(metric_type) .set_default("L2") .description("metric type") .for_train_and_search() - .for_iterator() - .for_deserialize() - .for_deserialize_from_file(); + .for_deserialize(); KNOWHERE_CONFIG_DECLARE_FIELD(retrieve_friendly) .description("whether the index holds raw data for fast retrieval") .set_default(false) @@ -563,20 +542,10 @@ class BaseConfig : public Config { .set_default(0.0) .description("radius for range search") .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(range_search_k) - .set_default(-1) - .description("limit the number of similar results returned by range_search. -1 means no limitations.") - .set_range(-1, std::numeric_limits::max()) - .for_range_search(); KNOWHERE_CONFIG_DECLARE_FIELD(range_filter) .set_default(defaultRangeFilter) .description("result filter for range search") .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(range_search_level) - .set_default(0.01f) - .description("control the accurancy of range search, [0.0 - 0.5], the larger the more accurate") - .set_range(0, 0.5) - .for_range_search(); KNOWHERE_CONFIG_DECLARE_FIELD(trace_visit) .set_default(false) .description("trace visit for feder") @@ -585,80 +554,28 @@ class BaseConfig : public Config { KNOWHERE_CONFIG_DECLARE_FIELD(enable_mmap) .set_default(false) .description("enable mmap for load index") - .for_deserialize() - .for_deserialize_from_file(); - KNOWHERE_CONFIG_DECLARE_FIELD(enable_mmap_pop) - .set_default(false) - .description("enable map_populate option for mmap") - .for_deserialize() .for_deserialize_from_file(); KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(shuffle_build) - .set_default(true) - .description("shuffle ids before index building") - .for_train(); - KNOWHERE_CONFIG_DECLARE_FIELD(trace_id) - .description("trace id") - .allow_empty_without_default() - .for_search() - .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(span_id) - .description("span id") - .allow_empty_without_default() - .for_search() - .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(trace_flags) - .set_default(0) - .description("trace flags") - .for_search() - .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(materialized_view_search_info) - .description("materialized view search info") - .allow_empty_without_default() - .for_search() - .for_iterator() - .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(opt_fields_path) - .description("materialized view optional fields path") - .allow_empty_without_default() - .for_train(); - KNOWHERE_CONFIG_DECLARE_FIELD(iterator_refine_ratio) - .set_default(0.5) - .description("refine ratio for iterator") - .for_iterator() - .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(retain_iterator_order) - .set_default(false) - .description("whether the result of iterator monotonically ordered") - .for_iterator() - .for_range_search(); - KNOWHERE_CONFIG_DECLARE_FIELD(bm25_k1) - .allow_empty_without_default() - .set_range(0.0, 3.0) - .description("BM25 k1 to tune the term frequency scaling factor") - .for_train_and_search() - .for_iterator() - .for_deserialize() - .for_deserialize_from_file(); - KNOWHERE_CONFIG_DECLARE_FIELD(bm25_b) - .allow_empty_without_default() - .set_range(0.0, 1.0) - .description("BM25 beta to tune the document length scaling factor") - .for_train_and_search() - .for_iterator() - .for_deserialize() - .for_deserialize_from_file(); - // This must be provided in any BM25 type search request. - // This is necessary for building/training/deserializing only if the index - // type is WAND. - KNOWHERE_CONFIG_DECLARE_FIELD(bm25_avgdl) - .allow_empty_without_default() - .set_range(1, std::numeric_limits::max()) - .description("average document length") - .for_train_and_search() - .for_iterator() - .for_deserialize() - .for_deserialize_from_file(); + } + + virtual Status + CheckAndAdjustForSearch(std::string* err_msg) { + return Status::success; + } + + virtual Status + CheckAndAdjustForRangeSearch(std::string* err_msg) { + return Status::success; + } + + virtual Status + CheckAndAdjustForIterator() { + return Status::success; + } + + virtual inline Status + CheckAndAdjustForBuild() { + return Status::success; } }; } // namespace knowhere diff --git a/include/knowhere/dataset.h b/include/knowhere/dataset.h index 1a7b37169..c77decaf3 100644 --- a/include/knowhere/dataset.h +++ b/include/knowhere/dataset.h @@ -21,12 +21,10 @@ #include #include "comp/index_param.h" -#include "knowhere/range_util.h" -#include "knowhere/sparse_utils.h" namespace knowhere { -class DataSet : public std::enable_shared_from_this { +class DataSet { public: typedef std::variant Var; DataSet() = default; @@ -56,11 +54,7 @@ class DataSet : public std::enable_shared_from_this { { auto ptr = std::get_if<3>(&x.second); if (ptr != nullptr) { - if (is_sparse) { - delete[](sparse::SparseRow*)(*ptr); - } else { - delete[](char*)(*ptr); - } + delete[](char*)(*ptr); } } } @@ -72,64 +66,24 @@ class DataSet : public std::enable_shared_from_this { this->data_[meta::DISTANCE] = Var(std::in_place_index<0>, dis); } - void - SetDistance(std::unique_ptr&& dis) { - std::unique_lock lock(mutex_); - this->data_[meta::DISTANCE] = Var(std::in_place_index<0>, dis.release()); - } - void SetLims(const size_t* lims) { std::unique_lock lock(mutex_); this->data_[meta::LIMS] = Var(std::in_place_index<1>, lims); } - void - SetLims(std::unique_ptr&& lims) { - std::unique_lock lock(mutex_); - this->data_[meta::LIMS] = Var(std::in_place_index<1>, lims.release()); - } - void SetIds(const int64_t* ids) { std::unique_lock lock(mutex_); this->data_[meta::IDS] = Var(std::in_place_index<2>, ids); } - void - SetIds(std::unique_ptr&& ids) { - static_assert(sizeof(long int) == sizeof(int64_t)); - - std::unique_lock lock(mutex_); - this->data_[meta::IDS] = Var(std::in_place_index<2>, reinterpret_cast(ids.release())); - } - - void - SetIds(std::unique_ptr&& ids) { - static_assert(sizeof(long long int) == sizeof(int64_t)); - - std::unique_lock lock(mutex_); - this->data_[meta::IDS] = Var(std::in_place_index<2>, reinterpret_cast(ids.release())); - } - - /** - * For dense float vector, tensor is a rows * dim float array - * For sparse float vector, tensor is pointer to sparse::Sparse* - * and values in each row should be sorted by column id. - */ void SetTensor(const void* tensor) { std::unique_lock lock(mutex_); this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor); } - template - void - SetTensor(std::unique_ptr&& tensor) { - std::unique_lock lock(mutex_); - this->data_[meta::TENSOR] = Var(std::in_place_index<3>, tensor.release()); - } - void SetRows(const int64_t rows) { std::unique_lock lock(mutex_); @@ -248,18 +202,6 @@ class DataSet : public std::enable_shared_from_this { this->is_owner = is_owner; } - bool - GetIsSparse() const { - std::unique_lock lock(mutex_); - return this->is_sparse; - } - - void - SetIsSparse(bool is_sparse) { - std::unique_lock lock(mutex_); - this->is_sparse = is_sparse; - } - // deprecated API template void @@ -283,7 +225,6 @@ class DataSet : public std::enable_shared_from_this { mutable std::shared_mutex mutex_; std::map data_; bool is_owner = true; - bool is_sparse = false; }; using DataSetPtr = std::shared_ptr; @@ -297,16 +238,13 @@ GenDataSet(const int64_t nb, const int64_t dim, const void* xb) { return ret_ds; } -// swig won't compile when using int64_t* or size_t* as parameter -inline DataSetPtr #ifdef NOT_COMPILE_FOR_SWIG +// TOOD: python wheel build error for this API, need check +inline DataSetPtr GenIdsDataSet(const int64_t rows, const int64_t* ids) { -#else -GenIdsDataSet(const int64_t rows, const void* ids) { -#endif auto ret_ds = std::make_shared(); ret_ds->SetRows(rows); - ret_ds->SetIds((const int64_t*)ids); + ret_ds->SetIds(ids); ret_ds->SetIsOwner(false); return ret_ds; } @@ -321,84 +259,24 @@ GenResultDataSet(const int64_t rows, const int64_t dim, const void* tensor) { return ret_ds; } -template -inline DataSetPtr -GenResultDataSet(const int64_t rows, const int64_t dim, std::unique_ptr&& tensor) { - auto ret_ds = std::make_shared(); - ret_ds->SetRows(rows); - ret_ds->SetDim(dim); - ret_ds->SetTensor(std::move(tensor)); - ret_ds->SetIsOwner(true); - return ret_ds; -} - inline DataSetPtr -#ifdef NOT_COMPILE_FOR_SWIG GenResultDataSet(const int64_t nq, const int64_t topk, const int64_t* ids, const float* distance) { -#else -GenResultDataSet(const int64_t nq, const int64_t topk, const void* ids, const float* distance) { -#endif - static_assert(sizeof(int64_t) == sizeof(long long int)); - auto ret_ds = std::make_shared(); ret_ds->SetRows(nq); ret_ds->SetDim(topk); - ret_ds->SetIds((const int64_t*)ids); + ret_ds->SetIds(ids); ret_ds->SetDistance(distance); ret_ds->SetIsOwner(true); return ret_ds; } inline DataSetPtr -GenResultDataSet(const int64_t nq, const int64_t topk, std::unique_ptr&& ids, - std::unique_ptr&& distance) { - static_assert(sizeof(int64_t) == sizeof(long int)); - - auto ret_ds = std::make_shared(); - ret_ds->SetRows(nq); - ret_ds->SetDim(topk); - ret_ds->SetIds(std::move(ids)); - ret_ds->SetDistance(std::move(distance)); - ret_ds->SetIsOwner(true); - return ret_ds; -} - -inline DataSetPtr -GenResultDataSet(const int64_t nq, const int64_t topk, std::unique_ptr&& ids, - std::unique_ptr&& distance) { - static_assert(sizeof(int64_t) == sizeof(long long int)); - - auto ret_ds = std::make_shared(); - ret_ds->SetRows(nq); - ret_ds->SetDim(topk); - ret_ds->SetIds(std::move(ids)); - ret_ds->SetDistance(std::move(distance)); - ret_ds->SetIsOwner(true); - return ret_ds; -} - -inline DataSetPtr -#ifdef NOT_COMPILE_FOR_SWIG GenResultDataSet(const int64_t nq, const int64_t* ids, const float* distance, const size_t* lims) { -#else -GenResultDataSet(const int64_t nq, const void* ids, const float* distance, const void* lims) { -#endif auto ret_ds = std::make_shared(); ret_ds->SetRows(nq); - ret_ds->SetIds((const int64_t*)ids); + ret_ds->SetIds(ids); ret_ds->SetDistance(distance); - ret_ds->SetLims((const size_t*)lims); - ret_ds->SetIsOwner(true); - return ret_ds; -} - -inline DataSetPtr -GenResultDataSet(const int64_t nq, RangeSearchResult&& range_search_result) { - auto ret_ds = std::make_shared(); - ret_ds->SetRows(nq); - ret_ds->SetIds(std::move(range_search_result.labels)); - ret_ds->SetDistance(std::move(range_search_result.distances)); - ret_ds->SetLims(std::move(range_search_result.lims)); + ret_ds->SetLims(lims); ret_ds->SetIsOwner(true); return ret_ds; } @@ -411,6 +289,7 @@ GenResultDataSet(const std::string& json_info, const std::string& json_id_set) { ret_ds->SetIsOwner(true); return ret_ds; } +#endif } // namespace knowhere #endif /* DATASET_H */ diff --git a/include/knowhere/device_bitset.h b/include/knowhere/device_bitset.h new file mode 100644 index 000000000..12532dc4e --- /dev/null +++ b/include/knowhere/device_bitset.h @@ -0,0 +1,90 @@ +// Copyright (C) 2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef DEVICE_BITSET_H +#define DEVICE_BITSET_H + +#include "knowhere/bitsetview.h" +#include "raft/core/device_mdarray.hpp" +#include "raft/core/device_resources.hpp" +#include "raft/util/cudart_utils.hpp" + +namespace knowhere { + +struct DeviceBitsetView { + __device__ __host__ + DeviceBitsetView(const DeviceBitsetView& other) + : bits_{other.data()}, num_bits_{other.size()} { + } + __device__ __host__ + DeviceBitsetView(const uint8_t* data, size_t num_bits = size_t{}) + : bits_{data}, num_bits_{num_bits} { + } + + __device__ __host__ bool + empty() const { + return num_bits_ == 0; + } + + __device__ __host__ size_t + size() const { + return num_bits_; + } + + __device__ __host__ size_t + byte_size() const { + return (num_bits_ + 8 - 1) >> 3; + } + + __device__ __host__ const uint8_t* + data() const { + return bits_; + } + + __device__ bool + test(int64_t index) const { + auto result = false; + if (index < num_bits_) { + result = bits_[index >> 3] & (0x1 << (index & 0x7)); + } + return result; + } + + private: + const uint8_t* bits_ = nullptr; + size_t num_bits_ = 0; +}; + +struct DeviceBitset { + DeviceBitset(raft::device_resources& res, BitsetView const& other) + : storage_{[&res, &other]() { + auto result = raft::make_device_vector(res, other.byte_size()); + if (!other.empty()) { + raft::copy(result.data_handle(), other.data(), other.byte_size(), res.get_stream()); + } + return result; + }()}, + num_bits_{other.size()} { + } + + auto + view() { + return DeviceBitsetView{storage_.data_handle(), num_bits_}; + } + + private: + raft::device_vector storage_; + size_t num_bits_; +}; + +} // namespace knowhere + +#endif /* DEVICE_BITSET_H */ diff --git a/include/knowhere/expected.h b/include/knowhere/expected.h index 343d2cec6..bf8fd2fdd 100644 --- a/include/knowhere/expected.h +++ b/include/knowhere/expected.h @@ -40,11 +40,6 @@ enum class Status { raft_inner_error = 18, invalid_binary_set = 19, invalid_instruction_set = 20, - cardinal_inner_error = 21, - cuda_runtime_error = 22, - invalid_index_error = 23, - invalid_cluster_error = 24, - cluster_inner_error = 25, }; inline std::string @@ -88,12 +83,6 @@ Status2String(knowhere::Status status) { return "invalid binary set"; case knowhere::Status::invalid_instruction_set: return "the current index is not supported on the current CPU model"; - case knowhere::Status::cardinal_inner_error: - return "cardinal inner error"; - case knowhere::Status::invalid_cluster_error: - return "invalid cluster type"; - case knowhere::Status::cluster_inner_error: - return "cluster inner error"; default: return "unexpected status"; } @@ -108,13 +97,13 @@ class expected { expected(const expected&) = default; - expected(expected&&) = default; + expected(expected&&) noexcept = default; expected& operator=(const expected&) = default; expected& - operator=(expected&&) = default; + operator=(expected&&) noexcept = default; bool has_value() const { @@ -131,11 +120,6 @@ class expected { assert(val.has_value() == true); return val.value(); } - T& - value() { - assert(val.has_value() == true); - return val.value(); - } const std::string& what() const { diff --git a/include/knowhere/factory.h b/include/knowhere/factory.h new file mode 100644 index 000000000..3138f7494 --- /dev/null +++ b/include/knowhere/factory.h @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef INDEX_FACTORY_H +#define INDEX_FACTORY_H + +#include +#include +#include + +#include "knowhere/index.h" + +namespace knowhere { +class IndexFactory { + public: + Index + Create(const std::string& name, const int32_t& version, const Object& object = nullptr); + const IndexFactory& + Register(const std::string& name, std::function(const int32_t& version, const Object&)> func); + static IndexFactory& + Instance(); + + private: + typedef std::map(const int32_t&, const Object&)>> FuncMap; + IndexFactory(); + static FuncMap& + MapInstance(); +}; + +#define KNOWHERE_CONCAT(x, y) x##y +#define KNOWHERE_REGISTER_GLOBAL(name, func) \ + const IndexFactory& KNOWHERE_CONCAT(index_factory_ref_, name) = IndexFactory::Instance().Register(#name, func) +} // namespace knowhere + +#endif /* INDEX_FACTORY_H */ diff --git a/include/knowhere/index/index.h b/include/knowhere/index.h similarity index 81% rename from include/knowhere/index/index.h rename to include/knowhere/index.h index 9c3992f07..164eea88c 100644 --- a/include/knowhere/index/index.h +++ b/include/knowhere/index.h @@ -16,9 +16,10 @@ #include "knowhere/config.h" #include "knowhere/dataset.h" #include "knowhere/expected.h" -#include "knowhere/index/index_node.h" +#include "knowhere/index_node.h" namespace knowhere { + template class Index { public: @@ -43,20 +44,6 @@ class Index { node = idx.node; } - Index& - operator=(const Index& idx) { - if (&idx == this) { - return *this; - } - if (idx.node == nullptr) { - node = nullptr; - return *this; - } - idx.node->IncRef(); - node = idx.node; - return *this; - } - Index(Index&& idx) { if (idx.node == nullptr) { node = nullptr; @@ -139,32 +126,29 @@ class Index { } Status - Build(const DataSetPtr dataset, const Json& json); + Build(const DataSet& dataset, const Json& json); Status - Train(const DataSetPtr dataset, const Json& json); + Train(const DataSet& dataset, const Json& json); Status - Add(const DataSetPtr dataset, const Json& json); + Add(const DataSet& dataset, const Json& json); expected - Search(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const; + Search(const DataSet& dataset, const Json& json, const BitsetView& bitset) const; - expected> - AnnIterator(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const; + expected>> + AnnIterator(const DataSet& dataset, const Json& json, const BitsetView& bitset) const; expected - RangeSearch(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const; + RangeSearch(const DataSet& dataset, const Json& json, const BitsetView& bitset) const; expected - GetVectorByIds(const DataSetPtr dataset) const; + GetVectorByIds(const DataSet& dataset) const; bool HasRawData(const std::string& metric_type) const; - bool - IsAdditionalScalarSupported() const; - expected GetIndexMeta(const Json& json) const; diff --git a/include/knowhere/index/index_factory.h b/include/knowhere/index/index_factory.h deleted file mode 100644 index 501904f69..000000000 --- a/include/knowhere/index/index_factory.h +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef INDEX_FACTORY_H -#define INDEX_FACTORY_H - -#include -#include -#include -#include - -#include "knowhere/index/index.h" -#include "knowhere/utils.h" - -namespace knowhere { -class IndexFactory { - public: - template - expected> - Create(const std::string& name, const int32_t& version, const Object& object = nullptr); - template - const IndexFactory& - Register(const std::string& name, std::function(const int32_t&, const Object&)> func); - static IndexFactory& - Instance(); - typedef std::tuple>, std::set> GlobalIndexTable; - static GlobalIndexTable& - StaticIndexTableInstance(); - - private: - struct FunMapValueBase { - virtual ~FunMapValueBase() = default; - }; - template - struct FunMapValue : FunMapValueBase { - public: - FunMapValue(std::function& input) : fun_value(input) { - } - std::function fun_value; - }; - typedef std::map> FuncMap; - IndexFactory(); - static FuncMap& - MapInstance(); -}; - -#define KNOWHERE_CONCAT(x, y) index_factory_ref_##x##y -#define KNOWHERE_REGISTER_GLOBAL(name, func, data_type) \ - const IndexFactory& KNOWHERE_CONCAT(name, data_type) = IndexFactory::Instance().Register(#name, func) -#define KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, data_type, ...) \ - KNOWHERE_REGISTER_GLOBAL( \ - name, \ - (static_cast> (*)(const int32_t&, const Object&)>( \ - &Index>::Create)), \ - data_type) -#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, ...) \ - KNOWHERE_REGISTER_GLOBAL( \ - name, \ - [](const int32_t& version, const Object& object) { \ - return (Index>::Create( \ - std::make_unique::type, ##__VA_ARGS__>>(version, object))); \ - }, \ - data_type) -#define KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(name, index_node, data_type, thread_size) \ - KNOWHERE_REGISTER_GLOBAL( \ - name, \ - [](const int32_t& version, const Object& object) { \ - return (Index::Create( \ - std::make_unique::type>>(version, object), thread_size)); \ - }, \ - data_type) -#define KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(table_index, name, index_table) \ - static int name = []() -> int { \ - auto& static_index_table = std::get(IndexFactory::StaticIndexTableInstance()); \ - static_index_table.insert(index_table.begin(), index_table.end()); \ - return 0; \ - }(); -} // namespace knowhere - -#endif /* INDEX_FACTORY_H */ diff --git a/include/knowhere/index/index_node.h b/include/knowhere/index/index_node.h deleted file mode 100644 index 3cf5b7bd8..000000000 --- a/include/knowhere/index/index_node.h +++ /dev/null @@ -1,583 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef INDEX_NODE_H -#define INDEX_NODE_H - -#include -#include -#include -#include - -#include "knowhere/binaryset.h" -#include "knowhere/bitsetview.h" -#include "knowhere/config.h" -#include "knowhere/dataset.h" -#include "knowhere/expected.h" -#include "knowhere/object.h" -#include "knowhere/operands.h" -#include "knowhere/range_util.h" -#include "knowhere/utils.h" -#include "knowhere/version.h" - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) -#include "knowhere/comp/thread_pool.h" -#endif - -namespace knowhere { - -class IndexNode : public Object { - public: - IndexNode(const int32_t ver) : version_(ver) { - } - - IndexNode() : version_(Version::GetDefaultVersion()) { - } - - IndexNode(const IndexNode& other) : version_(other.version_) { - } - - IndexNode(const IndexNode&& other) : version_(other.version_) { - } - - /** - * @brief Builds the index using the provided dataset and configuration. - * - * Mostly, this method combines the `Train` and `Add` steps to create the index structure, but it can be overridden - * if the index doesn't support Train-Add pattern, such as immutable indexes like DiskANN. - * - * @param dataset Dataset to build the index from. - * @param cfg - * @return Status. - * - * @note Indexes need to be ready to search after `Build` is called. TODO:@liliu-z DiskANN is an exception and need - * to be revisited. - */ - virtual Status - Build(const DataSetPtr dataset, const Config& cfg) { - RETURN_IF_ERROR(Train(dataset, cfg)); - return Add(dataset, cfg); - } - - /** - * @brief Trains the index model using the provided dataset and configuration. - * - * @param dataset Dataset used to train the index. - * @param cfg - * @return Status. - * - * @note This interface is only available for growable indexes. For immutable indexes like DiskANN, this method - * should return an error. - */ - virtual Status - Train(const DataSetPtr dataset, const Config& cfg) = 0; - - /** - * @brief Adds data to the trained index. - * - * @param dataset Dataset to add to the index. - * @param cfg - * @return Status - * - * @note - * 1. This interface is only available for growable indexes. For immutable indexes like DiskANN, this method - * should return an error. - * 2. This method need to be thread safe when called with search methods like @see Search, @see RangeSearch and @see - * AnnIterator. - */ - virtual Status - Add(const DataSetPtr dataset, const Config& cfg) = 0; - - /** - * @brief Performs a search operation on the index. - * - * @param dataset Query vectors. - * @param cfg - * @param bitset A BitsetView object for filtering results. - * @return An expected<> object containing the search results or an error. - */ - virtual expected - Search(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const = 0; - - // not thread safe. - class iterator { - public: - virtual std::pair - Next() = 0; - [[nodiscard]] virtual bool - HasNext() = 0; - virtual ~iterator() { - } - }; - using IteratorPtr = std::shared_ptr; - - virtual expected> - AnnIterator(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const { - return expected>>::Err( - Status::not_implemented, "annIterator not supported for current index type"); - } - - /** - * @brief Performs a range search operation on the index. - * - * This method provides a default implementation of range search based on the `AnnIterator`, assuming the iterator - * will buffer an expanded range and return the closest elements on each Next() call. It can be overridden by - * derived classes for more efficient implementations. - * - * @param dataset Query vectors. - * @param cfg - * @param bitset A BitsetView object for filtering results. - * @return An expected<> object containing the range search results or an error. - */ - virtual expected - RangeSearch(const DataSetPtr dataset, const Config& cfg, - const BitsetView& bitset) const { // TODO: @alwayslove2013 test with mock AnnIterator after we - // introduced mock framework into knowhere. Currently this is tested - // in test_sparse.cc with real sparse vector index. - const auto base_cfg = static_cast(cfg); - const float closer_bound = base_cfg.range_filter.value(); - const bool has_closer_bound = closer_bound != defaultRangeFilter; - float further_bound = base_cfg.radius.value(); - - const bool the_larger_the_closer = IsMetricType(base_cfg.metric_type.value(), metric::IP) || - IsMetricType(base_cfg.metric_type.value(), metric::COSINE) || - IsMetricType(base_cfg.metric_type.value(), metric::BM25); - auto is_first_closer = [&the_larger_the_closer](const float dist_1, const float dist_2) { - return the_larger_the_closer ? dist_1 > dist_2 : dist_1 < dist_2; - }; - auto too_close = [&is_first_closer, &closer_bound](float dist) { return is_first_closer(dist, closer_bound); }; - auto same_or_too_far = [&is_first_closer, &further_bound](float dist) { - return !is_first_closer(dist, further_bound); - }; - - /** The `range_search_k` is used to early terminate the iterator-search. - * - `range_search_k < 0` means no early termination. - * - `range_search_k == 0` will return empty results. - * - Note that the number of results is not guaranteed to be exactly range_search_k, it may be more or less. - * */ - const int32_t range_search_k = base_cfg.range_search_k.value(); - LOG_KNOWHERE_DEBUG_ << "range_search_k: " << range_search_k; - if (range_search_k == 0) { - auto nq = dataset->GetRows(); - std::vector> result_id_array(nq); - std::vector> result_dist_array(nq); - auto range_search_result = GetRangeSearchResult(result_dist_array, result_id_array, the_larger_the_closer, - nq, further_bound, closer_bound); - return GenResultDataSet(nq, std::move(range_search_result)); - } - - auto its_or = AnnIterator(dataset, cfg, bitset); - if (!its_or.has_value()) { - return expected::Err(its_or.error(), - "RangeSearch failed due to AnnIterator failure: " + its_or.what()); - } - - const auto its = its_or.value(); - const auto nq = its.size(); - std::vector> result_id_array(nq); - std::vector> result_dist_array(nq); - - const bool retain_iterator_order = base_cfg.retain_iterator_order.value(); - LOG_KNOWHERE_DEBUG_ << "retain_iterator_order: " << retain_iterator_order; - - /** - * use ordered iterator (retain_iterator_order == true) - * - terminate iterator if next distance exceeds `further_bound`. - * - terminate iterator if get enough results. (`range_search_k`) - * */ - auto task_with_ordered_iterator = [&](size_t idx) { - auto it = its[idx]; - while (it->HasNext()) { - auto [id, dist] = it->Next(); - if (has_closer_bound && too_close(dist)) { - continue; - } - if (same_or_too_far(dist)) { - break; - } - result_id_array[idx].push_back(id); - result_dist_array[idx].push_back(dist); - if (range_search_k >= 0 && static_cast(result_id_array[idx].size()) >= range_search_k) { - break; - } - } - }; - - /** - * use default unordered iterator (retain_iterator_order == false) - * - terminate iterator if next distance [consecutively] exceeds `further_bound` several times. - * - if get enough results (`range_search_k`), update a `tighter_further_bound`, to early terminate iterator. - * */ - const auto range_search_level = base_cfg.range_search_level.value(); // from 0 to 0.5 - LOG_KNOWHERE_DEBUG_ << "range_search_level: " << range_search_level; - auto task_with_unordered_iterator = [&](size_t idx) { - // max-heap, use top (the current kth-furthest dist) as the further_bound if size == range_search_k - std::priority_queue, decltype(is_first_closer)> early_stop_further_bounds( - is_first_closer); - auto it = its[idx]; - size_t num_next = 0; - size_t num_consecutive_over_further_bound = 0; - float tighter_further_bound = base_cfg.radius.value(); - auto same_or_too_far = [&is_first_closer, &tighter_further_bound](float dist) { - return !is_first_closer(dist, tighter_further_bound); - }; - while (it->HasNext()) { - auto [id, dist] = it->Next(); - num_next++; - if (has_closer_bound && too_close(dist)) { - continue; - } - if (same_or_too_far(dist)) { - num_consecutive_over_further_bound++; - if (num_consecutive_over_further_bound > - static_cast(std::ceil(num_next * range_search_level))) { - break; - } - continue; - } - if (range_search_k > 0) { - if (static_cast(early_stop_further_bounds.size()) < range_search_k) { - early_stop_further_bounds.emplace(dist); - } else { - early_stop_further_bounds.pop(); - early_stop_further_bounds.emplace(dist); - tighter_further_bound = early_stop_further_bounds.top(); - } - } - num_consecutive_over_further_bound = 0; - result_id_array[idx].push_back(id); - result_dist_array[idx].push_back(dist); - } - }; -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - std::vector> futs; - futs.reserve(nq); - if (retain_iterator_order) { - for (size_t i = 0; i < nq; i++) { - futs.emplace_back( - ThreadPool::GetGlobalSearchThreadPool()->push([&, idx = i]() { task_with_ordered_iterator(idx); })); - } - } else { - for (size_t i = 0; i < nq; i++) { - futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push( - [&, idx = i]() { task_with_unordered_iterator(idx); })); - } - } - WaitAllSuccess(futs); -#else - if (retain_iterator_order) { - for (size_t i = 0; i < nq; i++) { - task_with_ordered_iterator(i); - } - } else { - for (size_t i = 0; i < nq; i++) { - task_with_unordered_iterator(i); - } - } -#endif - - auto range_search_result = GetRangeSearchResult(result_dist_array, result_id_array, the_larger_the_closer, nq, - further_bound, closer_bound); - return GenResultDataSet(nq, std::move(range_search_result)); - } - - /** - * @brief Retrieves raw vectors by their IDs from the index. - * - * @param dataset Dataset containing the IDs of the vectors to retrieve. - * @return An expected<> object containing the retrieved vectors or an error. - * - * @note - * 1. This method may return an error if the index does not contain raw data. The returned raw data must be exactly - * the same as the input data when we do @see Add or @see Build. For example, if the datatype is BF16, then we need - * to return a dataset with BF16 vectors. - * 2. It doesn't guarantee the index contains raw data, so it's better to check with @see HasRawData() before - */ - virtual expected - GetVectorByIds(const DataSetPtr dataset) const = 0; - - /** - * @brief Checks if the index contains raw vector data. - * - * @param metric_type The metric type used in the index. - * @return true if the index contains raw data, false otherwise. - */ - virtual bool - HasRawData(const std::string& metric_type) const = 0; - - virtual bool - IsAdditionalScalarSupported() const { - return false; - } - - /** - * @unused Milvus is not using this method, so it cannot guarantee all indexes implement this method. - * - * This is for Feder, and we can ignore it for now. - */ - virtual expected - GetIndexMeta(const Config& cfg) const = 0; - - /** - * @brief Serializes the index to a binary set. - * - * @param binset The BinarySet to store the serialized index. - * @return Status indicating success or failure of the serialization. - */ - virtual Status - Serialize(BinarySet& binset) const = 0; - - /** - * @brief Deserializes the index from a binary set. - * - * @param binset The BinarySet containing the serialized index. - * @param config - * @return Status indicating success or failure of the deserialization. - * - * @note - * 1. The index should be ready to search after deserialization. - * 2. For immutable indexes, the path for now if Build->Serialize->Deserialize->Search. - */ - virtual Status - Deserialize(const BinarySet& binset, const Config& config) = 0; - - /** - * @brief Deserializes the index from a file. - * - * This method is mostly used for mmap deserialization. However, it has some conflicts with the FileManager that we - * used to deserialize DiskANN. TODO: @liliu-z some redesign is needed here. - * - * @param filename Path to the file containing the serialized index. - * @param config - * @return Status indicating success or failure of the deserialization. - */ - virtual Status - DeserializeFromFile(const std::string& filename, const Config& config) = 0; - - virtual std::unique_ptr - CreateConfig() const = 0; - - /** - * @brief Gets the dimensionality of the vectors in the index. - * - * @return The number of dimensions as an int64_t. - */ - virtual int64_t - Dim() const = 0; - - /** - * @unused Milvus is not using this method, so it cannot guarantee all indexes implement this method. - * - * @brief Gets the memory usage of the index in bytes. - * - * @return The size of the index as an int64_t. - * @note This method doesn't have to be very accurate. - */ - virtual int64_t - Size() const = 0; - - /** - * @brief Gets the number of vectors in the index. - * - * @return The count of vectors as an int64_t. - */ - virtual int64_t - Count() const = 0; - - virtual std::string - Type() const = 0; - - virtual ~IndexNode() { - } - - protected: - Version version_; -}; - -// Common superclass for iterators that expand search range as needed. Subclasses need -// to override `next_batch` which will add expanded vectors to the results. For indexes -// with quantization, override `raw_distance`. -class IndexIterator : public IndexNode::iterator { - public: - IndexIterator(bool larger_is_closer, float refine_ratio = 0.0f, bool retain_iterator_order = false) - : refine_ratio_(refine_ratio), - refine_(refine_ratio != 0.0f), - retain_iterator_order_(retain_iterator_order), - sign_(larger_is_closer ? -1 : 1) { - } - - std::pair - Next() override { - if (!initialized_) { - throw std::runtime_error("Next should not be called before initialization"); - } - auto& q = refined_res_.empty() ? res_ : refined_res_; - if (q.empty()) { - throw std::runtime_error("No more elements"); - } - auto ret = q.top(); - q.pop(); - UpdateNext(); - if (retain_iterator_order_) { - while (HasNext()) { - auto& q = refined_res_.empty() ? res_ : refined_res_; - auto next_ret = q.top(); - // with the help of `sign_`, both `res_` and `refine_res` are min-heap. - // such as `COSINE`, `-dist` will be inserted to `res_` or `refine_res`. - // just make sure that the next value is greater than or equal to the current value. - if (next_ret.val >= ret.val) { - break; - } - q.pop(); - UpdateNext(); - } - } - return std::make_pair(ret.id, ret.val * sign_); - } - - [[nodiscard]] bool - HasNext() override { - if (!initialized_) { - throw std::runtime_error("HasNext should not be called before initialization"); - } - return !res_.empty() || !refined_res_.empty(); - } - - virtual void - initialize() { - if (initialized_) { - throw std::runtime_error("initialize should not be called twice"); - } - UpdateNext(); - initialized_ = true; - } - - protected: - virtual void - next_batch(std::function&)> batch_handler) = 0; - // will be called only if refine_ratio_ is not 0. - virtual float - raw_distance(int64_t id) { - if (!refine_) { - throw std::runtime_error("raw_distance should not be called for indexes without quantization"); - } - throw std::runtime_error("raw_distance not implemented"); - } - - const float refine_ratio_; - const bool refine_; - - std::priority_queue, std::greater> res_; - // unused if refine_ is false - std::priority_queue, std::greater> refined_res_; - - private: - inline size_t - min_refine_size() const { - // TODO: maybe make this configurable - return std::max((size_t)20, (size_t)(res_.size() * refine_ratio_)); - } - - void - UpdateNext() { - auto batch_handler = [this](const std::vector& batch) { - if (batch.empty()) { - return; - } - for (const auto& dist_id : batch) { - res_.emplace(dist_id.id, dist_id.val * sign_); - } - if (refine_) { - while (!res_.empty() && (refined_res_.empty() || refined_res_.size() < min_refine_size())) { - auto pair = res_.top(); - res_.pop(); - refined_res_.emplace(pair.id, raw_distance(pair.id) * sign_); - } - } - }; - next_batch(batch_handler); - } - - bool initialized_ = false; - bool retain_iterator_order_ = false; - const int64_t sign_; -}; - -// An iterator implementation that accepts a list of distances and ids and returns them in order. -class PrecomputedDistanceIterator : public IndexNode::iterator { - public: - PrecomputedDistanceIterator(std::vector&& distances_ids, bool larger_is_closer) - : larger_is_closer_(larger_is_closer), results_(std::move(distances_ids)) { - sort_size_ = get_sort_size(results_.size()); - sort_next(); - } - - // Construct an iterator from a list of distances with index being id, filtering out zero distances. - PrecomputedDistanceIterator(const std::vector& distances, bool larger_is_closer) - : larger_is_closer_(larger_is_closer) { - // 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non zero - // dims out of 30000 total dims) sharing at least 1 common non-zero dimension. - results_.reserve(distances.size() * 0.3); - for (size_t i = 0; i < distances.size(); i++) { - if (distances[i] != 0) { - results_.emplace_back((int64_t)i, distances[i]); - } - } - sort_size_ = get_sort_size(results_.size()); - sort_next(); - } - - std::pair - Next() override { - sort_next(); - auto& result = results_[next_++]; - return std::make_pair(result.id, result.val); - } - - [[nodiscard]] bool - HasNext() override { - return next_ < results_.size() && results_[next_].id != -1; - } - - private: - static inline size_t - get_sort_size(size_t rows) { - return std::max((size_t)50000, rows / 10); - } - - // sort the next sort_size_ elements - inline void - sort_next() { - if (next_ < sorted_) { - return; - } - size_t current_end = std::min(results_.size(), sorted_ + sort_size_); - if (larger_is_closer_) { - std::partial_sort(results_.begin() + sorted_, results_.begin() + current_end, results_.end(), - std::greater()); - } else { - std::partial_sort(results_.begin() + sorted_, results_.begin() + current_end, results_.end(), - std::less()); - } - - sorted_ = current_end; - } - const bool larger_is_closer_; - - std::vector results_; - size_t next_ = 0; - size_t sorted_ = 0; - size_t sort_size_ = 0; -}; - -} // namespace knowhere - -#endif /* INDEX_NODE_H */ diff --git a/include/knowhere/index/index_node_data_mock_wrapper.h b/include/knowhere/index/index_node_data_mock_wrapper.h deleted file mode 100644 index 41a90c380..000000000 --- a/include/knowhere/index/index_node_data_mock_wrapper.h +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef INDEX_NODE_DATA_MOCK_WRAPPER_H -#define INDEX_NODE_DATA_MOCK_WRAPPER_H - -#include "knowhere/index/index_node.h" -#include "knowhere/utils.h" -namespace knowhere { - -template -class IndexNodeDataMockWrapper : public IndexNode { - public: - IndexNodeDataMockWrapper(std::unique_ptr index_node) : index_node_(std::move(index_node)) { - if constexpr (!std::is_same_v::type>) { - LOG_KNOWHERE_INFO_ << "replace index " << (GetKey(this->Type())) << " with " - << (GetKey::type>(this->Type())); - } - } - - Status - Build(const DataSetPtr dataset, const Config& cfg) override; - - Status - Train(const DataSetPtr dataset, const Config& cfg) override; - - Status - Add(const DataSetPtr dataset, const Config& cfg) override; - - expected - Search(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const override; - - expected - RangeSearch(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const override; - - expected> - AnnIterator(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const override; - - expected - GetVectorByIds(const DataSetPtr dataset) const override; - - bool - HasRawData(const std::string& metric_type) const override { - return index_node_->HasRawData(metric_type); - } - - expected - GetIndexMeta(const Config& cfg) const override { - return index_node_->GetIndexMeta(cfg); - } - - Status - Serialize(BinarySet& binset) const override { - return index_node_->Serialize(binset); - } - - Status - Deserialize(const BinarySet& binset, const Config& config) override { - return index_node_->Deserialize(binset, config); - } - - Status - DeserializeFromFile(const std::string& filename, const Config& config) override { - return index_node_->DeserializeFromFile(filename, config); - } - - std::unique_ptr - CreateConfig() const override { - return index_node_->CreateConfig(); - } - - int64_t - Dim() const override { - return index_node_->Dim(); - } - - int64_t - Size() const override { - return index_node_->Size(); - } - - int64_t - Count() const override { - return index_node_->Count(); - } - - std::string - Type() const override { - return index_node_->Type(); - } - - private: - std::unique_ptr index_node_; -}; - -} // namespace knowhere - -#endif /* INDEX_NODE_DATA_MOCK_WRAPPER_H */ diff --git a/include/knowhere/index/index_table.h b/include/knowhere/index/index_table.h deleted file mode 100644 index 0004d1318..000000000 --- a/include/knowhere/index/index_table.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#ifndef INDEX_TABLE_H -#define INDEX_TABLE_H -#include -#include - -#include "knowhere/comp/index_param.h" -#include "knowhere/index/index_factory.h" -namespace knowhere { -static std::set> legal_knowhere_index = { - // binary ivf - {IndexEnum::INDEX_FAISS_BIN_IDMAP, VecType::VECTOR_BINARY}, - {IndexEnum::INDEX_FAISS_BIN_IVFFLAT, VecType::VECTOR_BINARY}, - - // faiss index - {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_IDMAP, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_IVFFLAT, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_IVFFLAT_CC, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_IVFPQ, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_SCANN, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_IVFSQ8, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16}, - - // gpu index - {IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_FLOAT}, - - // hnsw - {IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_HNSW, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16}, - - // faiss hnsw - {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_BFLOAT16}, - - {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_BFLOAT16}, - - // diskann - {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT}, - {IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT16}, - {IndexEnum::INDEX_DISKANN, VecType::VECTOR_BFLOAT16}, - - // sparse index - {IndexEnum::INDEX_SPARSE_INVERTED_INDEX, VecType::VECTOR_SPARSE_FLOAT}, - {IndexEnum::INDEX_SPARSE_WAND, VecType::VECTOR_SPARSE_FLOAT}, -}; - -static std::set legal_support_mmap_knowhere_index = { - // binary ivf - IndexEnum::INDEX_FAISS_BIN_IDMAP, - IndexEnum::INDEX_FAISS_BIN_IVFFLAT, - - // faiss index - IndexEnum::INDEX_FAISS_IDMAP, - IndexEnum::INDEX_FAISS_IVFFLAT, - IndexEnum::INDEX_FAISS_IVFPQ, - IndexEnum::INDEX_FAISS_SCANN, - IndexEnum::INDEX_FAISS_IVFSQ8, - IndexEnum::INDEX_FAISS_IVFSQ_CC, - - // hnsw - IndexEnum::INDEX_HNSW, - IndexEnum::INDEX_HNSW_SQ8, - IndexEnum::INDEX_HNSW_SQ8_REFINE, - - // faiss hnsw - IndexEnum::INDEX_FAISS_HNSW_FLAT, - IndexEnum::INDEX_FAISS_HNSW_SQ, - IndexEnum::INDEX_FAISS_HNSW_PQ, - IndexEnum::INDEX_FAISS_HNSW_PRQ, - - // sparse index - IndexEnum::INDEX_SPARSE_INVERTED_INDEX, - IndexEnum::INDEX_SPARSE_WAND, - -}; -KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(0, KNOWHERE_STATIC_INDEX, legal_knowhere_index) -KNOWHERE_SET_STATIC_GLOBAL_INDEX_TABLE(1, KNOWHERE_SUPPORT_MMAP_INDEX, legal_support_mmap_knowhere_index) - -} // namespace knowhere -#endif /* INDEX_TABLE_H */ diff --git a/include/knowhere/index_node.h b/include/knowhere/index_node.h new file mode 100644 index 000000000..4613d37ca --- /dev/null +++ b/include/knowhere/index_node.h @@ -0,0 +1,115 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef INDEX_NODE_H +#define INDEX_NODE_H + +#include "knowhere/binaryset.h" +#include "knowhere/bitsetview.h" +#include "knowhere/config.h" +#include "knowhere/dataset.h" +#include "knowhere/expected.h" +#include "knowhere/object.h" +#include "knowhere/version.h" + +namespace knowhere { + +class IndexNode : public Object { + public: + IndexNode(const int32_t ver) : version_(ver) { + } + + IndexNode() : version_(Version::GetDefaultVersion()) { + } + + IndexNode(const IndexNode& other) : version_(other.version_) { + } + + IndexNode(const IndexNode&& other) : version_(other.version_) { + } + + virtual Status + Build(const DataSet& dataset, const Config& cfg) { + RETURN_IF_ERROR(Train(dataset, cfg)); + return Add(dataset, cfg); + } + + virtual Status + Train(const DataSet& dataset, const Config& cfg) = 0; + + virtual Status + Add(const DataSet& dataset, const Config& cfg) = 0; + + virtual expected + Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const = 0; + + // not thread safe. + class iterator { + public: + virtual std::pair + Next() = 0; + [[nodiscard]] virtual bool + HasNext() const = 0; + virtual ~iterator() { + } + }; + + virtual expected>> + AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { + throw std::runtime_error("annIterator not supported for current index type"); + } + + virtual expected + RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const = 0; + + virtual expected + GetVectorByIds(const DataSet& dataset) const = 0; + + virtual bool + HasRawData(const std::string& metric_type) const = 0; + + virtual expected + GetIndexMeta(const Config& cfg) const = 0; + + virtual Status + Serialize(BinarySet& binset) const = 0; + + virtual Status + Deserialize(const BinarySet& binset, const Config& config) = 0; + + virtual Status + DeserializeFromFile(const std::string& filename, const Config& config) = 0; + + virtual std::unique_ptr + CreateConfig() const = 0; + + virtual int64_t + Dim() const = 0; + + virtual int64_t + Size() const = 0; + + virtual int64_t + Count() const = 0; + + virtual std::string + Type() const = 0; + + virtual ~IndexNode() { + } + + protected: + Version version_; +}; + +} // namespace knowhere + +#endif /* INDEX_NODE_H */ diff --git a/include/knowhere/index/index_node_thread_pool_wrapper.h b/include/knowhere/index_node_thread_pool_wrapper.h similarity index 85% rename from include/knowhere/index/index_node_thread_pool_wrapper.h rename to include/knowhere/index_node_thread_pool_wrapper.h index 43a9e8b73..acc86908a 100644 --- a/include/knowhere/index/index_node_thread_pool_wrapper.h +++ b/include/knowhere/index_node_thread_pool_wrapper.h @@ -12,7 +12,7 @@ #ifndef INDEX_NODE_THREAD_POOL_WRAPPER_H #define INDEX_NODE_THREAD_POOL_WRAPPER_H -#include "knowhere/index/index_node.h" +#include "knowhere/index_node.h" namespace knowhere { @@ -24,23 +24,23 @@ class IndexNodeThreadPoolWrapper : public IndexNode { IndexNodeThreadPoolWrapper(std::unique_ptr index_node, std::shared_ptr thread_pool); Status - Train(const DataSetPtr dataset, const Config& cfg) override { + Train(const DataSet& dataset, const Config& cfg) override { return index_node_->Train(dataset, cfg); } Status - Add(const DataSetPtr dataset, const Config& cfg) override { + Add(const DataSet& dataset, const Config& cfg) override { return index_node_->Add(dataset, cfg); } expected - Search(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const override; + Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override; expected - RangeSearch(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const override; + RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override; expected - GetVectorByIds(const DataSetPtr dataset) const override { + GetVectorByIds(const DataSet& dataset) const override { return index_node_->GetVectorByIds(dataset); } diff --git a/include/knowhere/object.h b/include/knowhere/object.h index b683dcedc..cadb0065a 100644 --- a/include/knowhere/object.h +++ b/include/knowhere/object.h @@ -13,41 +13,12 @@ #define OBJECT_H #include -#include #include -#include #include "knowhere/file_manager.h" namespace knowhere { -template -struct IdVal { - I id; - T val; - - IdVal() = default; - IdVal(I id, T val) : id(id), val(val) { - } - - inline friend bool - operator<(const IdVal& lhs, const IdVal& rhs) { - return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.id < rhs.id); - } - - inline friend bool - operator>(const IdVal& lhs, const IdVal& rhs) { - return !(lhs < rhs) && !(lhs == rhs); - } - - inline friend bool - operator==(const IdVal& lhs, const IdVal& rhs) { - return lhs.id == rhs.id && lhs.val == rhs.val; - } -}; - -using DistId = IdVal; - class Object { public: Object() = default; diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h deleted file mode 100644 index 6db211ef7..000000000 --- a/include/knowhere/operands.h +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy of -// the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. - -#ifndef OPERANDS_H -#define OPERANDS_H -#include - -#include -#include - -namespace { -union fp32_bits { - uint32_t as_bits; - float as_value; -}; - -__attribute__((always_inline)) inline float -bf16_float(float f) { - auto u32 = fp32_bits{.as_value = f}.as_bits; - // Round off - return fp32_bits{.as_bits = (u32 + 0x8000) & 0xFFFF0000}.as_value; -} - -inline float -fp32_from_bits(const uint32_t& w) { - return fp32_bits{.as_bits = w}.as_value; -} - -inline uint32_t -fp32_to_bits(const float& f) { - return fp32_bits{.as_value = f}.as_bits; -} -}; // namespace - -namespace knowhere { -using fp32 = float; -using bin1 = uint8_t; - -struct fp16 { - public: - fp16() = default; - fp16(const float& f) { - from_fp32(f); - }; - operator float() const { - return to_fp32(bits); - } - - private: - uint16_t bits = 0; - void - from_fp32(const float f) { - // const float scale_to_inf = 0x1.0p+112f; - // const float scale_to_zero = 0x1.0p-110f; - constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; - constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; - float scale_to_inf_val, scale_to_zero_val; - std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); - std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); - const float scale_to_inf = scale_to_inf_val; - const float scale_to_zero = scale_to_zero_val; - -#if defined(_MSC_VER) && _MSC_VER == 1916 - float base = ((f < 0.0 ? -f : f) * scale_to_inf) * scale_to_zero; -#else - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; -#endif - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - this->bits = static_cast((sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); - } - - float - to_fp32(const uint16_t h) const { - const uint32_t w = (uint32_t)h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; - constexpr uint32_t scale_bits = (uint32_t)15 << 23; - - float exp_scale_val; - std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); - const float exp_scale = exp_scale_val; - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - constexpr uint32_t magic_mask = UINT32_C(126) << 23; - constexpr float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = - sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); - } -}; - -struct bf16 { - public: - bf16() = default; - bf16(const float& f) { - from_fp32(f); - }; - operator float() const { - return this->to_fp32(bits); - } - - private: - uint16_t bits = 0; - void - from_fp32(const float f) { - volatile uint32_t fp32Bits = fp32_to_bits(f); - volatile uint16_t bf16Bits = (uint16_t)(fp32Bits >> 16); - this->bits = bf16Bits; - } - float - to_fp32(const uint16_t h) const { - uint32_t bits = ((unsigned int)h) << 16; - bits &= 0xFFFF0000; - return fp32_from_bits(bits); - } -}; - -template -using TypeMatch = std::bool_constant<(... | std::is_same_v)>; -template -using KnowhereDataTypeCheck = TypeMatch; -template -using KnowhereFloatTypeCheck = TypeMatch; - -template -struct MockData { - using type = T; -}; - -template <> -struct MockData { - using type = knowhere::fp32; -}; - -template <> -struct MockData { - using type = knowhere::fp32; -}; -} // namespace knowhere -#endif /* OPERANDS_H */ diff --git a/include/knowhere/prometheus_client.h b/include/knowhere/prometheus_client.h index 6ae107dd2..8a437d56a 100644 --- a/include/knowhere/prometheus_client.h +++ b/include/knowhere/prometheus_client.h @@ -52,82 +52,35 @@ class PrometheusClient { /*****************************************************************************/ // prometheus metrics -extern const prometheus::Histogram::BucketBoundaries defaultBuckets; +extern const prometheus::Histogram::BucketBoundaries buckets; extern const std::unique_ptr prometheusClient; -#define CONCATENATE(x, y) x##_##y -#define PROMETHEUS_LABEL_KNOWHERE knowhere -#define PROMETHEUS_LABEL_CARDINAL cardinal - -#define DEFINE_PROMETHEUS_GAUGE_FAMILY(name, desc) \ - prometheus::Family& CONCATENATE(name, family) = \ - prometheus::BuildGauge().Name(#name).Help(desc).Register(knowhere::prometheusClient->GetRegistry()); - -#define DEFINE_PROMETHEUS_GAUGE(name, module) \ - prometheus::Gauge& CONCATENATE(module, name) = CONCATENATE(name, family).Add({{"module", #module}}); - -#define DEFINE_PROMETHEUS_COUNTER_FAMILY(name, desc) \ - prometheus::Family& CONCATENATE(name, family) = \ - prometheus::BuildCounter().Name(#name).Help(desc).Register(knowhere::prometheusClient->GetRegistry()); - -#define DEFINE_PROMETHEUS_COUNTER(name, module) \ - prometheus::Counter& CONCATENATE(module, name) = CONCATENATE(name, family).Add({{"module", #module}}); - -#define DEFINE_PROMETHEUS_HISTOGRAM_FAMILY(name, desc) \ - prometheus::Family& CONCATENATE(name, family) = \ - prometheus::BuildHistogram().Name(#name).Help(desc).Register(knowhere::prometheusClient->GetRegistry()); - -#define DEFINE_PROMETHEUS_HISTOGRAM_WITH_BUCKETS(name, module, buckets) \ - prometheus::Histogram& CONCATENATE(module, name) = CONCATENATE(name, family).Add({{"module", #module}}, buckets); - -#define DEFINE_PROMETHEUS_HISTOGRAM(name, module) DEFINE_PROMETHEUS_HISTOGRAM_WITH_BUCKETS(name, module, defaultBuckets) - -#define DECLARE_PROMETHEUS_GAUGE(name, module) extern prometheus::Gauge& CONCATENATE(module, name); -#define DECLARE_PROMETHEUS_COUNTER(name, module) extern prometheus::Counter& CONCATENATE(module, name); -#define DECLARE_PROMETHEUS_HISTOGRAM(name, module) extern prometheus::Histogram& CONCATENATE(module, name); - -DECLARE_PROMETHEUS_HISTOGRAM(build_latency, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(build_latency, PROMETHEUS_LABEL_CARDINAL); - -DECLARE_PROMETHEUS_HISTOGRAM(load_latency, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(load_latency, PROMETHEUS_LABEL_CARDINAL); - -DECLARE_PROMETHEUS_HISTOGRAM(search_latency, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(search_latency, PROMETHEUS_LABEL_CARDINAL); - -// cardinal uses the RangeSearch function of the parent class `IndexNode` (index_node.h). -// both use the knowhere metric uniformly. -DECLARE_PROMETHEUS_HISTOGRAM(range_search_latency, PROMETHEUS_LABEL_KNOWHERE); - -DECLARE_PROMETHEUS_HISTOGRAM(ann_iterator_init_latency, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(ann_iterator_init_latency, PROMETHEUS_LABEL_CARDINAL); - -DECLARE_PROMETHEUS_HISTOGRAM(search_topk, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(search_topk, PROMETHEUS_LABEL_CARDINAL); - -DECLARE_PROMETHEUS_HISTOGRAM(bitset_ratio, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(quant_compute_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(raw_compute_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(cache_hit_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(io_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(queue_latency, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(exec_latency, PROMETHEUS_LABEL_CARDINAL); - -DECLARE_PROMETHEUS_HISTOGRAM(graph_search_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(ivf_search_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(bf_search_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(re_search_cnt, PROMETHEUS_LABEL_CARDINAL); - -DECLARE_PROMETHEUS_HISTOGRAM(filter_connectivity_ratio, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(filter_mv_only_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(filter_mv_activated_fields_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(filter_mv_change_base_cnt, PROMETHEUS_LABEL_CARDINAL); -DECLARE_PROMETHEUS_HISTOGRAM(filter_mv_supplement_ep_bool_cnt, PROMETHEUS_LABEL_CARDINAL); - -DECLARE_PROMETHEUS_HISTOGRAM(hnsw_bitset_ratio, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(hnsw_search_hops, PROMETHEUS_LABEL_KNOWHERE); - -DECLARE_PROMETHEUS_HISTOGRAM(diskann_bitset_ratio, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(diskann_search_hops, PROMETHEUS_LABEL_KNOWHERE); -DECLARE_PROMETHEUS_HISTOGRAM(diskann_range_search_iters, PROMETHEUS_LABEL_KNOWHERE); +#define DEFINE_PROMETHEUS_GAUGE(name, desc) \ + prometheus::Family& name##_family = \ + prometheus::BuildGauge().Name(#name).Help(desc).Register(knowhere::prometheusClient->GetRegistry()); \ + prometheus::Gauge& name = name##_family.Add({}); + +#define DEFINE_PROMETHEUS_COUNTER(name, desc) \ + prometheus::Family& name##_family = \ + prometheus::BuildCounter().Name(#name).Help(desc).Register(knowhere::prometheusClient->GetRegistry()); \ + prometheus::Counter& name = name##_family.Add({}); + +#define DEFINE_PROMETHEUS_HISTOGRAM(name, desc) \ + prometheus::Family& name##_family = \ + prometheus::BuildHistogram().Name(#name).Help(desc).Register(knowhere::prometheusClient->GetRegistry()); \ + prometheus::Histogram& name = name##_family.Add({}, knowhere::buckets); + +#define DECLARE_PROMETHEUS_GAUGE(name_gauge) extern prometheus::Gauge& name_gauge; +#define DECLARE_PROMETHEUS_COUNTER(name_counter) extern prometheus::Counter& name_counter; +#define DECLARE_PROMETHEUS_HISTOGRAM(name_histogram) extern prometheus::Histogram& name_histogram; + +DECLARE_PROMETHEUS_COUNTER(knowhere_build_count); +DECLARE_PROMETHEUS_COUNTER(knowhere_search_count); +DECLARE_PROMETHEUS_COUNTER(knowhere_ann_iterator_count); +DECLARE_PROMETHEUS_COUNTER(knowhere_range_search_count); +DECLARE_PROMETHEUS_HISTOGRAM(knowhere_build_latency); +DECLARE_PROMETHEUS_HISTOGRAM(knowhere_search_topk); +DECLARE_PROMETHEUS_HISTOGRAM(knowhere_search_latency); +DECLARE_PROMETHEUS_HISTOGRAM(knowhere_ann_iterator_init_latency); +DECLARE_PROMETHEUS_HISTOGRAM(knowhere_range_search_latency); } // namespace knowhere diff --git a/include/knowhere/sparse_utils.h b/include/knowhere/sparse_utils.h deleted file mode 100644 index aca4b3238..000000000 --- a/include/knowhere/sparse_utils.h +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy of -// the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// valributed under the License is valributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "knowhere/expected.h" -#include "knowhere/object.h" -#include "knowhere/operands.h" - -namespace knowhere::sparse { - -// integer type in SparseRow -using table_t = uint32_t; -// type used to represent the id of a vector in the index interface. -// this is same as other index types. -using label_t = int64_t; - -template -using SparseIdVal = IdVal; - -// DocValueComputer takes a value of a doc vector and returns the a computed -// value that can be used to multiply directly with the corresponding query -// value. The second parameter is the document length of the database vector, -// which is used in BM25. -template -using DocValueComputer = std::function; - -template -auto -GetDocValueOriginalComputer() { - static DocValueComputer lambda = [](const T& right, const float) -> float { return right; }; - return lambda; -} - -template -auto -GetDocValueBM25Computer(float k1, float b, float avgdl) { - return [k1, b, avgdl](const T& tf, const float doc_len) -> float { - return tf * (k1 + 1) / (tf + k1 * (1 - b + b * (doc_len / avgdl))); - }; -} - -template -class SparseRow { - static_assert(std::is_same_v, "SparseRow supports float only"); - - public: - // construct an SparseRow with memory allocated to hold `count` elements. - SparseRow(size_t count = 0) - : data_(count ? new uint8_t[count * element_size()] : nullptr), count_(count), own_data_(true) { - } - - SparseRow(size_t count, uint8_t* data, bool own_data) : data_(data), count_(count), own_data_(own_data) { - } - - // copy constructor and copy assignment operator perform deep copy - SparseRow(const SparseRow& other) : SparseRow(other.count_) { - std::memcpy(data_, other.data_, data_byte_size()); - } - - SparseRow(SparseRow&& other) noexcept : SparseRow() { - swap(*this, other); - } - - SparseRow& - operator=(const SparseRow& other) { - if (this != &other) { - SparseRow tmp(other); - swap(*this, tmp); - } - return *this; - } - - SparseRow& - operator=(SparseRow&& other) noexcept { - swap(*this, other); - return *this; - } - - ~SparseRow() { - if (own_data_ && data_ != nullptr) { - delete[] data_; - data_ = nullptr; - } - } - - size_t - size() const { - return count_; - } - - size_t - memory_usage() const { - return data_byte_size() + sizeof(*this); - } - - // return the number of bytes used by the underlying data array. - size_t - data_byte_size() const { - return count_ * element_size(); - } - - void* - data() { - return data_; - } - - const void* - data() const { - return data_; - } - - // dim of a sparse vector is the max index + 1, or 0 for an empty vector. - int64_t - dim() const { - if (count_ == 0) { - return 0; - } - auto* elem = reinterpret_cast(data_) + count_ - 1; - return elem->index + 1; - } - - SparseIdVal - operator[](size_t i) const { - auto* elem = reinterpret_cast(data_) + i; - return {elem->index, elem->value}; - } - - void - set_at(size_t i, table_t index, T value) { - auto* elem = reinterpret_cast(data_) + i; - elem->index = index; - elem->value = value; - } - - // In the case of asymetric distance functions, this should be the query - // and the other should be the database vector. For example using BM25, we - // should call query_vec.dot(doc_vec) instead of doc_vec.dot(query_vec). - template > - float - dot(const SparseRow& other, Computer computer = GetDocValueOriginalComputer(), const T other_sum = 0) const { - float product_sum = 0.0f; - size_t i = 0; - size_t j = 0; - // TODO: improve with _mm_cmpistrm or the AVX512 alternative. - while (i < count_ && j < other.count_) { - auto* left = reinterpret_cast(data_) + i; - auto* right = reinterpret_cast(other.data_) + j; - - if (left->index < right->index) { - ++i; - } else if (left->index > right->index) { - ++j; - } else { - product_sum += left->value * computer(right->value, other_sum); - ++i; - ++j; - } - } - return product_sum; - } - - friend void - swap(SparseRow& left, SparseRow& right) { - using std::swap; - swap(left.count_, right.count_); - swap(left.data_, right.data_); - swap(left.own_data_, right.own_data_); - } - - static inline size_t - element_size() { - return sizeof(table_t) + sizeof(T); - } - - private: - // ElementProxy is used to access elements in the data_ array and should - // never be actually constructed. - struct __attribute__((packed)) ElementProxy { - table_t index; - T value; - ElementProxy() = delete; - ElementProxy(const ElementProxy&) = delete; - }; - // data_ must be sorted by column id. use raw pointer for easy mmap and zero - // copy. - uint8_t* data_; - size_t count_; - bool own_data_; -}; - -// When pushing new elements into a MaxMinHeap, only `capacity` elements with the -// largest val are kept. pop()/top() returns the smallest element out of them. -template -class MaxMinHeap { - public: - explicit MaxMinHeap(int capacity) : capacity_(capacity), pool_(capacity) { - } - void - push(table_t id, T val) { - if (size_ < capacity_) { - pool_[size_] = {id, val}; - size_ += 1; - std::push_heap(pool_.begin(), pool_.begin() + size_, std::greater>()); - } else if (val > pool_[0].val) { - sift_down(id, val); - } - } - table_t - pop() { - std::pop_heap(pool_.begin(), pool_.begin() + size_, std::greater>()); - size_ -= 1; - return pool_[size_].id; - } - [[nodiscard]] size_t - size() const { - return size_; - } - [[nodiscard]] bool - empty() const { - return size() == 0; - } - SparseIdVal - top() const { - return pool_[0]; - } - [[nodiscard]] bool - full() const { - return size_ == capacity_; - } - - private: - void - sift_down(table_t id, T val) { - size_t i = 0; - for (; 2 * i + 1 < size_;) { - size_t j = i; - size_t l = 2 * i + 1, r = 2 * i + 2; - if (pool_[l].val < val) { - j = l; - } - if (r < size_ && pool_[r].val < std::min(pool_[l].val, val)) { - j = r; - } - if (i == j) { - break; - } - pool_[i] = pool_[j]; - i = j; - } - pool_[i] = {id, val}; - } - - size_t size_ = 0, capacity_; - std::vector> pool_; -}; // class MaxMinHeap - -} // namespace knowhere::sparse diff --git a/include/knowhere/tolower.h b/include/knowhere/tolower.h deleted file mode 100644 index 87c0dbef3..000000000 --- a/include/knowhere/tolower.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (C) 2019-2024 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy of -// the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// valributed under the License is valributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. - -#pragma once - -#include -#include - -namespace knowhere { - -static inline std::string -str_to_lower(std::string s) { - std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); - - return s; -} - -} // namespace knowhere diff --git a/include/knowhere/tracer.h b/include/knowhere/tracer.h deleted file mode 100644 index 101c71e52..000000000 --- a/include/knowhere/tracer.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once - -#include -#include - -#include "opentelemetry/trace/provider.h" - -#define TRACE_SERVICE_KNOWHERE "knowhere" - -namespace knowhere::tracer { - -struct TraceConfig { - std::string exporter; - float sampleFraction; - std::string jaegerURL; - std::string otlpEndpoint; - bool oltpSecure; - - int nodeID; -}; - -struct TraceContext { - const uint8_t* traceID = nullptr; - const uint8_t* spanID = nullptr; - uint8_t traceFlags = 0; -}; -namespace trace = opentelemetry::trace; - -void -initTelemetry(const TraceConfig& cfg); - -std::shared_ptr -GetTracer(); - -std::shared_ptr -StartSpan(const std::string& name, TraceContext* ctx = nullptr); - -void -SetRootSpan(std::shared_ptr span); - -void -CloseRootSpan(); - -void -AddEvent(const std::string& event_label); - -bool -EmptyTraceID(const TraceContext* ctx); - -bool -EmptySpanID(const TraceContext* ctx); - -std::string -BytesToHexStr(const uint8_t* data, size_t len); - -std::string -GetIDFromHexStr(const std::string& hexStr); - -} // namespace knowhere::tracer diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index 9d494a695..efc80d5cc 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -18,7 +18,6 @@ #include "knowhere/binaryset.h" #include "knowhere/dataset.h" -#include "knowhere/operands.h" namespace knowhere { @@ -31,26 +30,22 @@ IsMetricType(const std::string& str, const knowhere::MetricType& metric_type) { inline bool IsFlatIndex(const knowhere::IndexType& index_type) { - static std::vector flat_index_list = { - IndexEnum::INDEX_FAISS_IDMAP, IndexEnum::INDEX_FAISS_GPU_IDMAP, IndexEnum::INDEX_GPU_BRUTEFORCE}; + static std::vector flat_index_list = {IndexEnum::INDEX_FAISS_IDMAP, + IndexEnum::INDEX_FAISS_GPU_IDMAP}; return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end(); } -template extern float -NormalizeVec(DataType* x, int32_t d); +NormalizeVec(float* x, int32_t d); -template extern std::vector -NormalizeVecs(DataType* x, size_t rows, int32_t dim); +NormalizeVecs(float* x, size_t rows, int32_t dim); -template -extern std::unique_ptr -CopyAndNormalizeVecs(const DataType* x, size_t rows, int32_t dim); - -template extern void -NormalizeDataset(const DataSetPtr dataset); +Normalize(const DataSet& dataset); + +extern std::unique_ptr +CopyAndNormalizeVecs(const float* x, size_t rows, int32_t dim); constexpr inline uint64_t seed = 0xc70f6907UL; @@ -63,15 +58,6 @@ hash_vec(const float* x, size_t d) { return h; } -inline uint64_t -hash_u8_vec(const uint8_t* x, size_t d) { - uint64_t h = seed; - for (size_t i = 0; i < d; ++i) { - h = h * 13331 + *(x + i); - } - return h; -} - inline uint64_t hash_binary_vec(const uint8_t* x, size_t d) { size_t len = (d + 7) / 8; @@ -82,98 +68,6 @@ hash_binary_vec(const uint8_t* x, size_t d) { return h; } -inline uint64_t -hash_half_precision_float(const void* x, size_t d) { - uint64_t h = seed; - auto u16_x = (uint16_t*)(x); - for (size_t i = 0; i < d; ++i) { - h = h * 13331 + u16_x[i]; - } - return h; -} - -template -inline std::string -GetKey(const std::string& name) { - static_assert(KnowhereDataTypeCheck::value == true); - if (std::is_same_v) { - return name + std::string("_fp32"); - } else if (std::is_same_v) { - return name + std::string("_fp16"); - } else if (std::is_same_v) { - return name + std::string("_bf16"); - } else if (std::is_same_v) { - return name + std::string("_bin1"); - } -} - -template -inline DataSetPtr -data_type_conversion(const DataSet& src, const std::optional start = std::nullopt, - const std::optional count = std::nullopt) { - auto dim = src.GetDim(); - auto rows = src.GetRows(); - - // check the acceptable range - int64_t start_row = start.value_or(0); - if (start_row < 0 || start_row >= rows) { - return nullptr; - } - - int64_t count_rows = count.value_or(rows - start_row); - if (count_rows < 0 || start_row + count_rows > rows) { - return nullptr; - } - - // map - auto* des_data = new OutType[dim * count_rows]; - auto* src_data = (const InType*)src.GetTensor(); - for (auto i = 0; i < dim * count_rows; i++) { - des_data[i] = (OutType)src_data[i + start_row * dim]; - } - - auto des = std::make_shared(); - des->SetRows(count_rows); - des->SetDim(dim); - des->SetTensor(des_data); - des->SetIsOwner(true); - return des; -} - -// Convert DataSet from DataType to float -// * no start, no count, float -> returns the source without cloning -// * no start, no count, no float -> returns a clone with a different type -// * start, no count -> returns a clone that starts from a given row 'start' -// * no start, count -> returns a clone that starts from a row 0 and has 'count' rows -// * start, count -> returns a clone that start from a given row 'start' and has 'count' rows -// * invalid start, count values -> returns nullptr -template -inline DataSetPtr -ConvertFromDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, - const std::optional count = std::nullopt) { - if constexpr (std::is_same_v::type>) { - if (!start.has_value() && !count.has_value()) { - return ds; - } - } - - return data_type_conversion::type>(*ds, start, count); -} - -// Convert DataSet from float to DataType -template -inline DataSetPtr -ConvertToDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, - const std::optional count = std::nullopt) { - if constexpr (std::is_same_v::type>) { - if (!start.has_value() && !count.has_value()) { - return ds; - } - } - - return data_type_conversion::type, DataType>(*ds, start, count); -} - template inline T round_down(const T value, const T align) { @@ -186,26 +80,4 @@ ConvertIVFFlat(const BinarySet& binset, const MetricType metric_type, const uint bool UseDiskLoad(const std::string& index_type, const int32_t& /*version*/); -template -static void -writeBinaryPOD(W& out, const T& podRef) { - out.write((char*)&podRef, sizeof(T)); -} - -template -static void -readBinaryPOD(R& in, T& podRef) { - in.read((char*)&podRef, sizeof(T)); -} - -// taken from -// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h -// round up X to the nearest multiple of Y -#define ROUND_UP(X, Y) ((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y)) - -#define DIV_ROUND_UP(X, Y) (((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) - -// round down X to the nearest multiple of Y -#define ROUND_DOWN(X, Y) (((uint64_t)(X) / (Y)) * (Y)) - } // namespace knowhere diff --git a/include/knowhere/version.h b/include/knowhere/version.h index 3db6a6c02..c4a82322c 100644 --- a/include/knowhere/version.h +++ b/include/knowhere/version.h @@ -21,7 +21,7 @@ namespace knowhere { namespace { static constexpr int32_t default_version = 0; static constexpr int32_t minimal_version = 0; -static constexpr int32_t current_version = 5; +static constexpr int32_t current_version = 1; } // namespace class Version { diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index cd7004424..815db8e31 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -1,44 +1,12 @@ from . import swigknowhere from .swigknowhere import Status -from .swigknowhere import CreateBinarySet, GetBinarySet, GetNullDataSet, GetNullBitSetView -from .swigknowhere import BruteForceSearchFloat, BruteForceRangeSearchFloat -from .swigknowhere import BruteForceSearchFP16, BruteForceRangeSearchFP16 -from .swigknowhere import BruteForceSearchBF16, BruteForceRangeSearchBF16 -from .swigknowhere import BruteForceSearchBin, BruteForceRangeSearchBin - +from .swigknowhere import GetBinarySet, GetNullDataSet, GetNullBitSetView +from .swigknowhere import BruteForceSearch, BruteForceRangeSearch import numpy as np -from bfloat16 import bfloat16 - - -def CreateIndex(name, version, type=np.float32): - if type == np.float32: - return swigknowhere.IndexWrapFloat(name, version) - if type == np.float16: - return swigknowhere.IndexWrapFP16(name, version) - if type == bfloat16: - return swigknowhere.IndexWrapBF16(name, version) - if type == np.uint8: - return swigknowhere.IndexWrapBin(name, version) - -def BruteForceSearch(type=np.float32, *args): - if type == np.float32: - return BruteForceSearchFloat(*args) - if type == np.float16: - return BruteForceSearchFP16(*args) - if type == bfloat16: - return BruteForceSearchBF16(*args) - if type == np.uint8: - return BruteForceSearchBin(*args) - -def BruteForceRangeSearch(type=np.float32, *args): - if type == np.float32: - return BruteForceRangeSearchFloat(*args) - if type == np.float16: - return BruteForceRangeSearchFP16(*args) - if type == bfloat16: - return BruteForceRangeSearchBF16(*args) - if type == np.uint8: - return BruteForceRangeSearchBin(*args) + + +def CreateIndex(name, version): + return swigknowhere.IndexWrap(name, version) def GetCurrentVersion(): @@ -57,49 +25,17 @@ def Dump(binset, file_name): return swigknowhere.Dump(binset, file_name) -def WriteIndexToDisk(binset, index_type, data_path): - return swigknowhere.WriteIndexToDisk(binset, index_type, data_path) - -def ArrayToBinary(arr): - if arr.dtype == np.uint8: - return swigknowhere.Array2Binary(arr) - raise ValueError( - """ - ArrayToBinary only support numpy array dtype uint8. - """ - ) - def ArrayToDataSet(arr): if arr.ndim == 1: return swigknowhere.Array2DataSetIds(arr) if arr.ndim == 2: - if arr.dtype == np.uint8: + if arr.dtype == np.int32: return swigknowhere.Array2DataSetI(arr) if arr.dtype == np.float32: return swigknowhere.Array2DataSetF(arr) - if arr.dtype == np.float16: - arr = arr.astype(np.float32) - return swigknowhere.Array2DataSetFP16(arr) - if arr.dtype == bfloat16: - arr = arr.astype(np.float32) - return swigknowhere.Array2DataSetBF16(arr) raise ValueError( """ - ArrayToDataSet only support numpy array dtype float32,uint8,float16 and bfloat16. - """ - ) - -# follow csr_matrix format -# row i are stored in ``indices[indptr[i]:indptr[i+1]]`` and their -# corresponding values are stored in ``data[indptr[i]:indptr[i+1]] -def ArrayToSparseDataSet(data, indices, indptr): - if data.ndim == 1 and indices.ndim == 1 and indptr.ndim == 1: - assert data.shape[0] == indices.shape[0] - assert indptr.shape[0] > 1 - return swigknowhere.Array2SparseDataSet(data, indices, indptr) - raise ValueError( - """ - ArrayToSparseDataSet input type wrong. + ArrayToDataSet only support numpy array dtype float32 and int32. """ ) @@ -153,34 +89,13 @@ def GetVectorDataSetToArray(ans): swigknowhere.DataSetTensor2Array(ans, data) return data -def GetFloat16VectorDataSetToArray(ans): - dim = swigknowhere.DataSet_Dim(ans) - rows = swigknowhere.DataSet_Rows(ans) - data = np.zeros([rows, dim]).astype(np.float32) - swigknowhere.Float16DataSetTensor2Array(ans, data) - data = data.astype(np.float16) - return data - -def GetBFloat16VectorDataSetToArray(ans): - dim = swigknowhere.DataSet_Dim(ans) - rows = swigknowhere.DataSet_Rows(ans) - data = np.zeros([rows, dim]).astype(np.float32) - swigknowhere.BFloat16DataSetTensor2Array(ans, data) - data = data.astype(bfloat16) - return data def GetBinaryVectorDataSetToArray(ans): - dim = int(swigknowhere.DataSet_Dim(ans) / 8) + dim = int(swigknowhere.DataSet_Dim(ans) / 32) rows = swigknowhere.DataSet_Rows(ans) - data = np.zeros([rows, dim]).astype(np.uint8) + data = np.zeros([rows, dim]).astype(np.int32) swigknowhere.BinaryDataSetTensor2Array(ans, data) return data def SetSimdType(type): swigknowhere.SetSimdType(type) - -def SetBuildThreadPool(num_threads): - swigknowhere.SetBuildThreadPool(num_threads) - -def SetSearchThreadPool(num_threads): - swigknowhere.SetSearchThreadPool(num_threads) diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index 892ce004d..18a435d9d 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -26,18 +26,13 @@ typedef uint64_t size_t; #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include #endif -#include #include -#include +#include #include #include -#include -#include "knowhere/operands.h" #include #include #include -#include -#include #include #include using namespace knowhere; @@ -56,15 +51,11 @@ import_array(); %include %include %include -%include %include %shared_ptr(knowhere::DataSet) -%shared_ptr(knowhere::Binary) %shared_ptr(knowhere::BinarySet) %template(DataSetPtr) std::shared_ptr; -%template(BinaryPtr) std::shared_ptr; %template(BinarySetPtr) std::shared_ptr; -%template(int64_float_pair) std::pair; %include %include %include @@ -74,14 +65,9 @@ import_array(); %apply (float* IN_ARRAY2, int DIM1, int DIM2) {(float* xb, int nb, int dim)} %apply (int* IN_ARRAY2, int DIM1, int DIM2) {(int* xb, int nb, int dim)} %apply (uint8_t *IN_ARRAY1, int DIM1) {(uint8_t *block, int size)} -%apply (uint8_t* IN_ARRAY2, int DIM1, int DIM2) {(uint8_t* xb, int nb, int dim)} -%apply (uint8_t* INPLACE_ARRAY2, int DIM1, int DIM2) {(uint8_t *data,int rows,int dim)} %apply (int *IN_ARRAY1, int DIM1) {(int *lims, int len)} %apply (int *IN_ARRAY1, int DIM1) {(int *ids, int len)} %apply (float *IN_ARRAY1, int DIM1) {(float *dis, int len)} -%apply (float *IN_ARRAY1, int DIM1) {(float *data, int nb1)} -%apply (int *IN_ARRAY1, int DIM1) {(int *ids, int nb2)} -%apply (int64_t *IN_ARRAY1, int DIM1) {(int64_t* indptr, int nb3)} %apply (float* INPLACE_ARRAY2, int DIM1, int DIM2){(float *dis,int nq_1,int k_1)} %apply (int *INPLACE_ARRAY2, int DIM1, int DIM2){(int *ids,int nq_2,int k_2)} %apply (float* INPLACE_ARRAY2, int DIM1, int DIM2){(float *data,int rows,int dim)} @@ -118,7 +104,7 @@ del Enum %inline %{ class GILReleaser { - public: +public: GILReleaser() : save(PyEval_SaveThread()) { } ~GILReleaser() { @@ -127,65 +113,41 @@ class GILReleaser { PyThreadState* save; }; -class AnnIteratorWrap { - public: - AnnIteratorWrap(std::shared_ptr it = nullptr) : it_(it) { - if (it_ == nullptr) { - throw std::runtime_error("ann iterator must not be nullptr."); - } - } - ~AnnIteratorWrap() { - } - - bool HasNext() { - return it_->HasNext(); - } - - std::pair Next() { - return it_->Next(); - } - - private: - std::shared_ptr it_; -}; - -template class IndexWrap { public: IndexWrap(const std::string& name, const int32_t& version) { GILReleaser rel; - if (name == std::string(knowhere::IndexEnum::INDEX_DISKANN)) { + if (knowhere::UseDiskLoad(name, version)) { std::shared_ptr file_manager = std::make_shared(); auto diskann_pack = knowhere::Pack(file_manager); - idx = IndexFactory::Instance().Create(name, version, - diskann_pack); + idx = IndexFactory::Instance().Create(name, version, diskann_pack); } else { - idx = IndexFactory::Instance().Create(name, version); + idx = IndexFactory::Instance().Create(name, version); } } knowhere::Status Build(knowhere::DataSetPtr dataset, const std::string& json) { GILReleaser rel; - return idx.value().Build(dataset, knowhere::Json::parse(json)); + return idx.Build(*dataset, knowhere::Json::parse(json)); } knowhere::Status Train(knowhere::DataSetPtr dataset, const std::string& json) { GILReleaser rel; - return idx.value().Train(dataset, knowhere::Json::parse(json)); + return idx.Train(*dataset, knowhere::Json::parse(json)); } knowhere::Status Add(knowhere::DataSetPtr dataset, const std::string& json) { GILReleaser rel; - return idx.value().Add(dataset, knowhere::Json::parse(json)); + return idx.Add(*dataset, knowhere::Json::parse(json)); } knowhere::DataSetPtr Search(knowhere::DataSetPtr dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status) { GILReleaser rel; - auto res = idx.value().Search(dataset, knowhere::Json::parse(json), bitset); + auto res = idx.Search(*dataset, knowhere::Json::parse(json), bitset); if (res.has_value()) { status = knowhere::Status::success; return res.value(); @@ -195,26 +157,10 @@ class IndexWrap { } } - std::vector - GetAnnIterator(knowhere::DataSetPtr dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status) { - GILReleaser rel; - auto res = idx.value().AnnIterator(dataset, knowhere::Json::parse(json), bitset); - std::vector result; - if (!res.has_value()) { - status = res.error(); - return result; - } - status = knowhere::Status::success; - for (auto it : res.value()) { - result.emplace_back(it); - } - return result; - } - knowhere::DataSetPtr RangeSearch(knowhere::DataSetPtr dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status){ GILReleaser rel; - auto res = idx.value().RangeSearch(dataset, knowhere::Json::parse(json), bitset); + auto res = idx.RangeSearch(*dataset, knowhere::Json::parse(json), bitset); if (res.has_value()) { status = knowhere::Status::success; return res.value(); @@ -227,7 +173,7 @@ class IndexWrap { knowhere::DataSetPtr GetVectorByIds(knowhere::DataSetPtr dataset, knowhere::Status& status) { GILReleaser rel; - auto res = idx.value().GetVectorByIds(dataset); + auto res = idx.GetVectorByIds(*dataset); if (res.has_value()) { status = knowhere::Status::success; return res.value(); @@ -240,49 +186,43 @@ class IndexWrap { bool HasRawData(const std::string& metric_type) { GILReleaser rel; - return idx.value().HasRawData(metric_type); + return idx.HasRawData(metric_type); } knowhere::Status Serialize(knowhere::BinarySetPtr binset) { GILReleaser rel; - return idx.value().Serialize(*binset); + return idx.Serialize(*binset); } knowhere::Status Deserialize(knowhere::BinarySetPtr binset, const std::string& json) { GILReleaser rel; - return idx.value().Deserialize(*binset, knowhere::Json::parse(json)); - } - - knowhere::Status - DeserializeFromFile(const std::string& filename, const std::string& json) { - GILReleaser rel; - return idx.value().DeserializeFromFile(filename, knowhere::Json::parse(json)); + return idx.Deserialize(*binset, knowhere::Json::parse(json)); } int64_t Dim() { - return idx.value().Dim(); + return idx.Dim(); } int64_t Size() { - return idx.value().Size(); + return idx.Size(); } int64_t Count() { - return idx.value().Count(); + return idx.Count(); } std::string Type() { - return idx.value().Type(); + return idx.Type(); } private: - expected> idx; + Index idx; }; class BitSet { @@ -321,83 +261,17 @@ Array2DataSetF(float* xb, int nb, int dim) { return ds; }; -knowhere::DataSetPtr -Array2DataSetFP16(float* xb, int nb, int dim) { - auto ds = std::make_shared(); - ds->SetIsOwner(true); - ds->SetRows(nb); - ds->SetDim(dim); - // float to fp16 - auto fp16_data = new knowhere::fp16[nb * dim]; - for (int i = 0; i < nb * dim; ++i) { - fp16_data[i] = knowhere::fp16(xb[i]); - } - ds->SetTensor(fp16_data); - return ds; -}; -#pragma GCC push_options -#pragma GCC optimize("O0") -knowhere::DataSetPtr -Array2DataSetBF16(float* xb, int nb, int dim) { - using bf16 = knowhere::bf16; - auto ds = std::make_shared(); - ds->SetIsOwner(true); - ds->SetRows(nb); - ds->SetDim(dim); - bf16* bf16_data = new bf16[nb * dim]; - for (int i = 0; i < nb * dim; ++i) { - bf16_data[i] = knowhere::bf16(xb[i]); - } - - ds->SetTensor(bf16_data); - return ds; -}; -#pragma GCC pop_options - int32_t CurrentVersion() { return knowhere::Version::GetCurrentVersion().VersionNumber(); } knowhere::DataSetPtr -Array2SparseDataSet(float* data, int nb1, int* ids, int nb2, int64_t* indptr, int nb3) { - int rows = nb3 - 1; - int cols = 0; - for (auto i = 0; i < nb2; ++i) { - if (ids[i] < 0) { - throw std::runtime_error("sparse matrix indics wrong"); - } - cols = std::max(ids[i] + 1, cols); - } - auto ds = std::make_shared(); - auto tensor = std::make_unique[]>(rows); - - for (int32_t i = 0; i < rows; ++i) { - int64_t start = indptr[i]; - int64_t end = indptr[i+1]; - if (start == end) { - throw std::runtime_error("sparse matrix indptr wrong"); - } - knowhere::sparse::SparseRow row(end - start); - for (auto j = start; j < end; ++j) { - row.set_at(j - start, ids[j], data[j]); - } - tensor[i] = std::move(row); - } - ds->SetRows(rows); - ds->SetDim(cols); - ds->SetTensor(tensor.release()); - ds->SetIsOwner(true); - ds->SetIsSparse(true); - return ds; -} - -knowhere::DataSetPtr -Array2DataSetI(uint8_t* xb, int nb, int dim) { +Array2DataSetI(int *xb, int nb, int dim){ auto ds = std::make_shared(); ds->SetIsOwner(false); ds->SetRows(nb); - ds->SetDim(dim*8); + ds->SetDim(dim*32); ds->SetTensor(xb); return ds; }; @@ -431,29 +305,6 @@ GetBinarySet() { return std::make_shared(); } -knowhere::BinaryPtr -Array2Binary(uint8_t* block, int size) { - GILReleaser rel; - - auto binary = std::make_shared(); - std::shared_ptr data_ptr(new uint8_t[size]); - for (int i = 0; i < size; i++) { - *(data_ptr.get() + i) = *(block + i); - } - binary->data = data_ptr; - binary->size = size; - return binary; -} - -knowhere::BinarySetPtr -CreateBinarySet(const std::string& name, knowhere::BinaryPtr binary) { - GILReleaser rel; - - auto binarySet = std::make_shared(); - binarySet->Append(name, binary); - return binarySet; -} - knowhere::DataSetPtr GetNullDataSet() { return nullptr; @@ -486,34 +337,12 @@ DataSetTensor2Array(knowhere::DataSetPtr result, float* data, int rows, int dim) } void -Float16DataSetTensor2Array(knowhere::DataSetPtr result, float* data, int rows, int dim) { - GILReleaser rel; - auto data_ = result->GetTensor(); - for (int i = 0; i < rows; i++) { - for (int j = 0; j < dim; ++j) { - *(data + i * dim + j) = (float)*((knowhere::fp16*)(data_) + i * dim + j); - } - } -} - -void -BFloat16DataSetTensor2Array(knowhere::DataSetPtr result, float* data, int rows, int dim) { - GILReleaser rel; - auto data_ = result->GetTensor(); - for (int i = 0; i < rows; i++) { - for (int j = 0; j < dim; ++j) { - *(data + i * dim + j) = (float)*((knowhere::bf16*)(data_) + i * dim + j); - } - } -} - -void -BinaryDataSetTensor2Array(knowhere::DataSetPtr result, uint8_t* data, int rows, int dim) { +BinaryDataSetTensor2Array(knowhere::DataSetPtr result, int32_t* data, int rows, int dim) { GILReleaser rel; auto data_ = result->GetTensor(); for (int i = 0; i < rows; i++) { for (int j = 0; j < dim; ++j) { - *(data + i * dim + j) = *((uint8_t*)(data_) + i * dim + j); + *(data + i * dim + j) = *((int32_t*)(data_) + i * dim + j); } } } @@ -595,36 +424,11 @@ Load(knowhere::BinarySetPtr binset, const std::string& file_name) { } } -bool -WriteIndexToDisk(const knowhere::BinarySetPtr binset, const std::string& index_type, const std::string& data_path) { - auto bin = binset->GetByName(index_type); - if (bin == nullptr) { - return false; - } - - std::ofstream outfile; - outfile.open(data_path, std::ios::binary | std::ios::trunc); - if (!outfile.good()) { - return false; - } - outfile.write(reinterpret_cast(bin->data.get()), bin->size); - outfile.flush(); - outfile.close(); - - return true; -} - -template knowhere::DataSetPtr BruteForceSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status) { GILReleaser rel; - expected res; - if (base_dataset->GetIsSparse()) { - res = knowhere::BruteForce::SearchSparse(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); - } else { - res = knowhere::BruteForce::Search(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); - } + auto res = knowhere::BruteForce::Search(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); if (res.has_value()) { status = knowhere::Status::success; return res.value(); @@ -634,12 +438,11 @@ BruteForceSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_d } } -template knowhere::DataSetPtr BruteForceRangeSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status) { GILReleaser rel; - auto res = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); + auto res = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); if (res.has_value()) { status = knowhere::Status::success; return res.value(); @@ -664,31 +467,4 @@ SetSimdType(const std::string type) { } } -void -SetBuildThreadPool(uint32_t num_threads) { - knowhere::InitBuildThreadPool(num_threads); -} - -void -SetSearchThreadPool(uint32_t num_threads) { - knowhere::InitSearchThreadPool(num_threads); -} - %} - -%template(AnnIteratorWrapVector) std::vector; - -%template(IndexWrapFloat) IndexWrap; -%template(IndexWrapFP16) IndexWrap; -%template(IndexWrapBF16) IndexWrap; -%template(IndexWrapBin) IndexWrap; - -%template(BruteForceSearchFloat) BruteForceSearch; -%template(BruteForceSearchFP16) BruteForceSearch; -%template(BruteForceSearchBF16) BruteForceSearch; -%template(BruteForceSearchBin) BruteForceSearch; - -%template(BruteForceRangeSearchFloat) BruteForceRangeSearch; -%template(BruteForceRangeSearchFP16) BruteForceRangeSearch; -%template(BruteForceRangeSearchBF16) BruteForceRangeSearch; -%template(BruteForceRangeSearchBin) BruteForceRangeSearch; diff --git a/python/setup.py b/python/setup.py index 8893901fe..814d200a6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -90,7 +90,7 @@ def get_readme(): description=( "A library for efficient similarity search and clustering of vectors." ), - url="https://github.com/zilliztech/knowhere", + url="https://github.com/milvus-io/knowhere", author="Milvus Team", author_email="milvus-team@zilliz.com", license='Apache License 2.0', diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc deleted file mode 100644 index 64d9b31b6..000000000 --- a/src/cluster/cluster.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include "knowhere/cluster/cluster.h" - -#include "knowhere/comp/time_recorder.h" -#include "knowhere/dataset.h" -#include "knowhere/expected.h" -#include "knowhere/log.h" - -#ifdef NOT_COMPILE_FOR_SWIG -#include "knowhere/prometheus_client.h" -#include "knowhere/tracer.h" -#endif - -namespace knowhere { - -inline Status -LoadConfig(Config* cfg, const Json& json, knowhere::PARAM_TYPE param_type, const std::string& method, - std::string* const msg = nullptr) { - Json json_(json); - auto res = Config::FormatAndCheck(*cfg, json_, msg); - LOG_KNOWHERE_DEBUG_ << method << " config dump: " << json_.dump(); - RETURN_IF_ERROR(res); - return Config::Load(*cfg, json_, param_type, msg); -} - -template -inline expected -Cluster::Train(const DataSet& dataset, const Json& json) { - auto cfg = this->node->CreateConfig(); - std::string msg; - auto status = LoadConfig(cfg.get(), json, knowhere::CLUSTER, "Train", &msg); - if (status != Status::success) { - return expected::Err(status, msg); - } - return this->node->Train(dataset, *cfg); -} - -template -inline expected -Cluster::Assign(const DataSet& dataset) { - return this->node->Assign(dataset); -} - -template -inline expected -Cluster::GetCentroids() const { - return this->node->GetCentroids(); -} - -template -inline std::string -Cluster::Type() const { - return this->node->Type(); -} - -template class Cluster; - -} // namespace knowhere diff --git a/src/cluster/cluster_factory.cc b/src/cluster/cluster_factory.cc deleted file mode 100644 index fbcde0b8f..000000000 --- a/src/cluster/cluster_factory.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (C) 2019-2023 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include "knowhere/cluster/cluster_factory.h" - -#include "knowhere/utils.h" - -namespace knowhere { - -template -expected> -ClusterFactory::Create(const std::string& name, const Object& object) { - static_assert(KnowhereDataTypeCheck::value == true); - auto& func_mapping_ = MapInstance(); - auto key = GetKey(name); - if (func_mapping_.find(key) == func_mapping_.end()) { - LOG_KNOWHERE_ERROR_ << "failed to find cluster type " << key << " in factory"; - return expected>::Err(Status::invalid_cluster_error, "cluster type not supported"); - } - LOG_KNOWHERE_INFO_ << "use key " << key << " to create knowhere cluster worker " << name; - auto fun_map_v = (FunMapValue>*)(func_mapping_[key].get()); - - return fun_map_v->fun_value(object); -} - -template -const ClusterFactory& -ClusterFactory::Register(const std::string& name, std::function(const Object&)> func) { - static_assert(KnowhereDataTypeCheck::value == true); - auto& func_mapping_ = MapInstance(); - auto key = GetKey(name); - assert(func_mapping_.find(key) == func_mapping_.end()); - func_mapping_[key] = std::make_unique>>(func); - return *this; -} - -ClusterFactory& -ClusterFactory::Instance() { - static ClusterFactory factory; - return factory; -} - -ClusterFactory::ClusterFactory() { -} - -ClusterFactory::FuncMap& -ClusterFactory::MapInstance() { - static FuncMap func_map; - return func_map; -} - -} // namespace knowhere - // -template knowhere::expected> -knowhere::ClusterFactory::Create(const std::string&, const Object&); -template knowhere::expected> -knowhere::ClusterFactory::Create(const std::string&, const Object&); -template knowhere::expected> -knowhere::ClusterFactory::Create(const std::string&, const Object&); -template const knowhere::ClusterFactory& -knowhere::ClusterFactory::Register( - const std::string&, std::function(const Object&)>); -template const knowhere::ClusterFactory& -knowhere::ClusterFactory::Register( - const std::string&, std::function(const Object&)>); -template const knowhere::ClusterFactory& -knowhere::ClusterFactory::Register( - const std::string&, std::function(const Object&)>); diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index d4b3cbf6e..ebd452020 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -14,66 +14,31 @@ #include #include "common/metric.h" +#include "common/range_util.h" #include "faiss/MetricType.h" #include "faiss/utils/binary_distances.h" #include "faiss/utils/distances.h" -#include "knowhere/bitsetview_idselector.h" #include "knowhere/comp/thread_pool.h" #include "knowhere/config.h" #include "knowhere/expected.h" -#include "knowhere/index/index_node.h" #include "knowhere/log.h" -#include "knowhere/range_util.h" -#include "knowhere/sparse_utils.h" #include "knowhere/utils.h" -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) -#include "knowhere/tracer.h" -#endif - namespace knowhere { /* knowhere wrapper API to call faiss brute force search for all metric types */ class BruteForceConfig : public BaseConfig {}; -namespace { - -template -expected> -GetDocValueComputer(const BruteForceConfig& cfg) { - if (IsMetricType(cfg.metric_type.value(), metric::IP)) { - return sparse::GetDocValueOriginalComputer(); - } else if (IsMetricType(cfg.metric_type.value(), metric::BM25)) { - if (!cfg.bm25_k1.has_value() || !cfg.bm25_b.has_value() || !cfg.bm25_avgdl.has_value()) { - return expected>::Err( - Status::invalid_args, "bm25_k1, bm25_b, bm25_avgdl must be set when searching for bm25 metric"); - } - auto k1 = cfg.bm25_k1.value(); - auto b = cfg.bm25_b.value(); - auto avgdl = cfg.bm25_avgdl.value(); - return sparse::GetDocValueBM25Computer(k1, b, avgdl); - } else { - return expected>::Err(Status::invalid_metric_type, - "metric type not supported for sparse vector"); - } -} - -} // namespace - -template expected BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - auto base = ConvertFromDataTypeIfNeeded(base_dataset); - auto query = ConvertFromDataTypeIfNeeded(query_dataset); - - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto dim = base->GetDim(); + auto xb = base_dataset->GetTensor(); + auto nb = base_dataset->GetRows(); + auto dim = base_dataset->GetDim(); - auto xq = query->GetTensor(); - auto nq = query->GetRows(); + auto xq = query_dataset->GetTensor(); + auto nq = query_dataset->GetRows(); BruteForceConfig cfg; std::string msg; @@ -82,22 +47,6 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset return expected::Err(status, msg); } -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - std::shared_ptr span = nullptr; - if (cfg.trace_id.has_value()) { - auto trace_id_str = tracer::GetIDFromHexStr(cfg.trace_id.value()); - auto span_id_str = tracer::GetIDFromHexStr(cfg.span_id.value()); - auto ctx = tracer::TraceContext{(uint8_t*)trace_id_str.c_str(), (uint8_t*)span_id_str.c_str(), - (uint8_t)cfg.trace_flags.value()}; - span = tracer::StartSpan("knowhere bf search", &ctx); - span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value()); - span->SetAttribute(meta::TOPK, cfg.k.value()); - span->SetAttribute(meta::ROWS, nb); - span->SetAttribute(meta::DIM, dim); - span->SetAttribute(meta::NQ, nq); - } -#endif - std::string metric_str = cfg.metric_type.value(); auto result = Str2FaissMetricType(metric_str); if (result.error() != Status::success) { @@ -107,26 +56,22 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset bool is_cosine = IsMetricType(metric_str, metric::COSINE); int topk = cfg.k.value(); - auto labels = std::make_unique(nq * topk); - auto distances = std::make_unique(nq * topk); + auto labels = new int64_t[nq * topk]; + auto distances = new float[nq * topk]; auto pool = ThreadPool::GetGlobalSearchThreadPool(); std::vector> futs; futs.reserve(nq); for (int i = 0; i < nq; ++i) { - futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] { + futs.emplace_back(pool->push([&, index = i] { ThreadPool::ScopedOmpSetter setter(1); - auto cur_labels = labels_ptr + topk * index; - auto cur_distances = distances_ptr + topk * index; - - BitsetViewIDSelector bw_idselector(bitset); - faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - + auto cur_labels = labels + topk * index; + auto cur_distances = distances + topk * index; switch (faiss_metric_type) { case faiss::METRIC_L2: { auto cur_query = (const float*)xq + dim * index; faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; - faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector); + faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset); break; } case faiss::METRIC_INNER_PRODUCT: { @@ -134,16 +79,16 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector); + faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset); } else { - faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); + faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); } break; } case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances}; - binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector); + binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset); break; } case faiss::METRIC_Hamming: { @@ -151,7 +96,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset std::vector int_distances(topk); faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()}; binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb, - dim / 8, id_selector); + dim / 8, bitset); for (int i = 0; i < topk; ++i) { cur_distances[i] = int_distances[i]; } @@ -162,7 +107,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset // only matched ids will be chosen, not to use heap auto cur_query = (const uint8_t*)xq + (dim / 8) * index; binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances, - cur_labels, id_selector); + cur_labels, bitset); break; } default: { @@ -173,54 +118,29 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset return Status::success; })); } - auto ret = WaitAllSuccess(futs); - if (ret != Status::success) { - return expected::Err(ret, "failed to brute force search"); - } - auto res = GenResultDataSet(nq, cfg.k.value(), std::move(labels), std::move(distances)); - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - if (cfg.trace_id.has_value()) { - span->End(); + for (auto& fut : futs) { + fut.wait(); + auto ret = fut.result().value(); + if (ret != Status::success) { + return expected::Err(ret, "failed to brute force search"); + } } -#endif - - return res; + return GenResultDataSet(nq, cfg.k.value(), labels, distances); } -template Status BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset) { - auto base = ConvertFromDataTypeIfNeeded(base_dataset); - auto query = ConvertFromDataTypeIfNeeded(query_dataset); - - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto dim = base->GetDim(); + auto xb = base_dataset->GetTensor(); + auto nb = base_dataset->GetRows(); + auto dim = base_dataset->GetDim(); - auto xq = query->GetTensor(); - auto nq = query->GetRows(); + auto xq = query_dataset->GetTensor(); + auto nq = query_dataset->GetRows(); BruteForceConfig cfg; RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH)); -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - std::shared_ptr span = nullptr; - if (cfg.trace_id.has_value()) { - auto trace_id_str = tracer::GetIDFromHexStr(cfg.trace_id.value()); - auto span_id_str = tracer::GetIDFromHexStr(cfg.span_id.value()); - auto ctx = tracer::TraceContext{(uint8_t*)trace_id_str.c_str(), (uint8_t*)span_id_str.c_str(), - (uint8_t)cfg.trace_flags.value()}; - span = tracer::StartSpan("knowhere bf search with buf", &ctx); - span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value()); - span->SetAttribute(meta::TOPK, cfg.k.value()); - span->SetAttribute(meta::ROWS, nb); - span->SetAttribute(meta::DIM, dim); - span->SetAttribute(meta::NQ, nq); - } -#endif - std::string metric_str = cfg.metric_type.value(); auto result = Str2FaissMetricType(cfg.metric_type.value()); if (result.error() != Status::success) { @@ -241,15 +161,11 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ ThreadPool::ScopedOmpSetter setter(1); auto cur_labels = labels + topk * index; auto cur_distances = distances + topk * index; - - BitsetViewIDSelector bw_idselector(bitset); - faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - switch (faiss_metric_type) { case faiss::METRIC_L2: { auto cur_query = (const float*)xq + dim * index; faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; - faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector); + faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, bitset); break; } case faiss::METRIC_INNER_PRODUCT: { @@ -257,16 +173,16 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ faiss::float_minheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, id_selector); + faiss::knn_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, &buf, bitset); } else { - faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, id_selector); + faiss::knn_inner_product(cur_query, (const float*)xb, dim, 1, nb, &buf, bitset); } break; } case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::float_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, cur_distances}; - binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, id_selector); + binary_knn_hc(faiss::METRIC_Jaccard, &res, cur_query, (const uint8_t*)xb, nb, dim / 8, bitset); break; } case faiss::METRIC_Hamming: { @@ -274,7 +190,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ std::vector int_distances(topk); faiss::int_maxheap_array_t res = {size_t(1), size_t(topk), cur_labels, int_distances.data()}; binary_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)cur_query, (const uint8_t*)xb, nb, - dim / 8, id_selector); + dim / 8, bitset); for (int i = 0; i < topk; ++i) { cur_distances[i] = int_distances[i]; } @@ -285,7 +201,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ // only matched ids will be chosen, not to use heap auto cur_query = (const uint8_t*)xq + (dim / 8) * index; binary_knn_mc(faiss_metric_type, cur_query, (const uint8_t*)xb, 1, nb, topk, dim / 8, cur_distances, - cur_labels, id_selector); + cur_labels, bitset); break; } default: { @@ -296,36 +212,25 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ return Status::success; })); } - RETURN_IF_ERROR(WaitAllSuccess(futs)); - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - if (cfg.trace_id.has_value()) { - span->End(); + for (auto& fut : futs) { + fut.wait(); + auto ret = fut.result().value(); + RETURN_IF_ERROR(ret); } -#endif - return Status::success; } /** knowhere wrapper API to call faiss brute force range search for all metric types */ -template expected BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - DataSetPtr base(base_dataset); - DataSetPtr query(query_dataset); - bool is_sparse = std::is_same>::value; - if (!is_sparse) { - base = ConvertFromDataTypeIfNeeded(base_dataset); - query = ConvertFromDataTypeIfNeeded(query_dataset); - } - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto dim = base->GetDim(); + auto xb = base_dataset->GetTensor(); + auto nb = base_dataset->GetRows(); + auto dim = base_dataset->GetDim(); - auto xq = query->GetTensor(); - auto nq = query->GetRows(); + auto xq = query_dataset->GetTensor(); + auto nq = query_dataset->GetRows(); BruteForceConfig cfg; std::string msg; @@ -334,44 +239,12 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da return expected::Err(status, std::move(msg)); } -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - std::shared_ptr span = nullptr; - if (cfg.trace_id.has_value()) { - auto trace_id_str = tracer::GetIDFromHexStr(cfg.trace_id.value()); - auto span_id_str = tracer::GetIDFromHexStr(cfg.span_id.value()); - auto ctx = tracer::TraceContext{(uint8_t*)trace_id_str.c_str(), (uint8_t*)span_id_str.c_str(), - (uint8_t)cfg.trace_flags.value()}; - span = tracer::StartSpan("knowhere bf range search", &ctx); - span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value()); - span->SetAttribute(meta::RADIUS, cfg.radius.value()); - if (cfg.range_filter.value() != defaultRangeFilter) { - span->SetAttribute(meta::RANGE_FILTER, cfg.range_filter.value()); - } - span->SetAttribute(meta::ROWS, nb); - span->SetAttribute(meta::DIM, dim); - span->SetAttribute(meta::NQ, nq); - } -#endif - std::string metric_str = cfg.metric_type.value(); - const bool is_bm25 = IsMetricType(metric_str, metric::BM25); - - faiss::MetricType faiss_metric_type; - sparse::DocValueComputer sparse_computer; - if (!is_sparse) { - auto result = Str2FaissMetricType(metric_str); - if (result.error() != Status::success) { - return expected::Err(result.error(), result.what()); - } - faiss_metric_type = result.value(); - } else { - auto computer_or = GetDocValueComputer(cfg); - if (!computer_or.has_value()) { - return expected::Err(computer_or.error(), computer_or.what()); - } - sparse_computer = computer_or.value(); + auto result = Str2FaissMetricType(metric_str); + if (result.error() != Status::success) { + return expected::Err(result.error(), result.what()); } - + faiss::MetricType faiss_metric_type = result.value(); bool is_cosine = IsMetricType(metric_str, metric::COSINE); auto radius = cfg.radius.value(); @@ -382,44 +255,18 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da std::vector> result_id_array(nq); std::vector> result_dist_array(nq); - + std::vector result_size(nq); + std::vector result_lims(nq + 1); std::vector> futs; futs.reserve(nq); for (int i = 0; i < nq; ++i) { futs.emplace_back(pool->push([&, index = i] { - if (is_sparse) { - auto cur_query = (const sparse::SparseRow*)xq + index; - auto xb_sparse = (const sparse::SparseRow*)xb; - for (int j = 0; j < nb; ++j) { - if (!bitset.empty() && bitset.test(j)) { - continue; - } - float row_sum = 0; - if (is_bm25) { - for (size_t k = 0; k < xb_sparse[j].size(); ++k) { - auto [d, v] = xb_sparse[j][k]; - row_sum += v; - } - } - auto dist = cur_query->dot(xb_sparse[j], sparse_computer, row_sum); - if (dist > radius && dist <= range_filter) { - result_id_array[index].push_back(j); - result_dist_array[index].push_back(dist); - } - } - return Status::success; - } - // else not sparse: ThreadPool::ScopedOmpSetter setter(1); faiss::RangeSearchResult res(1); - - BitsetViewIDSelector bw_idselector(bitset); - faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - switch (faiss_metric_type) { case faiss::METRIC_L2: { auto cur_query = (const float*)xq + dim * index; - faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, id_selector); + faiss::range_search_L2sqr(cur_query, (const float*)xb, dim, 1, nb, radius, &res, bitset); break; } case faiss::METRIC_INNER_PRODUCT: { @@ -428,25 +275,24 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da if (is_cosine) { auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); faiss::range_search_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, radius, - &res, id_selector); + &res, bitset); } else { faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, - id_selector); + bitset); } break; } case faiss::METRIC_Jaccard: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; - faiss::binary_range_search, float>(faiss::METRIC_Jaccard, cur_query, - (const uint8_t*)xb, 1, nb, radius, - dim / 8, &res, id_selector); + faiss::binary_range_search, float>( + faiss::METRIC_Jaccard, cur_query, (const uint8_t*)xb, 1, nb, radius, dim / 8, &res, bitset); break; } case faiss::METRIC_Hamming: { auto cur_query = (const uint8_t*)xq + (dim / 8) * index; faiss::binary_range_search, int>(faiss::METRIC_Hamming, cur_query, (const uint8_t*)xb, 1, nb, (int)radius, - dim / 8, &res, id_selector); + dim / 8, &res, bitset); break; } default: { @@ -457,6 +303,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da auto elem_cnt = res.lims[1]; result_dist_array[index].resize(elem_cnt); result_id_array[index].resize(elem_cnt); + result_size[index] = elem_cnt; for (size_t j = 0; j < elem_cnt; j++) { result_dist_array[index][j] = res.distances[j]; result_id_array[index][j] = res.labels[j]; @@ -468,386 +315,18 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da return Status::success; })); } - auto ret = WaitAllSuccess(futs); - if (ret != Status::success) { - return expected::Err(ret, "failed to brute force search"); - } - - auto range_search_result = - GetRangeSearchResult(result_dist_array, result_id_array, is_ip || is_bm25, nq, radius, range_filter); - auto res = GenResultDataSet(nq, std::move(range_search_result)); - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - if (cfg.trace_id.has_value()) { - span->End(); - } -#endif - - return res; -} - -Status -BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, sparse::label_t* labels, - float* distances, const Json& config, const BitsetView& bitset) { - auto base = static_cast*>(base_dataset->GetTensor()); - auto rows = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); - - auto xq = static_cast*>(query_dataset->GetTensor()); - auto nq = query_dataset->GetRows(); - - BruteForceConfig cfg; - std::string msg; - auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg); - if (status != Status::success) { - LOG_KNOWHERE_ERROR_ << "Failed to load config, msg is: " << msg; - return status; - } - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - std::shared_ptr span = nullptr; - if (cfg.trace_id.has_value()) { - auto trace_id_str = tracer::GetIDFromHexStr(cfg.trace_id.value()); - auto span_id_str = tracer::GetIDFromHexStr(cfg.span_id.value()); - auto ctx = tracer::TraceContext{(uint8_t*)trace_id_str.c_str(), (uint8_t*)span_id_str.c_str(), - (uint8_t)cfg.trace_flags.value()}; - span = tracer::StartSpan("knowhere bf search sparse with buf", &ctx); - span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value()); - span->SetAttribute(meta::TOPK, cfg.k.value()); - span->SetAttribute(meta::ROWS, rows); - span->SetAttribute(meta::DIM, dim); - span->SetAttribute(meta::NQ, nq); - } -#endif - - std::string metric_str = cfg.metric_type.value(); - const bool is_bm25 = IsMetricType(metric_str, metric::BM25); - - auto computer_or = GetDocValueComputer(cfg); - if (!computer_or.has_value()) { - return computer_or.error(); - } - auto computer = computer_or.value(); - - int topk = cfg.k.value(); - std::fill(distances, distances + nq * topk, std::numeric_limits::quiet_NaN()); - std::fill(labels, labels + nq * topk, -1); - - auto pool = ThreadPool::GetGlobalSearchThreadPool(); - std::vector> futs; - futs.reserve(nq); - for (int64_t i = 0; i < nq; ++i) { - futs.emplace_back(pool->push([&, index = i] { - auto cur_labels = labels + topk * index; - auto cur_distances = distances + topk * index; - - const auto& row = xq[index]; - if (row.size() == 0) { - return; - } - sparse::MaxMinHeap heap(topk); - for (int64_t j = 0; j < rows; ++j) { - if (!bitset.empty() && bitset.test(j)) { - continue; - } - float row_sum = 0; - if (is_bm25) { - for (size_t k = 0; k < base[j].size(); ++k) { - auto [d, v] = base[j][k]; - row_sum += v; - } - } - float dist = row.dot(base[j], computer, row_sum); - if (dist > 0) { - heap.push(j, dist); - } - } - int result_size = heap.size(); - for (int j = result_size - 1; j >= 0; --j) { - cur_labels[j] = heap.top().id; - cur_distances[j] = heap.top().val; - heap.pop(); - } - })); - } - WaitAllSuccess(futs); - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - if (cfg.trace_id.has_value()) { - span->End(); - } -#endif - - return Status::success; -} - -expected -BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset) { - auto nq = query_dataset->GetRows(); - BruteForceConfig cfg; - std::string msg; - auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg); - if (status != Status::success) { - return expected::Err(status, msg); - } - - int topk = cfg.k.value(); - auto labels = std::make_unique(nq * topk); - auto distances = std::make_unique(nq * topk); - - SearchSparseWithBuf(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset); - return GenResultDataSet(nq, topk, std::move(labels), std::move(distances)); -} - -template -expected> -BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset) { - auto base = ConvertFromDataTypeIfNeeded(base_dataset); - auto query = ConvertFromDataTypeIfNeeded(query_dataset); - - auto xb = base->GetTensor(); - auto nb = base->GetRows(); - auto dim = base->GetDim(); - - auto xq = query->GetTensor(); - auto nq = query->GetRows(); - BruteForceConfig cfg; - std::string msg; - auto status = Config::Load(cfg, config, knowhere::ITERATOR, &msg); - if (status != Status::success) { - return expected>::Err(status, msg); - } - std::string metric_str = cfg.metric_type.value(); - auto result = Str2FaissMetricType(metric_str); - if (result.error() != Status::success) { - return expected>::Err(result.error(), result.what()); - } - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - std::shared_ptr span = nullptr; - if (cfg.trace_id.has_value()) { - auto trace_id_str = tracer::GetIDFromHexStr(cfg.trace_id.value()); - auto span_id_str = tracer::GetIDFromHexStr(cfg.span_id.value()); - auto ctx = tracer::TraceContext{(uint8_t*)trace_id_str.c_str(), (uint8_t*)span_id_str.c_str(), - (uint8_t)cfg.trace_flags.value()}; - span = tracer::StartSpan("knowhere bf ann iterator initialization", &ctx); - span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value()); - span->SetAttribute(meta::ROWS, nb); - span->SetAttribute(meta::DIM, dim); - span->SetAttribute(meta::NQ, nq); - } -#endif - faiss::MetricType faiss_metric_type = result.value(); - bool is_cosine = IsMetricType(metric_str, metric::COSINE); - - auto pool = ThreadPool::GetGlobalSearchThreadPool(); - auto vec = std::vector(nq, nullptr); - std::vector> futs; - futs.reserve(nq); - - for (int i = 0; i < nq; ++i) { - futs.emplace_back(pool->push([&, index = i] { - ThreadPool::ScopedOmpSetter setter(1); - - BitsetViewIDSelector bw_idselector(bitset); - faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine; - auto max_dis = larger_is_closer ? std::numeric_limits::lowest() : std::numeric_limits::max(); - std::vector distances_ids(nb, {-1, max_dis}); - - switch (faiss_metric_type) { - case faiss::METRIC_L2: { - auto cur_query = (const float*)xq + dim * index; - faiss::all_L2sqr(cur_query, (const float*)xb, dim, 1, nb, distances_ids, nullptr, id_selector); - break; - } - case faiss::METRIC_INNER_PRODUCT: { - auto cur_query = (const float*)xq + dim * index; - if (is_cosine) { - auto copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - faiss::all_cosine(copied_query.get(), (const float*)xb, nullptr, dim, 1, nb, distances_ids, - id_selector); - } else { - faiss::all_inner_product(cur_query, (const float*)xb, dim, 1, nb, distances_ids, id_selector); - } - break; - } - default: { - LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << cfg.metric_type.value(); - return Status::invalid_metric_type; - } - } - vec[index] = std::make_shared(std::move(distances_ids), larger_is_closer); - - return Status::success; - })); - } - - auto ret = WaitAllSuccess(futs); - if (ret != Status::success) { - return expected>::Err(ret, "failed to brute force search for iterator"); - } - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - if (cfg.trace_id.has_value()) { - span->End(); - } -#endif - return vec; -} - -template <> -expected> -BruteForce::AnnIterator>(const DataSetPtr base_dataset, - const DataSetPtr query_dataset, const Json& config, - const BitsetView& bitset) { - auto base = static_cast*>(base_dataset->GetTensor()); - auto rows = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); - - auto xq = static_cast*>(query_dataset->GetTensor()); - auto nq = query_dataset->GetRows(); - - BruteForceConfig cfg; - std::string msg; - auto status = Config::Load(cfg, config, knowhere::ITERATOR, &msg); - if (status != Status::success) { - LOG_KNOWHERE_ERROR_ << "Failed to load config: " << msg; - return expected>::Err( - status, "Failed to brute force search sparse for iterator: failed to load config: " + msg); - } - - std::string metric_str = cfg.metric_type.value(); - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - std::shared_ptr span = nullptr; - if (cfg.trace_id.has_value()) { - auto trace_id_str = tracer::GetIDFromHexStr(cfg.trace_id.value()); - auto span_id_str = tracer::GetIDFromHexStr(cfg.span_id.value()); - auto ctx = tracer::TraceContext{(uint8_t*)trace_id_str.c_str(), (uint8_t*)span_id_str.c_str(), - (uint8_t)cfg.trace_flags.value()}; - span = tracer::StartSpan("knowhere bf iterator sparse", &ctx); - span->SetAttribute(meta::METRIC_TYPE, metric_str); - span->SetAttribute(meta::ROWS, rows); - span->SetAttribute(meta::DIM, dim); - span->SetAttribute(meta::NQ, nq); - } -#endif - - const bool is_bm25 = IsMetricType(metric_str, metric::BM25); - - auto computer_or = GetDocValueComputer(cfg); - if (!computer_or.has_value()) { - return expected>::Err(computer_or.error(), computer_or.what()); - } - auto computer = computer_or.value(); - - auto pool = ThreadPool::GetGlobalSearchThreadPool(); - auto vec = std::vector(nq, nullptr); - std::vector> futs; - futs.reserve(nq); - for (int64_t i = 0; i < nq; ++i) { - futs.emplace_back(pool->push([&, index = i] { - const auto& row = xq[index]; - std::vector distances_ids; - if (row.size() > 0) { - for (int64_t j = 0; j < rows; ++j) { - if (!bitset.empty() && bitset.test(j)) { - continue; - } - float row_sum = 0; - if (is_bm25) { - for (size_t k = 0; k < base[j].size(); ++k) { - auto [d, v] = base[j][k]; - row_sum += v; - } - } - auto dist = row.dot(base[j], computer, row_sum); - if (dist > 0) { - distances_ids.emplace_back(j, dist); - } - } - } - vec[index] = std::make_shared(std::move(distances_ids), true); - })); - } - WaitAllSuccess(futs); - -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) - if (cfg.trace_id.has_value()) { - span->End(); + for (auto& fut : futs) { + fut.wait(); + auto ret = fut.result().value(); + if (ret != Status::success) { + return expected::Err(ret, "failed to brute force search"); + } } -#endif - return vec; + int64_t* ids = nullptr; + float* distances = nullptr; + size_t* lims = nullptr; + GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); + return GenResultDataSet(nq, ids, distances, lims); } - } // namespace knowhere -template knowhere::expected -knowhere::BruteForce::Search(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, const knowhere::Json& config, - const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::Search(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, const knowhere::Json& config, - const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::Search(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, const knowhere::Json& config, - const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::Search(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, const knowhere::Json& config, - const knowhere::BitsetView& bitset); -template knowhere::Status -knowhere::BruteForce::SearchWithBuf(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, int64_t* ids, float* dis, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::Status -knowhere::BruteForce::SearchWithBuf(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, int64_t* ids, float* dis, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::Status -knowhere::BruteForce::SearchWithBuf(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, int64_t* ids, float* dis, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::Status -knowhere::BruteForce::SearchWithBuf(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, int64_t* ids, float* dis, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::RangeSearch(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::RangeSearch(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::RangeSearch(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::RangeSearch(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::expected -knowhere::BruteForce::RangeSearch>(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, - const knowhere::BitsetView& bitset); - -template knowhere::expected> -knowhere::BruteForce::AnnIterator(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::expected> -knowhere::BruteForce::AnnIterator(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); -template knowhere::expected> -knowhere::BruteForce::AnnIterator(const knowhere::DataSetPtr base_dataset, - const knowhere::DataSetPtr query_dataset, - const knowhere::Json& config, const knowhere::BitsetView& bitset); diff --git a/src/common/comp/knowhere_config.cc b/src/common/comp/knowhere_config.cc index bef49666a..126529d21 100644 --- a/src/common/comp/knowhere_config.cc +++ b/src/common/comp/knowhere_config.cc @@ -23,11 +23,10 @@ #ifdef KNOWHERE_WITH_GPU #include "index/gpu/gpu_res_mgr.h" #endif +#include "simd/hook.h" #ifdef KNOWHERE_WITH_RAFT -#include "common/raft/integration/raft_initialization.hpp" -#include "cuda_runtime_api.h" +#include "common/raft/raft_utils.h" #endif -#include "simd/hook.h" namespace knowhere { @@ -90,18 +89,6 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) { return simd_str; } -void -KnowhereConfig::EnablePatchForComputeFP32AsBF16() { - LOG_KNOWHERE_INFO_ << "Enable patch for compute fp32 as bf16"; - faiss::enable_patch_for_fp32_bf16(); -} - -void -KnowhereConfig::DisablePatchForComputeFP32AsBF16() { - LOG_KNOWHERE_INFO_ << "Disable patch for compute fp32 as bf16"; - faiss::disable_patch_for_fp32_bf16(); -} - void KnowhereConfig::SetBlasThreshold(const int64_t use_blas_threshold) { LOG_KNOWHERE_INFO_ << "Set faiss::distance_compute_blas_threshold to " << use_blas_threshold; @@ -148,29 +135,18 @@ KnowhereConfig::SetAioContextPool(size_t num_ctx) { void KnowhereConfig::SetBuildThreadPoolSize(size_t num_threads) { - knowhere::ThreadPool::SetGlobalBuildThreadPoolSize(num_threads); -} - -size_t -KnowhereConfig::GetBuildThreadPoolSize() { - return knowhere::ThreadPool::GetGlobalBuildThreadPoolSize(); + knowhere::ThreadPool::InitGlobalBuildThreadPool(num_threads); } void KnowhereConfig::SetSearchThreadPoolSize(size_t num_threads) { - knowhere::ThreadPool::SetGlobalSearchThreadPoolSize(num_threads); -} - -size_t -KnowhereConfig::GetSearchThreadPoolSize() { - return knowhere::ThreadPool::GetGlobalSearchThreadPoolSize(); + knowhere::ThreadPool::InitGlobalSearchThreadPool(num_threads); } void KnowhereConfig::InitGPUResource(int64_t gpu_id, int64_t res_num) { #ifdef KNOWHERE_WITH_GPU LOG_KNOWHERE_INFO_ << "init GPU resource for gpu id " << gpu_id << ", resource num " << res_num; - knowhere::GPUParams gpu_params(res_num); knowhere::GPUResMgr::GetInstance().InitDevice(gpu_id, gpu_params); knowhere::GPUResMgr::GetInstance().Init(); @@ -184,36 +160,10 @@ KnowhereConfig::FreeGPUResource() { knowhere::GPUResMgr::GetInstance().Free(); #endif } - void KnowhereConfig::SetRaftMemPool(size_t init_size, size_t max_size) { #ifdef KNOWHERE_WITH_RAFT - int count = 0; - auto status = cudaGetDeviceCount(&count); - if (status != cudaSuccess) { - LOG_KNOWHERE_INFO_ << cudaGetErrorString(status); - return; - } - if (count < 1) { - LOG_KNOWHERE_INFO_ << "GPU not available"; - return; - } - - auto config = raft_knowhere::raft_configuration{}; - config.init_mem_pool_size_mb = init_size; - config.max_mem_pool_size_mb = max_size; - // This should probably be a separate configuration option, but fine for now - config.max_workspace_size_mb = max_size; - raft_knowhere::initialize_raft(config); -#endif -} - -void -KnowhereConfig::SetRaftMemPool() { - // Overload for default values -#ifdef KNOWHERE_WITH_RAFT - auto config = raft_knowhere::raft_configuration{}; - raft_knowhere::initialize_raft(config); + raft_utils::set_mem_pool_size(init_size, max_size); #endif } diff --git a/src/common/comp/materialized_view.cc b/src/common/comp/materialized_view.cc deleted file mode 100644 index 1a090d97b..000000000 --- a/src/common/comp/materialized_view.cc +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (C) 2019-2024 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include "knowhere/comp/materialized_view.h" - -namespace knowhere { - -constexpr std::string_view kFieldIdToTouchedCategoriesCntKey = "field_id_to_touched_categories_cnt"; -constexpr std::string_view kIsPureAndKey = "is_pure_and"; -constexpr std::string_view kHasNotKey = "has_not"; - -void -to_json(nlohmann::json& j, const MaterializedViewSearchInfo& info) { - j = nlohmann::json{{kFieldIdToTouchedCategoriesCntKey, info.field_id_to_touched_categories_cnt}, - {kIsPureAndKey, info.is_pure_and}, - {kHasNotKey, info.has_not}}; -} - -void -from_json(const nlohmann::json& j, MaterializedViewSearchInfo& info) { - if (j.is_null()) { - // When the json is null, we return the default value of struct MaterializedViewSearchInfo - // If `MaterializedViewSearchInfo = j[xxx]` is called, a default constructed MaterializedViewSearchInfo will - // be created. Therefore the second parameter `info` here should have default values. - return; - } - - // if any of the keys is missing, the corresponding field in `info` will have default value - if (j.contains(kFieldIdToTouchedCategoriesCntKey)) { - j.at(kFieldIdToTouchedCategoriesCntKey).get_to(info.field_id_to_touched_categories_cnt); - } - if (j.contains(kIsPureAndKey)) { - j.at(kIsPureAndKey).get_to(info.is_pure_and); - } - if (j.contains(kHasNotKey)) { - j.at(kHasNotKey).get_to(info.has_not); - } -} -} // namespace knowhere diff --git a/src/common/config.cc b/src/common/config.cc index 81d3b80f9..f54c705a3 100644 --- a/src/common/config.cc +++ b/src/common/config.cc @@ -11,24 +11,13 @@ #include "knowhere/config.h" -#include "index/diskann/diskann_config.h" -#include "index/flat/flat_config.h" -#include "index/gpu_raft/gpu_raft_brute_force_config.h" -#include "index/gpu_raft/gpu_raft_cagra_config.h" -#include "index/gpu_raft/gpu_raft_ivf_flat_config.h" -#include "index/gpu_raft/gpu_raft_ivf_pq_config.h" -#include "index/hnsw/hnsw_config.h" -#include "index/ivf/ivf_config.h" -#include "index/sparse/sparse_inverted_index_config.h" #include "knowhere/log.h" - namespace knowhere { static const std::unordered_set ext_legal_json_keys = {"metric_type", "dim", "nlist", // IVF param "nprobe", // IVF param - "use_elkan", // IVF param "ssize", // IVF_FLAT_CC param "nbits", // IVF_PQ param "m", // IVF_PQ param @@ -64,22 +53,21 @@ static const std::unordered_set ext_legal_json_keys = {"metric_type Status Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg) { - // Deprecated invalid json key check for now - // try { - // for (auto& it : json.items()) { - // // valid only if it.key() exists in one of cfg.__DICT__ and ext_legal_json_keys - // if (cfg.__DICT__.find(it.key()) == cfg.__DICT__.end() && - // ext_legal_json_keys.find(it.key()) == ext_legal_json_keys.end()) { - // throw KnowhereException(std::string("invalid json key ") + it.key()); - // } - // } - // } catch (std::exception& e) { - // LOG_KNOWHERE_ERROR_ << e.what(); - // if (err_msg) { - // *err_msg = e.what(); - // } - // return Status::invalid_param_in_json; - // } + try { + for (auto& it : json.items()) { + // valid only if it.key() exists in one of cfg.__DICT__ and ext_legal_json_keys + if (cfg.__DICT__.find(it.key()) == cfg.__DICT__.end() && + ext_legal_json_keys.find(it.key()) == ext_legal_json_keys.end()) { + throw KnowhereException(std::string("invalid json key ") + it.key()); + } + } + } catch (std::exception& e) { + LOG_KNOWHERE_ERROR_ << e.what(); + if (err_msg) { + *err_msg = e.what(); + } + return Status::invalid_param_in_json; + } try { for (const auto& it : cfg.__DICT__) { @@ -118,55 +106,4 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg } return Status::success; } - } // namespace knowhere - -extern "C" __attribute__((visibility("default"))) int -CheckConfig(int index_type, char const* str, int n, int param_type); - -int -CheckConfig(int index_type, const char* str, int n, int param_type) { - if (!str || n <= 0) { - return int(knowhere::Status::invalid_args); - } - knowhere::Json json = knowhere::Json::parse(str, str + n); - std::unique_ptr cfg; - - switch (index_type) { - case 0: - cfg = std::make_unique(); - break; - case 1: - cfg = std::make_unique(); - break; - case 2: - cfg = std::make_unique(); - break; - case 3: - cfg = std::make_unique(); - break; - case 4: - cfg = std::make_unique(); - break; - case 5: - cfg = std::make_unique(); - break; - case 6: - cfg = std::make_unique(); - break; - case 7: - cfg = std::make_unique(); - break; - case 8: - cfg = std::make_unique(); - break; - default: - return int(knowhere::Status::invalid_args); - } - - auto res = knowhere::Config::FormatAndCheck(*cfg, json, nullptr); - if (res != knowhere::Status::success) { - return int(res); - } - return int(knowhere::Config::Load(*cfg, json, knowhere::PARAM_TYPE(param_type), nullptr)); -} diff --git a/src/common/factory.cc b/src/common/factory.cc new file mode 100644 index 000000000..62d21f29d --- /dev/null +++ b/src/common/factory.cc @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/factory.h" + +namespace knowhere { + +Index +IndexFactory::Create(const std::string& name, const int32_t& version, const Object& object) { + auto& func_mapping_ = MapInstance(); + assert(func_mapping_.find(name) != func_mapping_.end()); + LOG_KNOWHERE_INFO_ << "create knowhere index " << name << " with version " << version; + return func_mapping_[name](version, object); +} + +const IndexFactory& +IndexFactory::Register(const std::string& name, + std::function(const int32_t& version, const Object&)> func) { + auto& func_mapping_ = MapInstance(); + func_mapping_[name] = func; + return *this; +} + +IndexFactory& +IndexFactory::Instance() { + static IndexFactory factory; + return factory; +} + +IndexFactory::IndexFactory() { +} +IndexFactory::FuncMap& +IndexFactory::MapInstance() { + static FuncMap func_map; + return func_map; +} +} // namespace knowhere diff --git a/src/common/index.cc b/src/common/index.cc new file mode 100644 index 000000000..3d372ffe2 --- /dev/null +++ b/src/common/index.cc @@ -0,0 +1,248 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index.h" + +#include "knowhere/comp/time_recorder.h" +#include "knowhere/dataset.h" +#include "knowhere/expected.h" +#include "knowhere/log.h" + +#ifdef NOT_COMPILE_FOR_SWIG +#include "knowhere/prometheus_client.h" +#endif + +namespace knowhere { + +inline Status +LoadConfig(BaseConfig* cfg, const Json& json, knowhere::PARAM_TYPE param_type, const std::string& method, + std::string* const msg = nullptr) { + Json json_(json); + auto res = Config::FormatAndCheck(*cfg, json_, msg); + LOG_KNOWHERE_DEBUG_ << method << " config dump: " << json_.dump(); + RETURN_IF_ERROR(res); + return Config::Load(*cfg, json_, param_type, msg); +} + +template +inline Status +Index::Build(const DataSet& dataset, const Json& json) { + auto cfg = this->node->CreateConfig(); + RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Build")); + RETURN_IF_ERROR(cfg->CheckAndAdjustForBuild()); + +#ifdef NOT_COMPILE_FOR_SWIG + TimeRecorder rc("Build index", 2); + auto res = this->node->Build(dataset, *cfg); + auto span = rc.ElapseFromBegin("done"); + span *= 0.000001; // convert to s + knowhere_build_latency.Observe(span); + knowhere_build_count.Increment(); +#else + auto res = this->node->Build(dataset, *cfg); +#endif + return res; +} + +template +inline Status +Index::Train(const DataSet& dataset, const Json& json) { + auto cfg = this->node->CreateConfig(); + RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Train")); + return this->node->Train(dataset, *cfg); +} + +template +inline Status +Index::Add(const DataSet& dataset, const Json& json) { + auto cfg = this->node->CreateConfig(); + RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Add")); + return this->node->Add(dataset, *cfg); +} + +template +inline expected +Index::Search(const DataSet& dataset, const Json& json, const BitsetView& bitset) const { + auto cfg = this->node->CreateConfig(); + std::string msg; + const Status load_status = LoadConfig(cfg.get(), json, knowhere::SEARCH, "Search", &msg); + if (load_status != Status::success) { + return expected::Err(load_status, msg); + } + const Status search_status = cfg->CheckAndAdjustForSearch(&msg); + if (search_status != Status::success) { + return expected::Err(search_status, msg); + } + +#ifdef NOT_COMPILE_FOR_SWIG + TimeRecorder rc("Search"); + auto res = this->node->Search(dataset, *cfg, bitset); + auto span = rc.ElapseFromBegin("done"); + span *= 0.001; // convert to ms + knowhere_search_latency.Observe(span); + knowhere_search_count.Increment(); + knowhere_search_topk.Observe(cfg->k.value()); +#else + auto res = this->node->Search(dataset, *cfg, bitset); +#endif + return res; +} + +template +inline expected>> +Index::AnnIterator(const DataSet& dataset, const Json& json, const BitsetView& bitset) const { + auto cfg = this->node->CreateConfig(); + std::string msg; + Status status = LoadConfig(cfg.get(), json, knowhere::ITERATOR, "Iterator", &msg); + if (status != Status::success) { + return expected>>::Err(status, msg); + } + status = cfg->CheckAndAdjustForIterator(); + if (status != Status::success) { + return expected>>::Err(status, "invalid params for iterator"); + } + +#ifdef NOT_COMPILE_FOR_SWIG + // note that this time includes only the initial search phase of iterator. + TimeRecorder rc("AnnIterator"); + auto res = this->node->AnnIterator(dataset, *cfg, bitset); + auto span = rc.ElapseFromBegin("done"); + span *= 0.001; // convert to ms + knowhere_search_latency.Observe(span); + knowhere_ann_iterator_count.Increment(); +#else + auto res = this->node->AnnIterator(dataset, *cfg, bitset); +#endif + return res; +} + +template +inline expected +Index::RangeSearch(const DataSet& dataset, const Json& json, const BitsetView& bitset) const { + auto cfg = this->node->CreateConfig(); + std::string msg; + auto status = LoadConfig(cfg.get(), json, knowhere::RANGE_SEARCH, "RangeSearch", &msg); + if (status != Status::success) { + return expected::Err(status, std::move(msg)); + } + status = cfg->CheckAndAdjustForRangeSearch(&msg); + if (status != Status::success) { + return expected::Err(status, std::move(msg)); + } + +#ifdef NOT_COMPILE_FOR_SWIG + TimeRecorder rc("Range Search"); + auto res = this->node->RangeSearch(dataset, *cfg, bitset); + auto span = rc.ElapseFromBegin("done"); + span *= 0.001; // convert to ms + knowhere_range_search_latency.Observe(span); + knowhere_range_search_count.Increment(); +#else + auto res = this->node->RangeSearch(dataset, *cfg, bitset); +#endif + return res; +} + +template +inline expected +Index::GetVectorByIds(const DataSet& dataset) const { + return this->node->GetVectorByIds(dataset); +} + +template +inline bool +Index::HasRawData(const std::string& metric_type) const { + return this->node->HasRawData(metric_type); +} + +template +inline expected +Index::GetIndexMeta(const Json& json) const { + auto cfg = this->node->CreateConfig(); + std::string msg; + auto status = LoadConfig(cfg.get(), json, knowhere::FEDER, "GetIndexMeta", &msg); + if (status != Status::success) { + return expected::Err(status, msg); + } + return this->node->GetIndexMeta(*cfg); +} + +template +inline Status +Index::Serialize(BinarySet& binset) const { + return this->node->Serialize(binset); +} + +template +inline Status +Index::Deserialize(const BinarySet& binset, const Json& json) { + Json json_(json); + auto cfg = this->node->CreateConfig(); + { + auto res = Config::FormatAndCheck(*cfg, json_); + LOG_KNOWHERE_DEBUG_ << "Deserialize config dump: " << json_.dump(); + if (res != Status::success) { + return res; + } + } + auto res = Config::Load(*cfg, json_, knowhere::DESERIALIZE); + if (res != Status::success) { + return res; + } + return this->node->Deserialize(binset, *cfg); +} + +template +inline Status +Index::DeserializeFromFile(const std::string& filename, const Json& json) { + Json json_(json); + auto cfg = this->node->CreateConfig(); + { + auto res = Config::FormatAndCheck(*cfg, json_); + LOG_KNOWHERE_DEBUG_ << "DeserializeFromFile config dump: " << json_.dump(); + if (res != Status::success) { + return res; + } + } + auto res = Config::Load(*cfg, json_, knowhere::DESERIALIZE_FROM_FILE); + if (res != Status::success) { + return res; + } + return this->node->DeserializeFromFile(filename, *cfg); +} + +template +inline int64_t +Index::Dim() const { + return this->node->Dim(); +} + +template +inline int64_t +Index::Size() const { + return this->node->Size(); +} + +template +inline int64_t +Index::Count() const { + return this->node->Count(); +} + +template +inline std::string +Index::Type() const { + return this->node->Type(); +} + +template class Index; + +} // namespace knowhere diff --git a/src/index/index_node_thread_pool_wrapper.cc b/src/common/index_node_thread_pool_wrapper.cc similarity index 83% rename from src/index/index_node_thread_pool_wrapper.cc rename to src/common/index_node_thread_pool_wrapper.cc index 0f1098809..271ffa0b6 100644 --- a/src/index/index_node_thread_pool_wrapper.cc +++ b/src/common/index_node_thread_pool_wrapper.cc @@ -9,10 +9,10 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. -#include "knowhere/index/index_node_thread_pool_wrapper.h" +#include "knowhere/index_node_thread_pool_wrapper.h" #include "knowhere/comp/thread_pool.h" -#include "knowhere/index/index_node.h" +#include "knowhere/index_node.h" namespace knowhere { @@ -36,12 +36,12 @@ IndexNodeThreadPoolWrapper::IndexNodeThreadPoolWrapper(std::unique_ptr -IndexNodeThreadPoolWrapper::Search(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const { +IndexNodeThreadPoolWrapper::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { return thread_pool_->push([&]() { return this->index_node_->Search(dataset, cfg, bitset); }).get(); } expected -IndexNodeThreadPoolWrapper::RangeSearch(const DataSetPtr dataset, const Config& cfg, const BitsetView& bitset) const { +IndexNodeThreadPoolWrapper::RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { return thread_pool_->push([&]() { return this->index_node_->RangeSearch(dataset, cfg, bitset); }).get(); } diff --git a/src/common/prometheus_client.cc b/src/common/prometheus_client.cc index d3bf81355..83b6a21cc 100644 --- a/src/common/prometheus_client.cc +++ b/src/common/prometheus_client.cc @@ -13,11 +13,8 @@ namespace knowhere { -const prometheus::Histogram::BucketBoundaries defaultBuckets = { - 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 1048576}; - -const prometheus::Histogram::BucketBoundaries ratioBuckets = { - 0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0}; +const prometheus::Histogram::BucketBoundaries buckets = {1, 2, 4, 8, 16, 32, 64, 128, 256, + 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}; const std::unique_ptr prometheusClient = std::make_unique(); @@ -28,94 +25,14 @@ const std::unique_ptr prometheusClient = std::make_unique; -} // namespace raft_knowhere diff --git a/src/common/raft/integration/cagra_index.cu b/src/common/raft/integration/cagra_index.cu deleted file mode 100644 index d43900b50..000000000 --- a/src/common/raft/integration/cagra_index.cu +++ /dev/null @@ -1,30 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "common/raft/integration/raft_knowhere_index.cuh" -#include "common/raft/proto/filtered_search_instantiation.cuh" -#include "common/raft/proto/raft_index_kind.hpp" - -RAFT_FILTERED_SEARCH_EXTERN(cagra, raft_knowhere::raft_data_t, - raft_knowhere::raft_indexing_t, - raft_knowhere::raft_input_indexing_t, - raft_knowhere::raft_data_t, - raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type) - -namespace raft_knowhere { -template struct raft_knowhere_index; -} // namespace raft_knowhere diff --git a/src/common/raft/integration/cagra_instantiations.cu b/src/common/raft/integration/cagra_instantiations.cu deleted file mode 100644 index 651fc893b..000000000 --- a/src/common/raft/integration/cagra_instantiations.cu +++ /dev/null @@ -1,27 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "common/raft/integration/type_mappers.hpp" -#include "common/raft/proto/filtered_search_instantiation.cuh" -#include "common/raft/proto/raft_index_kind.hpp" - -RAFT_FILTERED_SEARCH_INSTANTIATION(cagra, raft_knowhere::raft_data_t, - raft_knowhere::raft_indexing_t, - raft_knowhere::raft_input_indexing_t, - raft_knowhere::raft_data_t, - raft_knowhere::knowhere_bitset_data_type, - raft_knowhere::knowhere_bitset_indexing_type) diff --git a/src/common/raft/integration/ivf_flat_index.cu b/src/common/raft/integration/ivf_flat_index.cu deleted file mode 100644 index 8d546048e..000000000 --- a/src/common/raft/integration/ivf_flat_index.cu +++ /dev/null @@ -1,30 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "common/raft/integration/raft_knowhere_index.cuh" -#include "common/raft/proto/filtered_search_instantiation.cuh" -#include "common/raft/proto/raft_index_kind.hpp" - -RAFT_FILTERED_SEARCH_EXTERN(ivf_flat, raft_knowhere::raft_data_t, - raft_knowhere::raft_indexing_t, - raft_knowhere::raft_input_indexing_t, - raft_knowhere::raft_data_t, - raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type) - -namespace raft_knowhere { -template struct raft_knowhere_index; -} // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_flat_instantiations.cu b/src/common/raft/integration/ivf_flat_instantiations.cu deleted file mode 100644 index a5402083c..000000000 --- a/src/common/raft/integration/ivf_flat_instantiations.cu +++ /dev/null @@ -1,27 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "common/raft/integration/type_mappers.hpp" -#include "common/raft/proto/filtered_search_instantiation.cuh" -#include "common/raft/proto/raft_index_kind.hpp" - -RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_flat, raft_knowhere::raft_data_t, - raft_knowhere::raft_indexing_t, - raft_knowhere::raft_input_indexing_t, - raft_knowhere::raft_data_t, - raft_knowhere::knowhere_bitset_data_type, - raft_knowhere::knowhere_bitset_indexing_type) diff --git a/src/common/raft/integration/ivf_pq_index.cu b/src/common/raft/integration/ivf_pq_index.cu deleted file mode 100644 index 9f46b3fc2..000000000 --- a/src/common/raft/integration/ivf_pq_index.cu +++ /dev/null @@ -1,30 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "common/raft/integration/raft_knowhere_index.cuh" -#include "common/raft/proto/filtered_search_instantiation.cuh" -#include "common/raft/proto/raft_index_kind.hpp" - -RAFT_FILTERED_SEARCH_EXTERN(ivf_pq, raft_knowhere::raft_data_t, - raft_knowhere::raft_indexing_t, - raft_knowhere::raft_input_indexing_t, - raft_knowhere::raft_data_t, - raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type) - -namespace raft_knowhere { -template struct raft_knowhere_index; -} // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_pq_instantiations.cu b/src/common/raft/integration/ivf_pq_instantiations.cu deleted file mode 100644 index 40936fda3..000000000 --- a/src/common/raft/integration/ivf_pq_instantiations.cu +++ /dev/null @@ -1,27 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "common/raft/integration/type_mappers.hpp" -#include "common/raft/proto/filtered_search_instantiation.cuh" -#include "common/raft/proto/raft_index_kind.hpp" - -RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_pq, raft_knowhere::raft_data_t, - raft_knowhere::raft_indexing_t, - raft_knowhere::raft_input_indexing_t, - raft_knowhere::raft_data_t, - raft_knowhere::knowhere_bitset_data_type, - raft_knowhere::knowhere_bitset_indexing_type) diff --git a/src/common/raft/integration/raft_initialization.cc b/src/common/raft/integration/raft_initialization.cc deleted file mode 100644 index 59bb6f842..000000000 --- a/src/common/raft/integration/raft_initialization.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "common/raft/integration/raft_initialization.hpp" - -#include - -#include -#include -#include -#include -namespace raft_knowhere { - -void -initialize_raft(raft_configuration const& config) { - auto static initialization_flag = std::once_flag{}; - std::call_once(initialization_flag, [&config]() { - raft::device_resources_manager::set_streams_per_device(config.streams_per_device); - if (config.stream_pool_size) { - raft::device_resources_manager::set_stream_pools_per_device(config.stream_pools_per_device, - *(config.stream_pool_size)); - } else { - raft::device_resources_manager::set_stream_pools_per_device(config.stream_pools_per_device); - } - if (config.init_mem_pool_size_mb) { - raft::device_resources_manager::set_init_mem_pool_size(*(config.init_mem_pool_size_mb) << 20); - } - if (config.max_mem_pool_size_mb) { - if (*config.max_mem_pool_size_mb > 0) { - raft::device_resources_manager::set_max_mem_pool_size(*(config.max_mem_pool_size_mb) << 20); - } - } else { - raft::device_resources_manager::set_max_mem_pool_size(std::nullopt); - } - if (config.max_workspace_size_mb) { - raft::device_resources_manager::set_workspace_allocation_limit(*(config.max_workspace_size_mb) << 20); - } - auto device_count = []() { - auto result = 0; - RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); - RAFT_EXPECTS(result != 0, "No CUDA devices found"); - return result; - }(); - - for (auto device_id = 0; device_id < device_count; ++device_id) { - auto scoped_device = raft::device_setter{device_id}; - auto workspace_size = std::size_t{}; - if (config.max_workspace_size_mb) { - workspace_size = *(config.max_workspace_size_mb) << 20; - } else { - auto free_mem = std::size_t{}; - auto total_mem = std::size_t{}; - RAFT_CUDA_TRY_NO_THROW(cudaMemGetInfo(&free_mem, &total_mem)); - // Heuristic: If workspace size is not explicitly specified, use half of free memory or a quarter of - // total memory, whichever is larger - workspace_size = std::max(free_mem / std::size_t{2}, total_mem / std::size_t{4}); - } - if (workspace_size > std::size_t{}) { - raft::device_resources_manager::set_workspace_memory_resource( - raft::resource::workspace_resource_factory::default_pool_resource(workspace_size), device_id); - } - } - }); -} - -} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_initialization.hpp b/src/common/raft/integration/raft_initialization.hpp deleted file mode 100644 index a26f48e4c..000000000 --- a/src/common/raft/integration/raft_initialization.hpp +++ /dev/null @@ -1,31 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -namespace raft_knowhere { -struct raft_configuration { - std::size_t streams_per_device = std::size_t{16}; - std::size_t stream_pools_per_device = std::size_t{}; - std::optional stream_pool_size = std::nullopt; - std::optional init_mem_pool_size_mb = std::nullopt; - std::optional max_mem_pool_size_mb = std::nullopt; - std::optional max_workspace_size_mb = std::nullopt; -}; - -void -initialize_raft(raft_configuration const& config); -} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_config.hpp b/src/common/raft/integration/raft_knowhere_config.hpp deleted file mode 100644 index 02682a933..000000000 --- a/src/common/raft/integration/raft_knowhere_config.hpp +++ /dev/null @@ -1,125 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include - -#include "common/raft/proto/raft_index_kind.hpp" - -namespace raft_knowhere { -// This struct includes all parameters that may be passed to underlying RAFT -// indexes. It is designed to not expose ANY RAFT types in order to cleanly -// separate RAFT from knowhere headers. -struct raft_knowhere_config { - raft_proto::raft_index_kind index_type; - int k = 10; - - // Common Parameters - std::string metric_type = std::string{"L2"}; - float metric_arg = 2.0f; - bool add_data_on_build = true; - bool cache_dataset_on_device = false; - float refine_ratio = 1.0f; - - // Shared IVF Parameters - std::optional nlist = std::nullopt; - std::optional nprobe = std::nullopt; - std::optional kmeans_n_iters = std::nullopt; - std::optional kmeans_trainset_fraction = std::nullopt; - - // IVF Flat only Parameters - std::optional adaptive_centers = std::nullopt; - - // IVFPQ only Parameters - std::optional m = std::nullopt; - std::optional nbits = std::nullopt; - std::optional codebook_kind = std::nullopt; - std::optional force_random_rotation = std::nullopt; - std::optional conservative_memory_allocation = std::nullopt; - std::optional lookup_table_dtype = std::nullopt; - std::optional internal_distance_dtype = std::nullopt; - std::optional preferred_shmem_carveout = std::nullopt; - - // CAGRA Parameters - std::optional intermediate_graph_degree = std::nullopt; - std::optional graph_degree = std::nullopt; - std::optional itopk_size = std::nullopt; - std::optional max_queries = std::nullopt; - std::optional build_algo = std::nullopt; - std::optional search_algo = std::nullopt; - std::optional team_size = std::nullopt; - std::optional search_width = std::nullopt; - std::optional min_iterations = std::nullopt; - std::optional max_iterations = std::nullopt; - std::optional thread_block_size = std::nullopt; - std::optional hashmap_mode = std::nullopt; - std::optional hashmap_min_bitlen = std::nullopt; - std::optional hashmap_max_fill_rate = std::nullopt; - std::optional nn_descent_niter = std::nullopt; -}; - -// The following function provides a single source of truth for default values -// of RAFT index configurations. -[[nodiscard]] inline auto -validate_raft_knowhere_config(raft_knowhere_config config) { - if (config.index_type == raft_proto::raft_index_kind::brute_force) { - config.add_data_on_build = false; - config.cache_dataset_on_device = true; - } - if (config.index_type == raft_proto::raft_index_kind::ivf_flat || - config.index_type == raft_proto::raft_index_kind::ivf_pq) { - config.add_data_on_build = true; - config.nlist = config.nlist.value_or(128); - config.nprobe = config.nprobe.value_or(8); - config.kmeans_n_iters = config.kmeans_n_iters.value_or(20); - config.kmeans_trainset_fraction = config.kmeans_trainset_fraction.value_or(0.5f); - } - if (config.index_type == raft_proto::raft_index_kind::ivf_flat) { - config.adaptive_centers = config.adaptive_centers.value_or(false); - } - if (config.index_type == raft_proto::raft_index_kind::ivf_pq) { - config.m = config.m.value_or(0); - config.nbits = config.nbits.value_or(8); - config.codebook_kind = config.codebook_kind.value_or("PER_SUBSPACE"); - config.force_random_rotation = config.force_random_rotation.value_or(false); - config.conservative_memory_allocation = config.conservative_memory_allocation.value_or(false); - config.lookup_table_dtype = config.lookup_table_dtype.value_or("CUDA_R_32F"); - config.internal_distance_dtype = config.internal_distance_dtype.value_or("CUDA_R_32F"); - config.preferred_shmem_carveout = config.preferred_shmem_carveout.value_or(1.0f); - } - if (config.index_type == raft_proto::raft_index_kind::cagra) { - config.add_data_on_build = true; - config.intermediate_graph_degree = config.intermediate_graph_degree.value_or(128); - config.graph_degree = config.graph_degree.value_or(64); - config.itopk_size = config.itopk_size.value_or(64); - config.max_queries = config.max_queries.value_or(0); - config.build_algo = config.build_algo.value_or("IVF_PQ"); - config.search_algo = config.search_algo.value_or("AUTO"); - config.team_size = config.team_size.value_or(0); - config.search_width = config.search_width.value_or(1); - config.min_iterations = config.min_iterations.value_or(0); - config.max_iterations = config.max_iterations.value_or(0); - config.thread_block_size = config.thread_block_size.value_or(0); - config.hashmap_mode = config.hashmap_mode.value_or("AUTO"); - config.hashmap_min_bitlen = config.hashmap_min_bitlen.value_or(0); - config.hashmap_max_fill_rate = config.hashmap_max_fill_rate.value_or(0.5f); - config.nn_descent_niter = config.nn_descent_niter.value_or(20); - } - return config; -} - -} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_index.cuh b/src/common/raft/integration/raft_knowhere_index.cuh deleted file mode 100644 index 236c32bee..000000000 --- a/src/common/raft/integration/raft_knowhere_index.cuh +++ /dev/null @@ -1,784 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/raft/integration/raft_knowhere_index.hpp" -#include "common/raft/proto/raft_index.cuh" -#include "common/raft/proto/raft_index_kind.hpp" - -namespace raft_knowhere { -namespace detail { - -// This helper struct maps the generic type of RAFT index to the specific -// instantiation of that index used within knowhere. -template -struct raft_index_type_mapper : std::false_type {}; - -template <> -struct raft_index_type_mapper : std::true_type { - using data_type = raft_data_t; - using indexing_type = raft_indexing_t; - using type = raft_proto::raft_index; - using underlying_index_type = typename type::vector_index_type; - using index_params_type = typename type::index_params_type; - using search_params_type = typename type::search_params_type; -}; -template <> -struct raft_index_type_mapper : std::true_type { - using data_type = raft_data_t; - using indexing_type = raft_indexing_t; - using type = raft_proto::raft_index; - using underlying_index_type = typename type::vector_index_type; - using index_params_type = typename type::index_params_type; - using search_params_type = typename type::search_params_type; -}; -template <> -struct raft_index_type_mapper : std::true_type { - using data_type = raft_data_t; - using indexing_type = raft_indexing_t; - using type = raft_proto::raft_index; - using underlying_index_type = typename type::vector_index_type; - using index_params_type = typename type::index_params_type; - using search_params_type = typename type::search_params_type; -}; -template <> -struct raft_index_type_mapper : std::true_type { - using data_type = raft_data_t; - using indexing_type = raft_indexing_t; - using type = raft_proto::raft_index; - using underlying_index_type = typename type::vector_index_type; - using index_params_type = typename type::index_params_type; - using search_params_type = typename type::search_params_type; -}; - -template -struct check_valid_entry { - __device__ __host__ - check_valid_entry(U max_distance, V max_id) - : max_distance_(max_distance), max_id_(max_id) { - } - __device__ auto - operator()(T id_distance) { - auto id = thrust::get<0>(id_distance); - auto distance = thrust::get<1>(id_distance); - return distance >= max_distance_ || distance < 0 || id >= max_id_; - } - - private: - U max_distance_; - V max_id_; -}; - -} // namespace detail - -template -using raft_index_t = typename detail::raft_index_type_mapper::type; - -template -using raft_index_params_t = typename detail::raft_index_type_mapper::index_params_type; -template -using raft_search_params_t = typename detail::raft_index_type_mapper::search_params_type; - -// Metrics are passed between knowhere and RAFT as strings to avoid tight -// coupling between the implementation details of either one. -[[nodiscard]] inline auto -metric_string_to_raft_distance_type(std::string const& metric_string) { - auto result = raft::distance::DistanceType::L2Expanded; - if (metric_string == "L2") { - result = raft::distance::DistanceType::L2Expanded; - } else if (metric_string == "L2SqrtExpanded") { - result = raft::distance::DistanceType::L2SqrtExpanded; - } else if (metric_string == "CosineExpanded") { - result = raft::distance::DistanceType::CosineExpanded; - } else if (metric_string == "L1") { - result = raft::distance::DistanceType::L1; - } else if (metric_string == "L2Unexpanded") { - result = raft::distance::DistanceType::L2Unexpanded; - } else if (metric_string == "L2SqrtUnexpanded") { - result = raft::distance::DistanceType::L2SqrtUnexpanded; - } else if (metric_string == "IP") { - result = raft::distance::DistanceType::InnerProduct; - } else if (metric_string == "Linf") { - result = raft::distance::DistanceType::Linf; - } else if (metric_string == "Canberra") { - result = raft::distance::DistanceType::Canberra; - } else if (metric_string == "LpUnexpanded") { - result = raft::distance::DistanceType::LpUnexpanded; - } else if (metric_string == "CorrelationExpanded") { - result = raft::distance::DistanceType::CorrelationExpanded; - } else if (metric_string == "JACCARD") { - result = raft::distance::DistanceType::JaccardExpanded; - } else if (metric_string == "HellingerExpanded") { - result = raft::distance::DistanceType::HellingerExpanded; - } else if (metric_string == "Haversine") { - result = raft::distance::DistanceType::Haversine; - } else if (metric_string == "BrayCurtis") { - result = raft::distance::DistanceType::BrayCurtis; - } else if (metric_string == "JensenShannon") { - result = raft::distance::DistanceType::JensenShannon; - } else if (metric_string == "HAMMING") { - result = raft::distance::DistanceType::HammingUnexpanded; - } else if (metric_string == "KLDivergence") { - result = raft::distance::DistanceType::KLDivergence; - } else if (metric_string == "RusselRaoExpanded") { - result = raft::distance::DistanceType::RusselRaoExpanded; - } else if (metric_string == "DiceExpanded") { - result = raft::distance::DistanceType::DiceExpanded; - } else if (metric_string == "Precomputed") { - result = raft::distance::DistanceType::Precomputed; - } else { - RAFT_FAIL("Unrecognized metric type %s", metric_string.c_str()); - } - return result; -} - -[[nodiscard]] inline auto -codebook_string_to_raft_codebook_gen(std::string const& codebook_string) { - auto result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; - if (codebook_string == "PER_SUBSPACE") { - result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; - } else if (codebook_string == "PER_CLUSTER") { - result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; - } else { - RAFT_FAIL("Unrecognized codebook type %s", codebook_string.c_str()); - } - return result; -} -[[nodiscard]] inline auto -build_algo_string_to_cagra_build_algo(std::string const& algo_string) { - auto result = raft::neighbors::cagra::graph_build_algo::IVF_PQ; - if (algo_string == "IVF_PQ") { - result = raft::neighbors::cagra::graph_build_algo::IVF_PQ; - } else if (algo_string == "NN_DESCENT") { - result = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; - } else { - RAFT_FAIL("Unrecognized CAGRA build algo %s", algo_string.c_str()); - } - return result; -} - -[[nodiscard]] inline auto -search_algo_string_to_cagra_search_algo(std::string const& algo_string) { - auto result = raft::neighbors::cagra::search_algo::AUTO; - if (algo_string == "SINGLE_CTA") { - result = raft::neighbors::cagra::search_algo::SINGLE_CTA; - } else if (algo_string == "MULTI_CTA") { - result = raft::neighbors::cagra::search_algo::MULTI_CTA; - } else if (algo_string == "MULTI_KERNEL") { - result = raft::neighbors::cagra::search_algo::MULTI_KERNEL; - } else if (algo_string == "AUTO") { - result = raft::neighbors::cagra::search_algo::AUTO; - } else { - RAFT_FAIL("Unrecognized CAGRA search algo %s", algo_string.c_str()); - } - return result; -} - -[[nodiscard]] inline auto -hashmap_mode_string_to_cagra_hashmap_mode(std::string const& mode_string) { - auto result = raft::neighbors::cagra::hash_mode::AUTO; - if (mode_string == "HASH") { - result = raft::neighbors::cagra::hash_mode::HASH; - } else if (mode_string == "SMALL") { - result = raft::neighbors::cagra::hash_mode::SMALL; - } else if (mode_string == "AUTO") { - result = raft::neighbors::cagra::hash_mode::AUTO; - } else { - RAFT_FAIL("Unrecognized CAGRA hash mode %s", mode_string.c_str()); - } - return result; -} - -[[nodiscard]] inline auto -dtype_string_to_cuda_dtype(std::string const& dtype_string) { - auto result = CUDA_R_32F; - if (dtype_string == "CUDA_R_16F") { - result = CUDA_R_16F; - } else if (dtype_string == "CUDA_C_16F") { - result = CUDA_C_16F; - } else if (dtype_string == "CUDA_R_16BF") { - result = CUDA_R_16BF; - } else if (dtype_string == "CUDA_R_32F") { - result = CUDA_R_32F; - } else if (dtype_string == "CUDA_C_32F") { - result = CUDA_C_32F; - } else if (dtype_string == "CUDA_R_64F") { - result = CUDA_R_64F; - } else if (dtype_string == "CUDA_C_64F") { - result = CUDA_C_64F; - } else if (dtype_string == "CUDA_R_8I") { - result = CUDA_R_8I; - } else if (dtype_string == "CUDA_C_8I") { - result = CUDA_C_8I; - } else if (dtype_string == "CUDA_R_8U") { - result = CUDA_R_8U; - } else if (dtype_string == "CUDA_C_8U") { - result = CUDA_C_8U; - } else if (dtype_string == "CUDA_R_32I") { - result = CUDA_R_32I; - } else if (dtype_string == "CUDA_C_32I") { - result = CUDA_C_32I; -#if __CUDACC_VER_MAJOR__ >= 12 - } else if (dtype_string == "CUDA_R_8F_E4M3") { - result = CUDA_R_8F_E4M3; - } else if (dtype_string == "CUDA_R_8F_E5M2") { - result = CUDA_R_8F_E5M2; -#endif - } else { - RAFT_FAIL("Unrecognized dtype %s", dtype_string.c_str()); - } - return result; -} - -// Given a generic config without RAFT symbols, convert to RAFT index build -// parameters -template -[[nodiscard]] auto -config_to_index_params(raft_knowhere_config const& raw_config) { - RAFT_EXPECTS(raw_config.index_type == IndexKind, "Incorrect index type for this index"); - auto config = validate_raft_knowhere_config(raw_config); - auto result = raft_index_params_t{}; - - result.metric = metric_string_to_raft_distance_type(config.metric_type); - result.metric_arg = config.metric_arg; - result.add_data_on_build = config.add_data_on_build; - - if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat || - IndexKind == raft_proto::raft_index_kind::ivf_pq) { - result.n_lists = *(config.nlist); - result.kmeans_n_iters = *(config.kmeans_n_iters); - result.kmeans_trainset_fraction = *(config.kmeans_trainset_fraction); - result.conservative_memory_allocation = *(config.conservative_memory_allocation); - } - if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat) { - result.adaptive_centers = *(config.adaptive_centers); - } - if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_pq) { - result.pq_dim = *(config.m); - result.pq_bits = *(config.nbits); - result.codebook_kind = codebook_string_to_raft_codebook_gen(*(config.codebook_kind)); - result.force_random_rotation = *(config.force_random_rotation); - } - if constexpr (IndexKind == raft_proto::raft_index_kind::cagra) { - result.intermediate_graph_degree = *(config.intermediate_graph_degree); - result.graph_degree = *(config.graph_degree); - result.build_algo = build_algo_string_to_cagra_build_algo(*(config.build_algo)); - result.nn_descent_niter = *(config.nn_descent_niter); - } - return result; -} - -// Given a generic config without RAFT symbols, convert to RAFT index search -// parameters -template -[[nodiscard]] auto -config_to_search_params(raft_knowhere_config const& raw_config) { - RAFT_EXPECTS(raw_config.index_type == IndexKind, "Incorrect index type for this index"); - auto config = validate_raft_knowhere_config(raw_config); - auto result = raft_search_params_t{}; - if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_flat || - IndexKind == raft_proto::raft_index_kind::ivf_pq) { - result.n_probes = *(config.nprobe); - } - if constexpr (IndexKind == raft_proto::raft_index_kind::ivf_pq) { - result.lut_dtype = dtype_string_to_cuda_dtype(*(config.lookup_table_dtype)); - result.internal_distance_dtype = dtype_string_to_cuda_dtype(*(config.internal_distance_dtype)); - result.preferred_shmem_carveout = *(config.preferred_shmem_carveout); - } - if constexpr (IndexKind == raft_proto::raft_index_kind::cagra) { - result.max_queries = *(config.max_queries); - result.itopk_size = *(config.itopk_size); - result.max_iterations = *(config.max_iterations); - result.algo = search_algo_string_to_cagra_search_algo(*(config.search_algo)); - result.team_size = *(config.team_size); - result.search_width = *(config.search_width); - result.min_iterations = *(config.min_iterations); - result.thread_block_size = *(config.thread_block_size); - result.hashmap_mode = hashmap_mode_string_to_cagra_hashmap_mode(*(config.hashmap_mode)); - result.hashmap_min_bitlen = *(config.hashmap_min_bitlen); - result.hashmap_max_fill_rate = *(config.hashmap_max_fill_rate); - } - return result; -} - -inline auto const& -get_device_resources_without_mempool(int device_id = raft::device_setter::get_current_device()) { - auto thread_local res = std::vector([]() { - int device_count; - RAFT_CUDA_TRY(cudaGetDeviceCount(&device_count)); - return device_count; - }()); - - return res[device_id]; -} - -inline auto -select_device_id() { - auto static device_count = []() { - auto result = 0; - RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); - RAFT_EXPECTS(result != 0, "No CUDA devices found"); - return result; - }(); - auto static index_counter = std::atomic{0}; - // Use round-robin assignment to distribute indexes across devices - auto result = index_counter.fetch_add(1) % device_count; - return result; -} - -// This struct is used to connect knowhere to a RAFT index. The implementation -// is provided here, but this header should never be directly included in -// another knowhere header. This ensures that RAFT symbols are not exposed in -// any knowhere header. -template -struct raft_knowhere_index::impl { - auto static constexpr index_kind = IndexKind; - using data_type = raft_data_t; - using indexing_type = raft_indexing_t; - using input_indexing_type = raft_input_indexing_t; - using raft_index_type = raft_index_t; - - impl() { - } - - auto - is_trained() const { - return index_.has_value(); - } - [[nodiscard]] auto - size() const { - auto result = std::int64_t{}; - if (is_trained()) { - result = index_->size(); - } - return result; - } - [[nodiscard]] auto - dim() const { - auto result = std::int64_t{}; - if (is_trained()) { - result = index_->dim(); - } - return result; - } - - void - train(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count, - knowhere_indexing_type feature_count) { - auto scoped_device = raft::device_setter{device_id}; - auto index_params = config_to_index_params(config); - if constexpr (index_kind == raft_proto::raft_index_kind::ivf_flat || - index_kind == raft_proto::raft_index_kind::ivf_pq) { - index_params.n_lists = std::min(knowhere_indexing_type(index_params.n_lists), row_count); - } - auto const& res = get_device_resources_without_mempool(); - auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); - if constexpr (index_kind == raft_proto::raft_index_kind::ivf_flat) { - device_dataset_storage = - raft::make_device_matrix(res, row_count, feature_count); - auto device_data = device_dataset_storage->view(); - raft::copy(res, device_data, host_data); - index_ = raft_index_type::template build( - res, index_params, raft::make_const_mdspan(device_data)); - if (!config.cache_dataset_on_device) { - device_dataset_storage = std::nullopt; - } - } else { - if (config.cache_dataset_on_device) { - device_dataset_storage = - raft::make_device_matrix(res, row_count, feature_count); - auto device_data = device_dataset_storage->view(); - raft::copy(res, device_data, host_data); - index_ = raft_index_type::template build( - res, index_params, raft::make_const_mdspan(device_data)); - } else { - index_ = raft_index_type::template build( - res, index_params, raft::make_const_mdspan(host_data)); - } - } - } - - void - add(data_type const* data, knowhere_indexing_type row_count, knowhere_indexing_type feature_count, - knowhere_indexing_type const* new_ids) { - if constexpr (index_kind == raft_proto::raft_index_kind::brute_force) { - if (index_) { - RAFT_FAIL("RAFT brute force does not support adding vectors after training"); - } - } else if constexpr (index_kind == raft_proto::raft_index_kind::cagra) { - if (index_) { - RAFT_FAIL("CAGRA does not support adding vectors after training"); - } - } else if constexpr (index_kind == raft_proto::raft_index_kind::ivf_pq) { - if (index_) { - RAFT_FAIL("IVFPQ does not support adding vectors after training"); - } - } else { - if (index_) { - auto const& res = get_device_resources_without_mempool(); - raft::resource::set_workspace_to_pool_resource(res); - auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); - device_dataset_storage = - raft::make_device_matrix(res, row_count, feature_count); - auto device_data = device_dataset_storage->view(); - raft::copy(res, device_data, host_data); - auto device_ids_storage = std::optional>{}; - if (new_ids != nullptr) { - auto host_ids = raft::make_host_vector_view(new_ids, row_count); - device_ids_storage = raft::make_device_vector(res, row_count); - raft::copy(res, device_ids_storage->view(), host_ids); - } - - if (device_ids_storage) { - index_ = raft_index_type::extend( - res, raft::make_const_mdspan(device_data), - std::make_optional(raft::make_const_mdspan(device_ids_storage->view())), *index_); - } else { - index_ = raft_index_type::extend( - res, raft::make_const_mdspan(device_data), - std::optional>{}, *index_); - } - } else { - RAFT_FAIL("Index has not yet been trained"); - } - } - } - - auto - search(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count, - knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data, - knowhere_bitset_indexing_type bitset_byte_size, knowhere_bitset_indexing_type bitset_size) const { - auto scoped_device = raft::device_setter{device_id}; - auto const& res = raft::device_resources_manager::get_device_resources(); - auto k = knowhere_indexing_type(config.k); - auto search_params = config_to_search_params(config); - - auto host_data = raft::make_host_matrix_view(data, row_count, feature_count); - auto device_data_storage = - raft::make_device_matrix(res, row_count, feature_count); - raft::copy(res, device_data_storage.view(), host_data); - - auto device_bitset = - std::optional>{}; - auto k_tmp = k; - - if (bitset_data != nullptr && bitset_byte_size != 0) { - device_bitset = - raft::core::bitset(res, bitset_size); - raft::copy(res, device_bitset->to_mdspan(), raft::make_host_vector_view(bitset_data, bitset_byte_size)); - if constexpr (index_kind == raft_proto::raft_index_kind::brute_force) { - k_tmp += device_bitset->count(res); - if (k_tmp == k) { - device_bitset = std::nullopt; - } - k_tmp = std::min(k_tmp, size()); - } - if (device_bitset) { - device_bitset->flip(res); - } - } - - auto output_size = row_count * k; - auto ids = std::unique_ptr(new knowhere_indexing_type[output_size]); - auto distances = std::unique_ptr(new knowhere_data_type[output_size]); - - auto host_ids = raft::make_host_matrix_view(ids.get(), row_count, k); - auto host_distances = raft::make_host_matrix_view(distances.get(), row_count, k); - - auto device_ids_storage = raft::make_device_matrix(res, row_count, k_tmp); - auto device_distances_storage = raft::make_device_matrix(res, row_count, k_tmp); - auto device_ids = device_ids_storage.view(); - auto device_distances = device_distances_storage.view(); - - RAFT_EXPECTS(index_, "Index has not yet been trained"); - auto dataset_view = device_dataset_storage - ? std::make_optional(device_dataset_storage->view()) - : std::optional>{}; - - if (device_bitset) { - raft_index_type::search( - res, *index_, search_params, raft::make_const_mdspan(device_data_storage.view()), device_ids, - device_distances, config.refine_ratio, input_indexing_type{}, dataset_view, - raft::neighbors::filtering::bitset_filter{ - device_bitset->view()}); - } else { - raft_index_type::search(res, *index_, search_params, raft::make_const_mdspan(device_data_storage.view()), - device_ids, device_distances, config.refine_ratio, input_indexing_type{}, - dataset_view); - } - - auto device_knowhere_ids_storage = - std::optional>{}; - auto device_knowhere_ids = [&device_knowhere_ids_storage, &res, row_count, k_tmp, device_ids]() { - if constexpr (std::is_signed_v) { - return device_ids; - } else { - device_knowhere_ids_storage = - raft::make_device_matrix(res, row_count, k_tmp); - raft::copy(res, device_knowhere_ids_storage->view(), device_ids); - return device_knowhere_ids_storage->view(); - } - }(); - - auto max_distance = std::nextafter(std::numeric_limits::max(), 0.0f); - thrust::replace_if( - raft::resource::get_thrust_policy(res), - thrust::device_ptr(device_knowhere_ids.data_handle()), - thrust::device_ptr(device_knowhere_ids.data_handle() + - device_knowhere_ids.size()), - thrust::make_zip_iterator(thrust::make_tuple( - thrust::device_ptr( - device_knowhere_ids.data_handle()), - thrust::device_ptr(device_distances.data_handle()))), - detail::check_valid_entry, - decltype(max_distance), knowhere_indexing_type>{max_distance, - knowhere_indexing_type(size())}, - typename decltype(device_knowhere_ids)::value_type{-1}); - - if constexpr (index_kind == raft_proto::raft_index_kind::brute_force) { - if (k_tmp > k) { - for (auto i = 0; i < host_ids.extent(0); ++i) { - raft::copy(res, raft::make_host_vector_view(host_ids.data_handle() + i * host_ids.extent(1), k), - raft::make_device_vector_view( - device_knowhere_ids.data_handle() + i * device_knowhere_ids.extent(1), k)); - raft::copy( - res, - raft::make_host_vector_view(host_distances.data_handle() + i * host_distances.extent(1), k), - raft::make_device_vector_view(device_distances.data_handle() + i * device_distances.extent(1), - k)); - } - } else { - raft::copy(res, host_ids, device_knowhere_ids); - raft::copy(res, host_distances, device_distances); - } - } else { - raft::copy(res, host_ids, device_knowhere_ids); - raft::copy(res, host_distances, device_distances); - } - return std::make_tuple(ids.release(), distances.release()); - } - void - range_search() const { - RAFT_FAIL("Range search not yet implemented for RAFT indexes"); - } - void - get_vector_by_id() const { - RAFT_FAIL("Vector reconstruction not yet implemented for RAFT indexes"); - } - void - serialize(std::ostream& os) const { - auto scoped_device = raft::device_setter{device_id}; - auto const& res = get_device_resources_without_mempool(); - RAFT_EXPECTS(index_, "Index has not yet been trained"); - raft_index_type::template serialize(res, os, *index_); - if (device_dataset_storage) { - raft::serialize_scalar(res, os, true); - raft::serialize_scalar(res, os, device_dataset_storage->extent(0)); - raft::serialize_scalar(res, os, device_dataset_storage->extent(1)); - raft::serialize_mdspan(res, os, device_dataset_storage->view()); - } else { - raft::serialize_scalar(res, os, false); - } - } - - void - serialize_to_hnswlib(std::ostream& os) const { - // only carga can save to hnswlib format - if constexpr (index_kind == raft_proto::raft_index_kind::cagra) { - auto scoped_device = raft::device_setter{device_id}; - auto const& res = get_device_resources_without_mempool(); - RAFT_EXPECTS(index_, "Index has not yet been trained"); - raft_index_type::template serialize_to_hnswlib(res, os, *index_); - raft::serialize_scalar(res, os, false); - } - } - - auto static deserialize(std::istream& is) { - auto static device_count = []() { - auto result = 0; - RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); - RAFT_EXPECTS(result != 0, "No CUDA devices found"); - return result; - }(); - // The lazy allocation mode cannot completely eliminate uneven distribution, but it can alleviate it well. - int new_device_id = 0; - size_t free, total; - size_t max_free = 0; - for (int i = 0; i < device_count; ++i) { - auto scoped_device = raft::device_setter{i}; - RAFT_CUDA_TRY(cudaMemGetInfo(&free, &total)); - if (max_free < free) { - max_free = free; - new_device_id = i; - } - } - auto scoped_device = raft::device_setter{new_device_id}; - auto const& res = get_device_resources_without_mempool(); - auto des_index = raft_index_type::template deserialize(res, is); - - auto dataset = std::optional>{}; - auto has_dataset = raft::deserialize_scalar(res, is); - if (has_dataset) { - auto rows = raft::deserialize_scalar(res, is); - auto cols = raft::deserialize_scalar(res, is); - dataset = raft::make_device_matrix(res, rows, cols); - raft::deserialize_mdspan(res, is, dataset->view()); - if constexpr (index_kind == raft_proto::raft_index_kind::brute_force || - index_kind == raft_proto::raft_index_kind::cagra) { - raft_index_type::template update_dataset( - res, des_index, raft::make_const_mdspan(dataset->view())); - } - } - return std::make_unique::impl>(std::move(des_index), new_device_id, - std::move(dataset)); - } - - void - synchronize(bool is_without_mempool = false) const { - auto scoped_device = raft::device_setter{device_id}; - if (is_without_mempool) { - get_device_resources_without_mempool().sync_stream(); - - } else { - raft::device_resources_manager::get_device_resources().sync_stream(); - } - } - impl(raft_index_type&& index, int new_device_id, - std::optional>&& dataset) - : index_{std::move(index)}, device_id{new_device_id}, device_dataset_storage{std::move(dataset)} { - } - - private: - std::optional index_ = std::nullopt; - int device_id = select_device_id(); - std::optional> device_dataset_storage = std::nullopt; -}; - -template -raft_knowhere_index::raft_knowhere_index() : pimpl{new raft_knowhere_index::impl()} { -} - -template -raft_knowhere_index::~raft_knowhere_index() = default; - -template -raft_knowhere_index::raft_knowhere_index(raft_knowhere_index&& other) - : pimpl{std::move(other.pimpl)} { -} - -template -raft_knowhere_index& -raft_knowhere_index::operator=(raft_knowhere_index&& other) { - pimpl = std::move(other.pimpl); - return *this; -} - -template -bool -raft_knowhere_index::is_trained() const { - return pimpl->is_trained(); -} - -template -std::int64_t -raft_knowhere_index::size() const { - return pimpl->size(); -} - -template -std::int64_t -raft_knowhere_index::dim() const { - return pimpl->dim(); -} - -template -void -raft_knowhere_index::train(raft_knowhere_config const& config, data_type const* data, - knowhere_indexing_type row_count, knowhere_indexing_type feature_count) { - return pimpl->train(config, data, row_count, feature_count); -} -template -void -raft_knowhere_index::add(data_type const* data, knowhere_indexing_type row_count, - knowhere_indexing_type feature_count, knowhere_indexing_type const* new_ids) { - return pimpl->add(data, row_count, feature_count, new_ids); -} -template -std::tuple -raft_knowhere_index::search(raft_knowhere_config const& config, data_type const* data, - knowhere_indexing_type row_count, knowhere_indexing_type feature_count, - knowhere_bitset_data_type const* bitset_data, - knowhere_bitset_indexing_type bitset_byte_size, - knowhere_bitset_indexing_type bitset_size) const { - return pimpl->search(config, data, row_count, feature_count, bitset_data, bitset_byte_size, bitset_size); -} - -template -void -raft_knowhere_index::range_search() const { - return pimpl->range_search(); -} - -template -void -raft_knowhere_index::get_vector_by_id() const { - return pimpl->get_vector_by_id(); -} - -template -void -raft_knowhere_index::serialize(std::ostream& os) const { - return pimpl->serialize(os); -} - -template -void -raft_knowhere_index::serialize_to_hnswlib(std::ostream& os) const { - return pimpl->serialize_to_hnswlib(os); -} - -template -raft_knowhere_index -raft_knowhere_index::deserialize(std::istream& is) { - return raft_knowhere_index(raft_knowhere_index::impl::deserialize(is)); -} - -template -void -raft_knowhere_index::synchronize(bool is_without_mempool) const { - return pimpl->synchronize(is_without_mempool); -} - -} // namespace raft_knowhere diff --git a/src/common/raft/integration/raft_knowhere_index.hpp b/src/common/raft/integration/raft_knowhere_index.hpp deleted file mode 100644 index 2c48223bc..000000000 --- a/src/common/raft/integration/raft_knowhere_index.hpp +++ /dev/null @@ -1,84 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include - -#include "common/raft/integration/raft_knowhere_config.hpp" -#include "common/raft/integration/type_mappers.hpp" -#include "common/raft/proto/raft_index_kind.hpp" -namespace raft_knowhere { - -template -struct raft_knowhere_index { - auto static constexpr index_kind = IndexKind; - - using data_type = raft_data_t; - using indexing_type = raft_indexing_t; - using input_indexing_type = raft_input_indexing_t; - - raft_knowhere_index(); - ~raft_knowhere_index(); - - raft_knowhere_index(raft_knowhere_index&& other); - raft_knowhere_index& - operator=(raft_knowhere_index&& other); - - bool - is_trained() const; - std::int64_t - size() const; - std::int64_t - dim() const; - void - train(raft_knowhere_config const&, data_type const*, knowhere_indexing_type, knowhere_indexing_type); - void - add(data_type const* data, knowhere_indexing_type row_count, knowhere_indexing_type feature_count, - knowhere_indexing_type const* new_ids = nullptr); - std::tuple - search(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count, - knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data = nullptr, - knowhere_bitset_indexing_type bitset_byte_size = knowhere_bitset_indexing_type{}, - knowhere_bitset_indexing_type bitset_size = knowhere_bitset_indexing_type{}) const; - void - range_search() const; - void - get_vector_by_id() const; - void - serialize(std::ostream& os) const; - void - serialize_to_hnswlib(std::ostream& os) const; - static raft_knowhere_index - deserialize(std::istream& is); - void - synchronize(bool is_without_mempool = false) const; - - private: - // Use a private implementation to completely separate knowhere headers from - // RAFT headers - struct impl; - std::unique_ptr pimpl; - - raft_knowhere_index(std::unique_ptr&& new_pimpl) : pimpl{std::move(new_pimpl)} { - } -}; - -extern template struct raft_knowhere_index; -extern template struct raft_knowhere_index; -extern template struct raft_knowhere_index; -extern template struct raft_knowhere_index; - -} // namespace raft_knowhere diff --git a/src/common/raft/integration/type_mappers.hpp b/src/common/raft/integration/type_mappers.hpp deleted file mode 100644 index da4134c56..000000000 --- a/src/common/raft/integration/type_mappers.hpp +++ /dev/null @@ -1,74 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include - -#include "common/raft/proto/raft_index_kind.hpp" - -namespace raft_knowhere { - -using knowhere_data_type = float; -using knowhere_indexing_type = std::int64_t; -using knowhere_bitset_data_type = std::uint8_t; -using knowhere_bitset_indexing_type = std::uint32_t; - -namespace detail { - -template -struct raft_io_type_mapper : std::false_type {}; - -template <> -struct raft_io_type_mapper : std::true_type { - using data_type = float; - using indexing_type = std::int64_t; - using input_indexing_type = std::int64_t; -}; - -template <> -struct raft_io_type_mapper : std::true_type { - using data_type = float; - using indexing_type = std::int64_t; - using input_indexing_type = std::int64_t; -}; - -template <> -struct raft_io_type_mapper : std::true_type { - using data_type = float; - using indexing_type = std::int64_t; - using input_indexing_type = std::uint32_t; -}; - -template <> -struct raft_io_type_mapper : std::true_type { - using data_type = float; - using indexing_type = std::uint32_t; - using input_indexing_type = std::int64_t; -}; - -} // namespace detail - -template -using raft_data_t = typename detail::raft_io_type_mapper::data_type; - -template -using raft_indexing_t = typename detail::raft_io_type_mapper::indexing_type; - -template -using raft_input_indexing_t = typename detail::raft_io_type_mapper::input_indexing_type; - -} // namespace raft_knowhere diff --git a/src/common/raft/proto/filtered_search_instantiation.cuh b/src/common/raft/proto/filtered_search_instantiation.cuh deleted file mode 100644 index 09acadb00..000000000 --- a/src/common/raft/proto/filtered_search_instantiation.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include -#include -#include - -#include "common/raft/proto/raft_index_kind.hpp" - -namespace raft_proto { -namespace detail { -template -using index_instantiation = std::conditional_t< - K == raft_proto::raft_index_kind::ivf_flat, raft::neighbors::ivf_flat::index, - std::conditional_t< - K == raft_proto::raft_index_kind::ivf_pq, raft::neighbors::ivf_pq::index, - std::conditional_t, - raft::neighbors::ivf_flat::index>>>; -} // namespace detail -} // namespace raft_proto - -#define RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \ - template void search_with_filtering>( \ - raft::resources const&, search_params const&, \ - raft_proto::detail::index_instantiation const&, \ - raft::device_matrix_view, raft::device_matrix_view, \ - raft::device_matrix_view, raft::neighbors::filtering::bitset_filter) - -#define RAFT_FILTERED_SEARCH_INSTANTIATION(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \ - namespace raft::neighbors::index_type { \ - RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \ - } - -#define RAFT_FILTERED_SEARCH_EXTERN(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \ - namespace raft::neighbors::index_type { \ - RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \ - } diff --git a/src/common/raft/proto/raft_index.cuh b/src/common/raft/proto/raft_index.cuh deleted file mode 100644 index c491c466e..000000000 --- a/src/common/raft/proto/raft_index.cuh +++ /dev/null @@ -1,472 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/raft/proto/raft_index_kind.hpp" - -namespace raft_proto { - -auto static const RAFT_NAME = raft::RAFT_NAME; - -namespace detail { -template typename index_template> -struct template_matches_index_kind : std::false_type {}; - -template <> -struct template_matches_index_kind : std::true_type { -}; - -template <> -struct template_matches_index_kind : std::true_type {}; - -template <> -struct template_matches_index_kind : std::true_type {}; - -template <> -struct template_matches_index_kind : std::true_type {}; - -template typename index_template> -auto static constexpr template_matches_index_kind_v = template_matches_index_kind::value; - -// Note: The following are not at all general or properly SFINAE-guarded. They -// should be replaced if the post_filter proto is required for upstreaming. -template -auto -mdspan_begin(mdspan_t data) { - return thrust::device_ptr(data.data_handle()); -} -template -auto -mdspan_end(mdspan_t data) { - return thrust::device_ptr(data.data_handle() + data.size()); -} - -template -auto -mdspan_begin_row(mdspan_t data, std::size_t row) { - return thrust::device_ptr(data.data_handle() + row * data.extent(1)); -} -template -auto -mdspan_end_row(mdspan_t data, std::size_t row) { - return thrust::device_ptr(data.data_handle() + (row + 1) * data.extent(1)); -} - -template -void -post_filter(raft::resources const& res, filter_lambda_t const& sample_filter, index_mdspan_t index_mdspan, - distance_mdspan_t distance_mdspan) { - auto counter = thrust::counting_iterator(0); - // TODO (wphicks): This could be rolled into the stable_partition calls - // below, but I am not sure whether or not that would be a net benefit. This - // deserves some benchmarking unless pre-filtering gets in before we revisit - // this. - thrust::for_each(raft::resource::get_thrust_policy(res), - thrust::make_zip_iterator( - thrust::make_tuple(counter, mdspan_begin(index_mdspan), mdspan_begin(distance_mdspan))), - thrust::make_zip_iterator(thrust::make_tuple( - counter + index_mdspan.size(), mdspan_end(index_mdspan), mdspan_end(distance_mdspan))), - [=] __device__(auto& index_id_distance) { - auto index = thrust::get<0>(index_id_distance); - auto& id = thrust::get<1>(index_id_distance); - auto& distance = thrust::get<2>(index_id_distance); - if (!sample_filter(index / index_mdspan.extent(1), id)) { - id = std::numeric_limits>::max(); - distance = std::numeric_limits>::max(); - } - }); - for (auto i = 0; i < index_mdspan.extent(0); ++i) { - auto id_row_begin = mdspan_begin_row(index_mdspan, i); - auto id_row_end = mdspan_end_row(index_mdspan, i); - auto distance_row_begin = mdspan_begin_row(distance_mdspan, i); - auto distance_row_end = mdspan_end_row(distance_mdspan, i); - thrust::stable_partition( - raft::resource::get_thrust_policy(res), - thrust::make_zip_iterator(thrust::make_tuple(id_row_begin, distance_row_begin)), - thrust::make_zip_iterator(thrust::make_tuple(id_row_end, distance_row_end)), - [=] __device__(auto& id_distance) { - return thrust::get<0>(id_distance) != - std::numeric_limits(id_distance))>>::max(); - }); - } -} - -template -void -serialize_to_hnswlib(raft::resources const& res, std::ostream& os, - const raft::neighbors::cagra::index& index_) { - size_t metric_type; - if (index_.metric() == raft::distance::L2Expanded) { - metric_type = 0; - } else if (index_.metric() == raft::distance::InnerProduct) { - metric_type = 1; - } else if (index_.metric() == raft::distance::CosineExpanded) { - metric_type = 2; - } - - os.write(reinterpret_cast(&metric_type), sizeof(metric_type)); - size_t data_size = index_.dim() * sizeof(float); - os.write(reinterpret_cast(&data_size), sizeof(data_size)); - size_t dim = index_.dim(); - os.write(reinterpret_cast(&dim), sizeof(dim)); - std::size_t offset_level_0 = 0; - os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); - std::size_t max_element = index_.size(); - os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); - std::size_t curr_element_count = index_.size(); - os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); - auto size_data_per_element = - static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + index_.dim() * sizeof(T) + 8); - os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); - std::size_t label_offset = size_data_per_element - 8; - os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); - auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); - os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); - int max_level = 1; - os.write(reinterpret_cast(&max_level), sizeof(int)); - auto entrypoint_node = static_cast(index_.size() / 2); - os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); - auto max_M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); - std::size_t max_M0 = index_.graph_degree(); - os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); - auto M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&M), sizeof(std::size_t)); - double mult = 0.42424242; - os.write(reinterpret_cast(&mult), sizeof(double)); - std::size_t efConstruction = 500; - os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); - - auto dataset = index_.dataset(); - auto host_dataset = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), dataset.extent(0), cudaMemcpyDefault, - raft::resource::get_cuda_stream(res))); - raft::resource::sync_stream(res); - - auto graph = index_.graph(); - auto host_graph = raft::make_host_matrix(graph.extent(0), graph.extent(1)); - raft::copy(host_graph.data_handle(), graph.data_handle(), graph.size(), raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - - for (std::size_t i = 0; i < index_.size(); i++) { - auto graph_degree = static_cast(index_.graph_degree()); - os.write(reinterpret_cast(&graph_degree), sizeof(uint32_t)); - - for (std::size_t j = 0; j < index_.graph_degree(); ++j) { - auto graph_elem = host_graph(i, j); - os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); - } - - auto data_row = host_dataset.data_handle() + (index_.dim() * i); - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = host_dataset(i, j); - os.write(reinterpret_cast(&data_elem), sizeof(T)); - } - - os.write(reinterpret_cast(&i), sizeof(std::size_t)); - } - - for (std::size_t i = 0; i < index_.size(); i++) { - // zeroes - auto zero = 0; - os.write(reinterpret_cast(&zero), sizeof(int)); - } - os.flush(); -} - -} // namespace detail - -template