Skip to content

Commit

Permalink
Moving FAISS ANN wrapper to raft (#3963)
Browse files Browse the repository at this point in the history
This accompanies the RAFT PR: rapidsai/raft#265

Closes #3954

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Victor Lafargue (https://github.com/viclafargue)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3963
  • Loading branch information
cjnolet authored Jun 16, 2021
1 parent afdeda9 commit 1be5e3a
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 375 deletions.
1 change: 0 additions & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ function(find_and_configure_raft)
SOURCE_SUBDIR cpp
OPTIONS
"BUILD_TESTS OFF"

)

message(VERBOSE "CUML: Using RAFT located in ${raft_SOURCE_DIR}")
Expand Down
56 changes: 6 additions & 50 deletions cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,14 @@

#pragma once

#include <faiss/gpu/GpuIndex.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <raft/linalg/distance_type.h>
#include <raft/spatial/knn/ann_common.h>

namespace raft {
class handle_t;
}

namespace ML {
struct knnIndex {
faiss::gpu::GpuIndex *index;
raft::distance::DistanceType metric;
float metricArg;

faiss::gpu::StandardGpuResources *gpu_res;
int device;
~knnIndex() {
delete index;
delete gpu_res;
}
};

typedef enum {
QT_8bit,
QT_4bit,
QT_8bit_uniform,
QT_4bit_uniform,
QT_fp16,
QT_8bit_direct,
QT_6bit
} QuantizerType;

struct knnIndexParam {
virtual ~knnIndexParam() {}
};

struct IVFParam : knnIndexParam {
int nlist;
int nprobe;
};

struct IVFFlatParam : IVFParam {};

struct IVFPQParam : IVFParam {
int M;
int n_bits;
bool usePrecomputedTables;
};

struct IVFSQParam : IVFParam {
QuantizerType qtype;
bool encodeResidual;
};

/**
* @brief Flat C++ API function to perform a brute force knn on
Expand Down Expand Up @@ -112,8 +67,9 @@ void brute_force_knn(const raft::handle_t &handle, std::vector<float *> &input,
* @param[in] n number of rows in the index array
* @param[in] D the dimensionality of the index array
*/
void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params,
void approx_knn_build_index(raft::handle_t &handle,
raft::spatial::knn::knnIndex *index,
raft::spatial::knn::knnIndexParam *params,
raft::distance::DistanceType metric,
float metricArg, float *index_array, int n, int D);

Expand All @@ -131,8 +87,8 @@ void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
* @param[in] n number of rows in the query array
*/
void approx_knn_search(raft::handle_t &handle, float *distances,
int64_t *indices, ML::knnIndex *index, int k,
float *query_array, int n);
int64_t *indices, raft::spatial::knn::knnIndex *index,
int k, float *query_array, int n);

/**
* @brief Flat C++ API function to perform a knn classification using a
Expand Down
18 changes: 10 additions & 8 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <ml_mg_utils.cuh>

#include <label/classlabels.cuh>
#include <raft/spatial/knn/ann.hpp>
#include <raft/spatial/knn/knn.hpp>
#include <selection/knn.cuh>

Expand All @@ -44,19 +45,20 @@ void brute_force_knn(const raft::handle_t &handle, std::vector<float *> &input,
rowMajorQuery, nullptr, metric, metric_arg);
}

void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params,
void approx_knn_build_index(raft::handle_t &handle,
raft::spatial::knn::knnIndex *index,
raft::spatial::knn::knnIndexParam *params,
raft::distance::DistanceType metric,
float metricArg, float *index_array, int n, int D) {
MLCommon::Selection::approx_knn_build_index(handle, index, params, metric,
metricArg, index_array, n, D);
raft::spatial::knn::approx_knn_build_index(handle, index, params, metric,
metricArg, index_array, n, D);
}

void approx_knn_search(raft::handle_t &handle, float *distances,
int64_t *indices, ML::knnIndex *index, int k,
float *query_array, int n) {
MLCommon::Selection::approx_knn_search(handle, distances, indices, index, k,
query_array, n);
int64_t *indices, raft::spatial::knn::knnIndex *index,
int k, float *query_array, int n) {
raft::spatial::knn::approx_knn_search(handle, distances, indices, index, k,
query_array, n);
}

void knn_classify(raft::handle_t &handle, int *out, int64_t *knn_indices,
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/knn/knn_opg_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ void reduce(opg_knn_param<in_t, ind_t, dist_t, out_t> &params,
}

// Merge all KNN local results
MLCommon::Selection::knn_merge_parts(
raft::spatial::knn::knn_merge_parts(
work.res_D.data(), work.res_I.data(), distances, indices, batch_size,
work.idxRanks.size(), params.k, handle.get_stream(), trans.data());
CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));
Expand Down
Loading

0 comments on commit 1be5e3a

Please sign in to comment.