diff --git a/cpp/bench/sg/umap.cu b/cpp/bench/sg/umap.cu index fa05df69fb..c75dc0e51e 100644 --- a/cpp/bench/sg/umap.cu +++ b/cpp/bench/sg/umap.cu @@ -115,6 +115,7 @@ class UmapSupervised : public UmapBase { protected: void coreBenchmarkMethod() { + auto graph = raft::sparse::COO(stream); UMAP::fit(*this->handle, this->data.X.data(), yFloat, @@ -123,7 +124,8 @@ class UmapSupervised : public UmapBase { nullptr, nullptr, &uParams, - embeddings); + embeddings, + &graph); } }; ML_BENCH_REGISTER(Params, UmapSupervised, "blobs", getInputs()); @@ -135,6 +137,7 @@ class UmapUnsupervised : public UmapBase { protected: void coreBenchmarkMethod() { + auto graph = raft::sparse::COO(stream); UMAP::fit(*this->handle, this->data.X.data(), nullptr, @@ -143,7 +146,8 @@ class UmapUnsupervised : public UmapBase { nullptr, nullptr, &uParams, - embeddings); + embeddings, + &graph); } }; ML_BENCH_REGISTER(Params, UmapUnsupervised, "blobs", getInputs()); @@ -173,6 +177,7 @@ class UmapTransform : public UmapBase { UmapBase::allocateBuffers(state); auto& handle = *this->handle; alloc(transformed, this->params.nrows * uParams.n_components); + auto graph = raft::sparse::COO(stream); UMAP::fit(handle, this->data.X.data(), yFloat, @@ -181,7 +186,8 @@ class UmapTransform : public UmapBase { nullptr, nullptr, &uParams, - embeddings); + embeddings, + &graph); } void deallocateBuffers(const ::benchmark::State& state) { diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index 1532be3246..008dbb7155 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,126 @@ namespace ML { class UMAPParams; namespace UMAP { +/** + * Returns the simplical set to be consumed by the ML::UMAP::refine function. + * + * @param[in] handle: raft::handle_t + * @param[out] params: pointer to ML::UMAPParams object of which the a and b parameters will be + * updated + */ +void find_ab(const raft::handle_t& handle, UMAPParams* params); + +/** + * Returns the simplical set to be consumed by the ML::UMAP::refine function. + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array + * @param[in] y: pointer to labels array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] knn_indices: pointer to knn_indices (optional) + * @param[in] knn_dists: pointer to knn_dists (optional) + * @param[in] params: pointer to ML::UMAPParams object + * @return: simplical set as a unique pointer to a raft::sparse::COO object + */ +std::unique_ptr> get_graph(const raft::handle_t& handle, + float* X, // input matrix + float* y, // labels + int n, + int d, + int64_t* knn_indices, + float* knn_dists, + UMAPParams* params); + +/** + * Performs a UMAP fit on existing embeddings without reinitializing them, which enables + * iterative fitting without callbacks. + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] graph: pointer to raft::sparse::COO object computed using ML::UMAP::get_graph + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] embeddings: pointer to current embedding with shape n * n_components, stores updated + * embeddings on executing refine + */ +void refine(const raft::handle_t& handle, + float* X, + int n, + int d, + raft::sparse::COO* graph, + UMAPParams* params, + float* embeddings); + +/** + * Dense fit + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array + * @param[in] y: pointer to labels array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] knn_indices: pointer to knn_indices of input (optional) + * @param[in] knn_dists: pointer to knn_dists of input (optional) + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] embeddings: pointer to embedding produced through projection + * @param[out] graph: pointer to fuzzy simplicial set graph + */ +void fit(const raft::handle_t& handle, + float* X, + float* y, + int n, + int d, + int64_t* knn_indices, + float* knn_dists, + UMAPParams* params, + float* embeddings, + raft::sparse::COO* graph); + +/** + * Sparse fit + * + * @param[in] handle: raft::handle_t + * @param[in] indptr: pointer to index pointer array of input array + * @param[in] indices: pointer to index array of input array + * @param[in] data: pointer to data array of input array + * @param[in] nnz: pointer to data array of input array + * @param[in] y: pointer to labels array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] embeddings: pointer to embedding produced through projection + * @param[out] graph: pointer to fuzzy simplicial set graph + */ +void fit_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + size_t nnz, + float* y, + int n, + int d, + UMAPParams* params, + float* embeddings, + raft::sparse::COO* graph); + +/** + * Dense transform + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array to be infered + * @param[in] n: n_samples of input array to be infered + * @param[in] d: n_features of input array to be infered + * @param[in] knn_indices: pointer to knn_indices of input (optional) + * @param[in] knn_dists: pointer to knn_dists of input (optional) + * @param[in] orig_X: pointer to original training array + * @param[in] orig_n: number of rows in original training array + * @param[in] embedding: pointer to embedding created during training + * @param[in] embedding_n: number of rows in embedding created during training + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] transformed: pointer to embedding produced through projection + */ void transform(const raft::handle_t& handle, float* X, int n, @@ -43,6 +163,26 @@ void transform(const raft::handle_t& handle, UMAPParams* params, float* transformed); +/** + * Sparse transform + * + * @param[in] handle: raft::handle_t + * @param[in] indptr: pointer to index pointer array of input array to be infered + * @param[in] indices: pointer to index array of input array to be infered + * @param[in] data: pointer to data array of input array to be infered + * @param[in] nnz: number of stored values of input array to be infered + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] orig_x_indptr: pointer to index pointer array of original training array + * @param[in] orig_x_indices: pointer to index array of original training array + * @param[in] orig_x_data: pointer to data array of original training array + * @param[in] orig_nnz: number of stored values of original training array + * @param[in] orig_n: number of rows in original training array + * @param[in] embedding: pointer to embedding created during training + * @param[in] embedding_n: number of rows in embedding created during training + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] transformed: pointer to embedding produced through projection + */ void transform_sparse(const raft::handle_t& handle, int* indptr, int* indices, @@ -60,67 +200,5 @@ void transform_sparse(const raft::handle_t& handle, UMAPParams* params, float* transformed); -void find_ab(const raft::handle_t& handle, UMAPParams* params); - -void fit(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels - int n, - int d, - int64_t* knn_indices, - float* knn_dists, - UMAPParams* params, - float* embeddings); - -/** - * refine performs a UMAP fit on existing embeddings without reinitializing them, which enables - * iterative fitting without callbacks. - * - * @param handle: raft::handle_t - * @param X: pointer to input array - * @param n: n_samples of input array - * @param d: n_features of input array - * @param cgraph_coo: pointer to raft::sparse::COO object computed using ML::UMAP::get_graph - * @param params: pointer to ML::UMAPParams object - * @param embeddings: pointer to current embedding with shape n * n_components, stores updated - * embeddings on executing refine - */ -void refine(const raft::handle_t& handle, - float* X, // input matrix - int n, - int d, - raft::sparse::COO* cgraph_coo, - UMAPParams* params, - float* embeddings); - -/** - * returns a simplical set as a raft::sparse:COO object to be consumed by the ML::UMAP::refine - * function. - * - * @param handle: raft::handle_t - * @param X: pointer to input array - * @param y: pointer to labels array - * @param n: n_samples of input array - * @param d: n_features of input array - * @param params: pointer to ML::UMAPParams object - * @return: simplical set (pointer to raft::sparse::COO object) - */ -std::unique_ptr> get_graph(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels - int n, - int d, - UMAPParams* params); - -void fit_sparse(const raft::handle_t& handle, - int* indptr, // input matrix - int* indices, - float* data, - size_t nnz, - float* y, - int n, // rows - int d, // cols - UMAPParams* params, - float* embeddings); } // namespace UMAP } // namespace ML diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 4c57a388b6..02f970d348 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -87,89 +87,11 @@ __global__ void init_transform(int* indices, */ void find_ab(UMAPParams* params, cudaStream_t stream) { Optimize::find_params_ab(params, stream); } -template -void _fit(const raft::handle_t& handle, - const umap_inputs& inputs, - UMAPParams* params, - value_t* embeddings) -{ - raft::common::nvtx::range fun_scope("umap::unsupervised::fit"); - cudaStream_t stream = handle.get_stream(); - - int k = params->n_neighbors; - - ML::Logger::get().setLevel(params->verbosity); - - CUML_LOG_DEBUG("n_neighbors=%d", params->n_neighbors); - - raft::common::nvtx::push_range("umap::knnGraph"); - std::unique_ptr> knn_indices_b = nullptr; - std::unique_ptr> knn_dists_b = nullptr; - - knn_graph knn_graph(inputs.n, k); - - /** - * If not given precomputed knn graph, compute it - */ - if (inputs.alloc_knn_graph()) { - /** - * Allocate workspace for kNN graph - */ - knn_indices_b = std::make_unique>(inputs.n * k, stream); - knn_dists_b = std::make_unique>(inputs.n * k, stream); - - knn_graph.knn_indices = knn_indices_b->data(); - knn_graph.knn_dists = knn_dists_b->data(); - } - - CUML_LOG_DEBUG("Calling knn graph run"); - - kNNGraph::run( - handle, inputs, inputs, knn_graph, k, params, stream); - raft::common::nvtx::pop_range(); - - CUML_LOG_DEBUG("Done. Calling fuzzy simplicial set"); - - raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); - FuzzySimplSet::run( - inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &rgraph_coo, params, stream); - - CUML_LOG_DEBUG("Done. Calling remove zeros"); - /** - * Remove zeros from simplicial set - */ - raft::sparse::COO cgraph_coo(stream); - raft::sparse::op::coo_remove_zeros(&rgraph_coo, &cgraph_coo, stream); - raft::common::nvtx::pop_range(); - - /** - * Run initialization method - */ - raft::common::nvtx::push_range("umap::embedding"); - InitEmbed::run(handle, inputs.n, inputs.d, &cgraph_coo, params, embeddings, stream, params->init); - - if (params->callback) { - params->callback->setup(inputs.n, params->n_components); - params->callback->on_preprocess_end(embeddings); - } - - /** - * Run simplicial set embedding to approximate low-dimensional representation - */ - SimplSetEmbed::run(inputs.n, inputs.d, &cgraph_coo, params, embeddings, stream); - raft::common::nvtx::pop_range(); - - if (params->callback) params->callback->on_train_end(embeddings); -} - template void _get_graph(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, - raft::sparse::COO* cgraph_coo // assumes single-precision int as the - // second template argument for COO -) + raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::supervised::_get_graph"); cudaStream_t stream = handle.get_stream(); @@ -209,27 +131,24 @@ void _get_graph(const raft::handle_t& handle, CUML_LOG_DEBUG("Done. Calling fuzzy simplicial set"); raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); + raft::sparse::COO fss_graph(stream); FuzzySimplSet::run( - inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &rgraph_coo, params, stream); + inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &fss_graph, params, stream); CUML_LOG_DEBUG("Done. Calling remove zeros"); /** * Remove zeros from simplicial set */ - raft::sparse::op::coo_remove_zeros(&rgraph_coo, cgraph_coo, stream); + raft::sparse::op::coo_remove_zeros(&fss_graph, graph, stream); raft::common::nvtx::pop_range(); } template -void _get_graph_supervised( - const raft::handle_t& handle, - const umap_inputs& inputs, - UMAPParams* params, - raft::sparse::COO* cgraph_coo // assumes single-precision int as the - // second template argument for COO -) +void _get_graph_supervised(const raft::handle_t& handle, + const umap_inputs& inputs, + UMAPParams* params, + raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::supervised::_get_graph_supervised"); cudaStream_t stream = handle.get_stream(); @@ -269,24 +188,27 @@ void _get_graph_supervised( * Allocate workspace for fuzzy simplicial set. */ raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); - raft::sparse::COO tmp_coo(stream); - - /** - * Run Fuzzy simplicial set - */ - // int nnz = n*k*2; - FuzzySimplSet::run(inputs.n, - knn_graph.knn_indices, - knn_graph.knn_dists, - params->n_neighbors, - &tmp_coo, - params, - stream); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - raft::sparse::op::coo_remove_zeros(&tmp_coo, &rgraph_coo, stream); + raft::sparse::COO fss_graph(stream); + { + raft::sparse::COO fss_graph_tmp(stream); + /** + * Run Fuzzy simplicial set + */ + // int nnz = n*k*2; + FuzzySimplSet::run(inputs.n, + knn_graph.knn_indices, + knn_graph.knn_dists, + params->n_neighbors, + &fss_graph_tmp, + params, + stream); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + raft::sparse::op::coo_remove_zeros(&fss_graph_tmp, &fss_graph, stream); + } + raft::sparse::COO ci_graph(stream); /** * If target metric is 'categorical', perform * categorical simplicial set intersection. @@ -294,7 +216,7 @@ void _get_graph_supervised( if (params->target_metric == ML::UMAPParams::MetricType::CATEGORICAL) { CUML_LOG_DEBUG("Performing categorical intersection"); Supervised::perform_categorical_intersection( - inputs.y, &rgraph_coo, cgraph_coo, params, stream); + inputs.y, &fss_graph, &ci_graph, params, stream); /** * Otherwise, perform general simplicial set intersection @@ -302,16 +224,14 @@ void _get_graph_supervised( } else { CUML_LOG_DEBUG("Performing general intersection"); Supervised::perform_general_intersection( - handle, inputs.y, &rgraph_coo, cgraph_coo, params, stream); + handle, inputs.y, &fss_graph, &ci_graph, params, stream); } /** * Remove zeros */ - raft::sparse::op::coo_sort(cgraph_coo, stream); - - raft::sparse::COO ocoo(stream); - raft::sparse::op::coo_remove_zeros(cgraph_coo, &ocoo, stream); + raft::sparse::op::coo_sort(&ci_graph, stream); + raft::sparse::op::coo_remove_zeros(&ci_graph, graph, stream); raft::common::nvtx::pop_range(); } @@ -319,112 +239,72 @@ template void _refine(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, - raft::sparse::COO* cgraph_coo, + raft::sparse::COO* graph, value_t* embeddings) { cudaStream_t stream = handle.get_stream(); /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(inputs.n, inputs.d, cgraph_coo, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); } template -void _fit_supervised(const raft::handle_t& handle, - const umap_inputs& inputs, - UMAPParams* params, - value_t* embeddings) +void _fit(const raft::handle_t& handle, + const umap_inputs& inputs, + UMAPParams* params, + value_t* embeddings, + raft::sparse::COO* graph) { - raft::common::nvtx::range fun_scope("umap::supervised::fit"); - cudaStream_t stream = handle.get_stream(); - - int k = params->n_neighbors; + raft::common::nvtx::range fun_scope("umap::unsupervised::fit"); + cudaStream_t stream = handle.get_stream(); ML::Logger::get().setLevel(params->verbosity); - if (params->target_n_neighbors == -1) params->target_n_neighbors = params->n_neighbors; - - raft::common::nvtx::push_range("umap::knnGraph"); - std::unique_ptr> knn_indices_b = nullptr; - std::unique_ptr> knn_dists_b = nullptr; - - knn_graph knn_graph(inputs.n, k); + UMAPAlgo::_get_graph(handle, inputs, params, graph); /** - * If not given precomputed knn graph, compute it + * Run initialization method */ - if (inputs.alloc_knn_graph()) { - /** - * Allocate workspace for kNN graph - */ - knn_indices_b = std::make_unique>(inputs.n * k, stream); - knn_dists_b = std::make_unique>(inputs.n * k, stream); + raft::common::nvtx::push_range("umap::embedding"); + InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); - knn_graph.knn_indices = knn_indices_b->data(); - knn_graph.knn_dists = knn_dists_b->data(); + if (params->callback) { + params->callback->setup(inputs.n, params->n_components); + params->callback->on_preprocess_end(embeddings); } - kNNGraph::run( - handle, inputs, inputs, knn_graph, k, params, stream); - - raft::common::nvtx::pop_range(); - /** - * Allocate workspace for fuzzy simplicial set. - */ - raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); - raft::sparse::COO tmp_coo(stream); - - /** - * Run Fuzzy simplicial set + * Run simplicial set embedding to approximate low-dimensional representation */ - // int nnz = n*k*2; - FuzzySimplSet::run(inputs.n, - knn_graph.knn_indices, - knn_graph.knn_dists, - params->n_neighbors, - &tmp_coo, - params, - stream); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - raft::sparse::op::coo_remove_zeros(&tmp_coo, &rgraph_coo, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); + raft::common::nvtx::pop_range(); - raft::sparse::COO final_coo(stream); + if (params->callback) params->callback->on_train_end(embeddings); - /** - * If target metric is 'categorical', perform - * categorical simplicial set intersection. - */ - if (params->target_metric == ML::UMAPParams::MetricType::CATEGORICAL) { - CUML_LOG_DEBUG("Performing categorical intersection"); - Supervised::perform_categorical_intersection( - inputs.y, &rgraph_coo, &final_coo, params, stream); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} - /** - * Otherwise, perform general simplicial set intersection - */ - } else { - CUML_LOG_DEBUG("Performing general intersection"); - Supervised::perform_general_intersection( - handle, inputs.y, &rgraph_coo, &final_coo, params, stream); - } +template +void _fit_supervised(const raft::handle_t& handle, + const umap_inputs& inputs, + UMAPParams* params, + value_t* embeddings, + raft::sparse::COO* graph) +{ + raft::common::nvtx::range fun_scope("umap::supervised::fit"); - /** - * Remove zeros - */ - raft::sparse::op::coo_sort(&final_coo, stream); + cudaStream_t stream = handle.get_stream(); + ML::Logger::get().setLevel(params->verbosity); - raft::sparse::COO ocoo(stream); - raft::sparse::op::coo_remove_zeros(&final_coo, &ocoo, stream); - raft::common::nvtx::pop_range(); + UMAPAlgo::_get_graph_supervised( + handle, inputs, params, graph); /** * Initialize embeddings */ raft::common::nvtx::push_range("umap::supervised::fit"); - InitEmbed::run(handle, inputs.n, inputs.d, &ocoo, params, embeddings, stream, params->init); + InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); if (params->callback) { params->callback->setup(inputs.n, params->n_components); @@ -434,7 +314,7 @@ void _fit_supervised(const raft::handle_t& handle, /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(inputs.n, inputs.d, &ocoo, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); raft::common::nvtx::pop_range(); if (params->callback) params->callback->on_train_end(embeddings); diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 08eb3bcfc0..b5be1b2543 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -28,84 +28,78 @@ namespace UMAP { static const int TPB_X = 256; -/** - * \brief Dense transform - * - * @param X: pointer to input array - * @param n: n_samples of input array - * @param d: n_features of input array - * @param orig_X: self.X_m.ptr - * @param orig_n: self.n_rows - * @param embedding: pointer to stored embedding - * @param embedding_n: n_samples in embedding, equals to orig_n - * @param transformed: output array with shape n * n_components - */ -void transform(const raft::handle_t& handle, - float* X, - int n, - int d, - knn_indices_dense_t* knn_indices, - float* knn_dists, - float* orig_X, - int orig_n, - float* embedding, - int embedding_n, - UMAPParams* params, - float* transformed) +void find_ab(const raft::handle_t& handle, UMAPParams* params) +{ + cudaStream_t stream = handle.get_stream(); + UMAPAlgo::find_ab(params, stream); +} + +std::unique_ptr> get_graph( + const raft::handle_t& handle, + float* X, // input matrix + float* y, // labels + int n, + int d, + knn_indices_dense_t* knn_indices, // precomputed indices + float* knn_dists, // precomputed distances + UMAPParams* params) { + auto graph = std::make_unique>(handle.get_stream()); if (knn_indices != nullptr && knn_dists != nullptr) { + CUML_LOG_DEBUG("Calling UMAP::get_graph() with precomputed KNN"); + manifold_precomputed_knn_inputs_t inputs( - knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors); - UMAPAlgo::_transform, - TPB_X>( - handle, inputs, inputs, embedding, embedding_n, params, transformed); + knn_indices, knn_dists, X, y, n, d, params->n_neighbors); + if (y != nullptr) { + UMAPAlgo::_get_graph_supervised, + TPB_X>(handle, inputs, params, graph.get()); + } else { + UMAPAlgo::_get_graph, + TPB_X>(handle, inputs, params, graph.get()); + } + return graph; } else { - manifold_dense_inputs_t inputs(X, nullptr, n, d); - manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); - UMAPAlgo::_transform, TPB_X>( - handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); + manifold_dense_inputs_t inputs(X, y, n, d); + if (y != nullptr) { + UMAPAlgo:: + _get_graph_supervised, TPB_X>( + handle, inputs, params, graph.get()); + } else { + UMAPAlgo::_get_graph, TPB_X>( + handle, inputs, params, graph.get()); + } + return graph; } } -// Sparse transform -void transform_sparse(const raft::handle_t& handle, - int* indptr, - int* indices, - float* data, - size_t nnz, - int n, - int d, - int* orig_x_indptr, - int* orig_x_indices, - float* orig_x_data, - size_t orig_nnz, - int orig_n, - float* embedding, - int embedding_n, - UMAPParams* params, - float* transformed) +void refine(const raft::handle_t& handle, + float* X, + int n, + int d, + raft::sparse::COO* graph, + UMAPParams* params, + float* embeddings) { - manifold_sparse_inputs_t inputs( - indptr, indices, data, nullptr, nnz, n, d); - manifold_sparse_inputs_t orig_x_inputs( - orig_x_indptr, orig_x_indices, orig_x_data, nullptr, orig_nnz, orig_n, d); - - UMAPAlgo::_transform, TPB_X>( - handle, inputs, orig_x_inputs, embedding, embedding_n, params, transformed); + CUML_LOG_DEBUG("Calling UMAP::refine() with precomputed KNN"); + manifold_dense_inputs_t inputs(X, nullptr, n, d); + UMAPAlgo::_refine, TPB_X>( + handle, inputs, params, graph, embeddings); } -// Dense fit void fit(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels + float* X, + float* y, int n, int d, knn_indices_dense_t* knn_indices, float* knn_dists, UMAPParams* params, - float* embeddings) + float* embeddings, + raft::sparse::COO* graph) { if (knn_indices != nullptr && knn_dists != nullptr) { CUML_LOG_DEBUG("Calling UMAP::fit() with precomputed KNN"); @@ -116,90 +110,102 @@ void fit(const raft::handle_t& handle, UMAPAlgo::_fit_supervised, - TPB_X>(handle, inputs, params, embeddings); + TPB_X>(handle, inputs, params, embeddings, graph); } else { UMAPAlgo::_fit, - TPB_X>(handle, inputs, params, embeddings); + TPB_X>(handle, inputs, params, embeddings, graph); } } else { manifold_dense_inputs_t inputs(X, y, n, d); if (y != nullptr) { UMAPAlgo::_fit_supervised, TPB_X>( - handle, inputs, params, embeddings); + handle, inputs, params, embeddings, graph); } else { UMAPAlgo::_fit, TPB_X>( - handle, inputs, params, embeddings); + handle, inputs, params, embeddings, graph); } } } -// get graph -std::unique_ptr> get_graph(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels - int n, - int d, - UMAPParams* params) -{ - manifold_dense_inputs_t inputs(X, y, n, d); - auto cgraph_coo = std::make_unique>(handle.get_stream()); - if (y != nullptr) { - UMAPAlgo:: - _get_graph_supervised, TPB_X>( - handle, inputs, params, cgraph_coo.get()); - } else { - UMAPAlgo::_get_graph, TPB_X>( - handle, inputs, params, cgraph_coo.get()); - } - - return cgraph_coo; -} - -// refine -void refine(const raft::handle_t& handle, - float* X, // input matrix - int n, - int d, - raft::sparse::COO* cgraph_coo, - UMAPParams* params, - float* embeddings) -{ - CUML_LOG_DEBUG("Calling UMAP::refine() with precomputed KNN"); - manifold_dense_inputs_t inputs(X, nullptr, n, d); - UMAPAlgo::_refine, TPB_X>( - handle, inputs, params, cgraph_coo, embeddings); -} - -// Sparse fit void fit_sparse(const raft::handle_t& handle, - int* indptr, // input matrix + int* indptr, int* indices, float* data, size_t nnz, float* y, - int n, // rows - int d, // cols + int n, + int d, UMAPParams* params, - float* embeddings) + float* embeddings, + raft::sparse::COO* graph) { manifold_sparse_inputs_t inputs(indptr, indices, data, y, nnz, n, d); if (y != nullptr) { UMAPAlgo:: _fit_supervised, TPB_X>( - handle, inputs, params, embeddings); + handle, inputs, params, embeddings, graph); } else { UMAPAlgo::_fit, TPB_X>( - handle, inputs, params, embeddings); + handle, inputs, params, embeddings, graph); } } -void find_ab(const raft::handle_t& handle, UMAPParams* params) +void transform(const raft::handle_t& handle, + float* X, + int n, + int d, + knn_indices_dense_t* knn_indices, + float* knn_dists, + float* orig_X, + int orig_n, + float* embedding, + int embedding_n, + UMAPParams* params, + float* transformed) { - cudaStream_t stream = handle.get_stream(); - UMAPAlgo::find_ab(params, stream); + if (knn_indices != nullptr && knn_dists != nullptr) { + manifold_precomputed_knn_inputs_t inputs( + knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors); + UMAPAlgo::_transform, + TPB_X>( + handle, inputs, inputs, embedding, embedding_n, params, transformed); + } else { + manifold_dense_inputs_t inputs(X, nullptr, n, d); + manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); + UMAPAlgo::_transform, TPB_X>( + handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); + } +} + +void transform_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + size_t nnz, + int n, + int d, + int* orig_x_indptr, + int* orig_x_indices, + float* orig_x_data, + size_t orig_nnz, + int orig_n, + float* embedding, + int embedding_n, + UMAPParams* params, + float* transformed) +{ + manifold_sparse_inputs_t inputs( + indptr, indices, data, nullptr, nnz, n, d); + manifold_sparse_inputs_t orig_x_inputs( + orig_x_indptr, orig_x_indices, orig_x_data, nullptr, orig_nnz, orig_n, d); + + UMAPAlgo::_transform, TPB_X>( + handle, inputs, orig_x_inputs, embedding, embedding_n, params, transformed); } } // namespace UMAP diff --git a/cpp/test/sg/umap_parametrizable_test.cu b/cpp/test/sg/umap_parametrizable_test.cu index 765da3c740..cf90cba1b4 100644 --- a/cpp/test/sg/umap_parametrizable_test.cu +++ b/cpp/test/sg/umap_parametrizable_test.cu @@ -169,9 +169,19 @@ class UMAPParametrizableTest : public ::testing::Test { handle.sync_stream(stream); + auto graph = raft::sparse::COO(stream); + if (test_params.supervised) { - ML::UMAP::fit( - handle, X, y, n_samples, n_features, knn_indices, knn_dists, &umap_params, model_embedding); + ML::UMAP::fit(handle, + X, + y, + n_samples, + n_features, + knn_indices, + knn_dists, + &umap_params, + model_embedding, + &graph); } else { ML::UMAP::fit(handle, X, @@ -181,18 +191,20 @@ class UMAPParametrizableTest : public ::testing::Test { knn_indices, knn_dists, &umap_params, - model_embedding); + model_embedding, + &graph); } if (test_params.refine) { std::cout << "using refine"; if (test_params.supervised) { - auto cgraph_coo = ML::UMAP::get_graph(handle, X, y, n_samples, n_features, &umap_params); + auto cgraph_coo = + ML::UMAP::get_graph(handle, X, y, n_samples, n_features, nullptr, nullptr, &umap_params); ML::UMAP::refine( handle, X, n_samples, n_features, cgraph_coo.get(), &umap_params, model_embedding); } else { - auto cgraph_coo = - ML::UMAP::get_graph(handle, X, nullptr, n_samples, n_features, &umap_params); + auto cgraph_coo = ML::UMAP::get_graph( + handle, X, nullptr, n_samples, n_features, nullptr, nullptr, &umap_params); ML::UMAP::refine( handle, X, n_samples, n_features, cgraph_coo.get(), &umap_params, model_embedding); } diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx new file mode 100644 index 0000000000..b1ea268900 --- /dev/null +++ b/python/cuml/manifold/simpl_set.pyx @@ -0,0 +1,336 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + +import numpy as np +import cupy as cp + +from cuml.manifold.umap_utils cimport * +from cuml.manifold.umap_utils import GraphHolder, find_ab_params + +import cuml.internals +from cuml.common.base import Base +from cuml.common.input_utils import input_to_cuml_array +from cuml.common.array import CumlArray + +from raft.common.handle cimport handle_t +from raft.common.handle import Handle + +from libc.stdint cimport uintptr_t +from libc.stdlib cimport free +from libcpp.memory cimport unique_ptr + + +cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": + + unique_ptr[COO] get_graph(handle_t &handle, + float* X, + float* y, + int n, + int d, + int64_t* knn_indices, + float* knn_dists, + UMAPParams* params) + + void refine(handle_t &handle, + float* X, + int n, + int d, + COO* cgraph_coo, + UMAPParams* params, + float* embeddings) + + +def fuzzy_simplicial_set(X, + n_neighbors, + random_state=None, + metric="euclidean", + metric_kwds={}, + knn_indices=None, + knn_dists=None, + set_op_mix_ratio=1.0, + local_connectivity=1.0, + verbose=False): + """Given a set of data X, a neighborhood size, and a measure of distance + compute the fuzzy simplicial set (here represented as a fuzzy graph in + the form of a sparse matrix) associated to the data. This is done by + locally approximating geodesic distance at each point, creating a fuzzy + simplicial set for each such point, and then combining all the local + fuzzy simplicial sets into a global one via a fuzzy union. + Parameters + ---------- + X: array of shape (n_samples, n_features) + The data to be modelled as a fuzzy simplicial set. + n_neighbors: int + The number of neighbors to use to approximate geodesic distance. + Larger numbers induce more global estimates of the manifold that can + miss finer detail, while smaller values will focus on fine manifold + structure to the detriment of the larger picture. + random_state: numpy RandomState or equivalent + A state capable being used as a numpy random state. + metric: string or function (optional, default 'euclidean') + unused + metric_kwds: dict (optional, default {}) + unused + knn_indices: array of shape (n_samples, n_neighbors) (optional) + If the k-nearest neighbors of each point has already been calculated + you can pass them in here to save computation time. This should be + an array with the indices of the k-nearest neighbors as a row for + each data point. + knn_dists: array of shape (n_samples, n_neighbors) (optional) + If the k-nearest neighbors of each point has already been calculated + you can pass them in here to save computation time. This should be + an array with the distances of the k-nearest neighbors as a row for + each data point. + set_op_mix_ratio: float (optional, default 1.0) + Interpolate between (fuzzy) union and intersection as the set operation + used to combine local fuzzy simplicial sets to obtain a global fuzzy + simplicial sets. Both fuzzy set operations use the product t-norm. + The value of this parameter should be between 0.0 and 1.0; a value of + 1.0 will use a pure fuzzy union, while 0.0 will use a pure fuzzy + intersection. + local_connectivity: int (optional, default 1) + The local connectivity required -- i.e. the number of nearest + neighbors that should be assumed to be connected at a local level. + The higher this value the more connected the manifold becomes + locally. In practice this should be not more than the local intrinsic + dimension of the manifold. + verbose: bool (optional, default False) + Whether to report information on the current progress of the algorithm. + Returns + ------- + fuzzy_simplicial_set: coo_matrix + A fuzzy simplicial set represented as a sparse matrix. The (i, + j) entry of the matrix represents the membership strength of the + 1-simplex between the ith and jth sample points. + """ + + deterministic = random_state is not None + if not isinstance(random_state, int): + if isinstance(random_state, np.random.RandomState): + rs = random_state + else: + rs = np.random.RandomState(random_state) + random_state = rs.randint(low=np.iinfo(np.int32).min, + high=np.iinfo(np.int32).max, + dtype=np.int32) + + cdef UMAPParams* umap_params = new UMAPParams() + umap_params.n_neighbors = n_neighbors + umap_params.random_state = random_state + umap_params.deterministic = deterministic + umap_params.set_op_mix_ratio = set_op_mix_ratio + umap_params.local_connectivity = local_connectivity + umap_params.verbosity = verbose + + X_m, n_rows, n_cols, _ = \ + input_to_cuml_array(X, + order='C', + check_dtype=np.float32, + convert_to_dtype=np.float32) + + if knn_indices is not None and knn_dists is not None: + knn_indices_m, _, _, _ = \ + input_to_cuml_array(knn_indices, + order='C', + check_dtype=np.int64, + convert_to_dtype=np.int64) + knn_dists_m, _, _, _ = \ + input_to_cuml_array(knn_dists, + order='C', + check_dtype=np.float32, + convert_to_dtype=np.float32) + X_ptr = 0 + knn_indices_ptr = knn_indices_m.ptr + knn_dists_ptr = knn_dists_m.ptr + else: + X_m, _, _, _ = \ + input_to_cuml_array(X, + order='C', + check_dtype=np.float32, + convert_to_dtype=np.float32) + X_ptr = X_m.ptr + knn_indices_ptr = 0 + knn_dists_ptr = 0 + + handle = Handle() + cdef handle_t* handle_ = handle.getHandle() + cdef unique_ptr[COO] fss_graph_ptr = get_graph( + handle_[0], + X_ptr, + NULL, + X.shape[0], + X.shape[1], + knn_indices_ptr, + knn_dists_ptr, + umap_params) + fss_graph = GraphHolder.from_ptr(fss_graph_ptr) + + free(umap_params) + + return fss_graph.get_cupy_coo() + + +def simplicial_set_embedding( + data, + graph, + n_components=2, + initial_alpha=1.0, + a=None, + b=None, + repulsion_strength=1.0, + negative_sample_rate=5, + n_epochs=None, + init="spectral", + random_state=None, + metric="euclidean", + metric_kwds={}, + output_metric="euclidean", + output_metric_kwds={}, + verbose=False, +): + """Perform a fuzzy simplicial set embedding, using a specified + initialisation method and then minimizing the fuzzy set cross entropy + between the 1-skeletons of the high and low dimensional fuzzy simplicial + sets. + Parameters + ---------- + data: array of shape (n_samples, n_features) + The source data to be embedded by UMAP. + graph: sparse matrix + The 1-skeleton of the high dimensional fuzzy simplicial set as + represented by a graph for which we require a sparse matrix for the + (weighted) adjacency matrix. + n_components: int + The dimensionality of the euclidean space into which to embed the data. + initial_alpha: float + Initial learning rate for the SGD. + a: float + Parameter of differentiable approximation of right adjoint functor + b: float + Parameter of differentiable approximation of right adjoint functor + repulsion_strength: float + Weight to apply to negative samples. + negative_sample_rate: int (optional, default 5) + The number of negative samples to select per positive sample + in the optimization process. Increasing this value will result + in greater repulsive force being applied, greater optimization + cost, but slightly more accuracy. + n_epochs: int (optional, default 0) + The number of training epochs to be used in optimizing the + low dimensional embedding. Larger values result in more accurate + embeddings. If 0 is specified a value will be selected based on + the size of the input dataset (200 for large datasets, 500 for small). + init: string + How to initialize the low dimensional embedding. Options are: + * 'spectral': use a spectral embedding of the fuzzy 1-skeleton + * 'random': assign initial embedding positions at random. + * A numpy array of initial embedding positions. + random_state: numpy RandomState or equivalent + A state capable being used as a numpy random state. + metric: string or callable + unused + metric_kwds: dict + unused + output_metric: function + Function returning the distance between two points in embedding space + and the gradient of the distance wrt the first argument. + output_metric_kwds: dict + Key word arguments to be passed to the output_metric function. + verbose: bool (optional, default False) + Whether to report information on the current progress of the algorithm. + Returns + ------- + embedding: array of shape (n_samples, n_components) + The optimized of ``graph`` into an ``n_components`` dimensional + euclidean space. + """ + + if init not in ['spectral', 'random']: + raise Exception("Initialization strategy not supported: %d" % init) + + if output_metric not in ['euclidean', 'categorical']: + raise Exception("Invalid output metric: {}" % output_metric) + + n_epochs = n_epochs if n_epochs else 0 + + if a is None or b is None: + spread = 1.0 + min_dist = 0.1 + a, b = find_ab_params(spread, min_dist) + + deterministic = random_state is not None + if not isinstance(random_state, int): + if isinstance(random_state, np.random.RandomState): + rs = random_state + else: + rs = np.random.RandomState(random_state) + random_state = rs.randint(low=np.iinfo(np.int32).min, + high=np.iinfo(np.int32).max, + dtype=np.int32) + + cdef UMAPParams* umap_params = new UMAPParams() + umap_params.n_components = n_components + umap_params.initial_alpha = initial_alpha + umap_params.a = a + umap_params.b = b + umap_params.repulsion_strength = repulsion_strength + umap_params.negative_sample_rate = negative_sample_rate + umap_params.n_epochs = n_epochs + if init == 'spectral': + umap_params.init = 1 + else: # init == 'random' + umap_params.init = 0 + umap_params.random_state = random_state + umap_params.deterministic = deterministic + if output_metric == 'euclidean': + umap_params.target_metric = MetricType.EUCLIDEAN + else: # output_metric == 'categorical' + umap_params.target_metric = MetricType.CATEGORICAL + umap_params.target_weight = output_metric_kwds['p'] \ + if 'p' in output_metric_kwds else 0 + umap_params.verbosity = verbose + + X_m, n_rows, n_cols, dtype = \ + input_to_cuml_array(data, order='C', check_dtype=np.float32) + + graph = graph.tocoo() + graph.sum_duplicates() + if not isinstance(graph, cp.sparse.coo_matrix): + graph = cp.sparse.coo_matrix(graph) + + handle = Handle() + cdef handle_t* handle_ = handle.getHandle() + cdef GraphHolder fss_graph = GraphHolder.from_coo_array(GraphHolder(), + handle, + graph) + + embedding = CumlArray.zeros((X_m.shape[0], n_components), + order="C", dtype=np.float32, + index=X_m.index) + + refine(handle_[0], + X_m.ptr, + X_m.shape[0], + X_m.shape[1], + fss_graph.get(), + umap_params, + embedding.ptr) + + free(umap_params) + + return embedding diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index ba79ff290c..0979b5a6d1 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -30,6 +30,9 @@ import cupyx import numba.cuda as cuda +from cuml.manifold.umap_utils cimport * +from cuml.manifold.umap_utils import GraphHolder, find_ab_params + from cuml.common.sparsefuncs import extract_knn_graph from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix,\ coo_matrix as cp_coo_matrix, csc_matrix as cp_csc_matrix @@ -48,6 +51,9 @@ from cuml.common.array_sparse import SparseCumlArray from cuml.common.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse +from cuml.manifold.simpl_set import fuzzy_simplicial_set, \ + simplicial_set_embedding + if has_scipy(True): import scipy.sparse @@ -55,54 +61,14 @@ from cuml.common.array_descriptor import CumlArrayDescriptor import rmm -from libcpp cimport bool from libc.stdint cimport uintptr_t -from libc.stdint cimport uint64_t -from libc.stdint cimport int64_t -from libc.stdlib cimport calloc, malloc, free +from libc.stdlib cimport free from libcpp.memory cimport shared_ptr cimport cuml.common.cuda -cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": - - enum MetricType: - EUCLIDEAN = 0, - CATEGORICAL = 1 - -cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals": - - cdef cppclass GraphBasedDimRedCallback - -cdef extern from "cuml/manifold/umapparams.h" namespace "ML": - - cdef cppclass UMAPParams: - int n_neighbors, - int n_components, - int n_epochs, - float learning_rate, - float min_dist, - float spread, - float set_op_mix_ratio, - float local_connectivity, - float repulsion_strength, - int negative_sample_rate, - float transform_queue_size, - int verbosity, - float a, - float b, - float initial_alpha, - int init, - int target_n_neighbors, - MetricType target_metric, - float target_weight, - uint64_t random_state, - bool deterministic, - GraphBasedDimRedCallback * callback - - cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": void fit(handle_t & handle, @@ -113,7 +79,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": int64_t * knn_indices, float * knn_dists, UMAPParams * params, - float * embeddings) except + + float * embeddings, + COO * graph) except + void fit_sparse(handle_t &handle, int *indptr, @@ -124,7 +91,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": int n, int d, UMAPParams *params, - float *embeddings) except + + float *embeddings, + COO * graph) except + void transform(handle_t & handle, float * X, @@ -465,27 +433,7 @@ class UMAP(Base, @staticmethod def find_ab_params(spread, min_dist): - """ Function taken from UMAP-learn : https://github.com/lmcinnes/umap - Fit a, b params for the differentiable curve used in lower - dimensional fuzzy simplicial complex construction. We want the - smooth curve (from a pre-defined family with simple gradient) that - best matches an offset exponential decay. - """ - - def curve(x, a, b): - return 1.0 / (1.0 + a * x ** (2 * b)) - - if has_scipy(): - from scipy.optimize import curve_fit - else: - raise RuntimeError('Scipy is needed to run find_ab_params') - - xv = np.linspace(0, spread * 3, 300) - yv = np.zeros(xv.shape) - yv[xv < min_dist] = 1.0 - yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread) - params, covar = curve_fit(curve, xv, yv) - return params[0], params[1] + return find_ab_params(spread, min_dist) @generate_docstring(convert_dtype_cast='np.float32', X='dense_sparse', @@ -579,6 +527,7 @@ class UMAP(Base, else None)) y_raw = y_m.ptr + fss_graph = GraphHolder.new_graph(handle_.get_stream()) if self.sparse_fit: fit_sparse(handle_[0], self.X_m.indptr.ptr, @@ -589,7 +538,8 @@ class UMAP(Base, self.n_rows, self.n_dims, umap_params, - embed_raw) + embed_raw, + fss_graph.get()) else: fit(handle_[0], @@ -600,7 +550,10 @@ class UMAP(Base, knn_indices_raw, knn_dists_raw, umap_params, - embed_raw) + embed_raw, + fss_graph.get()) + + self.graph_ = fss_graph.get_cupy_coo() self.handle.sync() @@ -726,10 +679,6 @@ class UMAP(Base, n_rows = X_m.shape[0] n_cols = X_m.shape[1] - if n_rows <= 1: - raise ValueError("There needs to be more than 1 sample to " - "build nearest the neighbors graph") - if n_cols != self.n_dims: raise ValueError("n_features of X must match n_features of " "training data") diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd new file mode 100644 index 0000000000..f4bebeb7b9 --- /dev/null +++ b/python/cuml/manifold/umap_utils.pxd @@ -0,0 +1,88 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + +from rmm._lib.cuda_stream_view cimport cuda_stream_view +from libcpp.memory cimport unique_ptr + +from libc.stdint cimport uint64_t, uintptr_t, int64_t +from libcpp cimport bool +from libcpp.memory cimport shared_ptr + + +cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": + + enum MetricType: + EUCLIDEAN = 0, + CATEGORICAL = 1 + +cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals": + + cdef cppclass GraphBasedDimRedCallback + +cdef extern from "cuml/manifold/umapparams.h" namespace "ML": + + cdef cppclass UMAPParams: + int n_neighbors, + int n_components, + int n_epochs, + float learning_rate, + float min_dist, + float spread, + float set_op_mix_ratio, + float local_connectivity, + float repulsion_strength, + int negative_sample_rate, + float transform_queue_size, + int verbosity, + float a, + float b, + float initial_alpha, + int init, + int target_n_neighbors, + MetricType target_metric, + float target_weight, + uint64_t random_state, + bool deterministic, + GraphBasedDimRedCallback * callback + +cdef extern from "raft/sparse/coo.hpp": + cdef cppclass COO "raft::sparse::COO": + COO(cuda_stream_view stream) + void allocate(int nnz, int size, bool init, cuda_stream_view stream) + int nnz + float* vals() + int* rows() + int* cols() + +cdef class GraphHolder: + cdef unique_ptr[COO] c_graph + + @staticmethod + cdef GraphHolder new_graph(cuda_stream_view stream) + + @staticmethod + cdef GraphHolder from_ptr(unique_ptr[COO]& ptr) + + @staticmethod + cdef GraphHolder from_coo_array(graph, handle, coo_array) + + cdef COO* get(GraphHolder self) + cdef uintptr_t vals(GraphHolder self) + cdef uintptr_t rows(GraphHolder self) + cdef uintptr_t cols(GraphHolder self) + cdef uint64_t get_nnz(GraphHolder self) diff --git a/python/cuml/manifold/umap_utils.pyx b/python/cuml/manifold/umap_utils.pyx new file mode 100644 index 0000000000..499aea605c --- /dev/null +++ b/python/cuml/manifold/umap_utils.pyx @@ -0,0 +1,127 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + +from raft.common.handle cimport handle_t +from cuml.manifold.umap_utils cimport * +from libcpp.utility cimport move +import numpy as np +import cupy as cp + + +cdef class GraphHolder: + @staticmethod + cdef GraphHolder new_graph(cuda_stream_view stream): + cdef GraphHolder graph = GraphHolder.__new__(GraphHolder) + graph.c_graph.reset(new COO(stream)) + return graph + + @staticmethod + cdef GraphHolder from_ptr(unique_ptr[COO]& ptr): + cdef GraphHolder graph = GraphHolder.__new__(GraphHolder) + graph.c_graph = move(ptr) + return graph + + @staticmethod + cdef GraphHolder from_coo_array(graph, handle, coo_array): + def copy_from_array(dst_raft_coo_ptr, src_cp_coo): + size = src_cp_coo.size + itemsize = np.dtype(src_cp_coo.dtype).itemsize + dest_buff = cp.cuda.UnownedMemory(ptr=dst_raft_coo_ptr, + size=size * itemsize, + owner=None, + device_id=-1) + dest_mptr = cp.cuda.memory.MemoryPointer(dest_buff, 0) + src_buff = cp.cuda.UnownedMemory(ptr=src_cp_coo.data.ptr, + size=size * itemsize, + owner=None, + device_id=-1) + src_mptr = cp.cuda.memory.MemoryPointer(src_buff, 0) + dest_mptr.copy_from_device(src_mptr, size * itemsize) + + cdef handle_t* handle_ = handle.getHandle() + graph.c_graph.reset(new COO(handle_.get_stream())) + graph.get().allocate(coo_array.nnz, + coo_array.shape[0], + False, + handle_.get_stream()) + handle_.sync_stream() + + copy_from_array(graph.vals(), coo_array.data.astype('float32')) + copy_from_array(graph.rows(), coo_array.row.astype('int32')) + copy_from_array(graph.cols(), coo_array.col.astype('int32')) + + return graph + + cdef inline COO* get(self): + return self.c_graph.get() + + cdef uintptr_t vals(self): + return self.get().vals() + + cdef uintptr_t rows(self): + return self.get().rows() + + cdef uintptr_t cols(self): + return self.get().cols() + + cdef uint64_t get_nnz(self): + return self.get().nnz + + def get_cupy_coo(self): + def create_nonowning_cp_array(ptr, dtype): + mem = cp.cuda.UnownedMemory(ptr=ptr, + size=(self.get_nnz() * + np.dtype(dtype).itemsize), + owner=self, + device_id=-1) + memptr = cp.cuda.memory.MemoryPointer(mem, 0) + return cp.ndarray(self.get_nnz(), dtype=dtype, memptr=memptr) + + vals = create_nonowning_cp_array(self.vals(), np.float32) + rows = create_nonowning_cp_array(self.rows(), np.int32) + cols = create_nonowning_cp_array(self.cols(), np.int32) + + return cp.sparse.coo_matrix(((vals, (rows, cols)))) + + def __dealloc__(self): + self.c_graph.reset(NULL) + + +def find_ab_params(spread, min_dist): + """ Function taken from UMAP-learn : https://github.com/lmcinnes/umap + Fit a, b params for the differentiable curve used in lower + dimensional fuzzy simplicial complex construction. We want the + smooth curve (from a pre-defined family with simple gradient) that + best matches an offset exponential decay. + """ + + def curve(x, a, b): + return 1.0 / (1.0 + a * x ** (2 * b)) + + from cuml.common.import_utils import has_scipy + if has_scipy(): + from scipy.optimize import curve_fit + else: + raise RuntimeError('Scipy is needed to run find_ab_params') + + xv = np.linspace(0, spread * 3, 300) + yv = np.zeros(xv.shape) + yv[xv < min_dist] = 1.0 + yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread) + params, covar = curve_fit(curve, xv, yv) + return params[0], params[1] diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index cd7b5c3bf6..e7dbf83748 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -24,6 +24,7 @@ import cupyx import scipy.sparse +import cupy as cp from cuml.manifold.umap import UMAP as cuUMAP from cuml.testing.utils import array_equal, unit_param, \ @@ -530,3 +531,56 @@ def test_equality(e1, e2): test_equality(embedding3, embedding4) test_equality(embedding5, embedding6) test_equality(embedding6, embedding7) + + +def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): + n_ref_zeros = (a == 0).sum() + n_ref_non_zero_elms = a.size - n_ref_zeros + n_correct = (cp.abs(a - b) <= (atol + rtol * cp.abs(b))).sum() + correctness = (n_correct - n_ref_zeros) / n_ref_non_zero_elms + return correctness >= threshold + + +@pytest.mark.parametrize('n_rows', [800, 5000]) +@pytest.mark.parametrize('n_features', [8, 32]) +@pytest.mark.parametrize('n_neighbors', [8, 16]) +@pytest.mark.parametrize('precomputed_nearest_neighbors', [False, True]) +def test_fuzzy_simplicial_set(n_rows, + n_features, + n_neighbors, + precomputed_nearest_neighbors): + n_clusters = 30 + random_state = 42 + metric = 'euclidean' + + X, _ = make_blobs(n_samples=n_rows, centers=n_clusters, + n_features=n_features, random_state=random_state) + + if precomputed_nearest_neighbors: + nn = NearestNeighbors(n_neighbors=n_neighbors, + metric=metric) + nn.fit(X) + knn_dists, knn_indices = nn.kneighbors(X, + n_neighbors, + return_distance=True) + knn_graph = nn.kneighbors_graph(X, mode="distance") + model = cuUMAP(n_neighbors=n_neighbors) + model.fit(X, + knn_graph=knn_graph) + cu_fss_graph = model.graph_ + + knn_indices = knn_indices + knn_dists = knn_dists + + else: + model = cuUMAP(n_neighbors=n_neighbors) + model.fit(X) + cu_fss_graph = model.graph_ + + cu_fss_graph = cu_fss_graph.todense() + ref_fss_graph = cu_fss_graph + assert correctness_sparse(ref_fss_graph, + cu_fss_graph, + atol=0.1, + rtol=0.2, + threshold=0.95)