From 311d4a4383a57497f09b6df07df921b40684b41a Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 26 Apr 2022 20:08:39 +0200 Subject: [PATCH 01/12] Simplicial set --- python/cuml/manifold/simpl_set.pyx | 367 ++++++++++++++++++++++++++++ python/cuml/tests/test_simpl_set.py | 132 ++++++++++ 2 files changed, 499 insertions(+) create mode 100644 python/cuml/manifold/simpl_set.pyx create mode 100644 python/cuml/tests/test_simpl_set.py diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx new file mode 100644 index 0000000000..c954ab2cbb --- /dev/null +++ b/python/cuml/manifold/simpl_set.pyx @@ -0,0 +1,367 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + +import numpy as np +import cupy as cp + +import cuml.internals +from cuml.common.base import Base +from cuml.common.input_utils import input_to_cuml_array +from cuml.common.array import CumlArray +from cuml.manifold.umap import UMAP + +from rmm._lib.cuda_stream_view cimport cuda_stream_view +from raft.common.handle cimport handle_t +from raft.common.handle import Handle + +from libc.stdint cimport uintptr_t +from libc.stdint cimport uint64_t +from libc.stdint cimport int64_t +from libcpp cimport bool +from libcpp.memory cimport unique_ptr + + +cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": + + enum MetricType: + EUCLIDEAN = 0, + CATEGORICAL = 1 + + +cdef extern from "cuml/manifold/umapparams.h" namespace "ML": + + cdef cppclass UMAPParams: + int n_neighbors, + int n_components, + int n_epochs, + float learning_rate, + float min_dist, + float spread, + float set_op_mix_ratio, + float local_connectivity, + float repulsion_strength, + int negative_sample_rate, + float transform_queue_size, + int verbosity, + float a, + float b, + float initial_alpha, + int init, + int target_n_neighbors, + MetricType target_metric, + float target_weight, + uint64_t random_state, + bool deterministic, + int optim_batch_size + + +cdef extern from "raft/sparse/coo.hpp": + cdef cppclass COO "raft::sparse::COO": + COO(cuda_stream_view stream) + void allocate(int nnz, int size, bool init, cuda_stream_view stream) + int nnz + float* vals() + int* rows() + int* cols() + + +cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": + + unique_ptr[COO] get_graph(handle_t &handle, + float* X, + float* y, + int n, + int d, + UMAPParams* params) + + void refine(handle_t &handle, + float* X, + int n, + int d, + COO* cgraph_coo, + UMAPParams* params, + float* embeddings) + + +def fuzzy_simplicial_set(X, + n_neighbors, + random_state=None, + metric="euclidean", + metric_kwds={}, + set_op_mix_ratio=1.0, + local_connectivity=1.0, + verbose=False): + """Given a set of data X, a neighborhood size, and a measure of distance + compute the fuzzy simplicial set (here represented as a fuzzy graph in + the form of a sparse matrix) associated to the data. This is done by + locally approximating geodesic distance at each point, creating a fuzzy + simplicial set for each such point, and then combining all the local + fuzzy simplicial sets into a global one via a fuzzy union. + Parameters + ---------- + X: array of shape (n_samples, n_features) + The data to be modelled as a fuzzy simplicial set. + n_neighbors: int + The number of neighbors to use to approximate geodesic distance. + Larger numbers induce more global estimates of the manifold that can + miss finer detail, while smaller values will focus on fine manifold + structure to the detriment of the larger picture. + random_state: numpy RandomState or equivalent + A state capable being used as a numpy random state. + metric: string or function (optional, default 'euclidean') + unused + metric_kwds: dict (optional, default {}) + unused + set_op_mix_ratio: float (optional, default 1.0) + Interpolate between (fuzzy) union and intersection as the set operation + used to combine local fuzzy simplicial sets to obtain a global fuzzy + simplicial sets. Both fuzzy set operations use the product t-norm. + The value of this parameter should be between 0.0 and 1.0; a value of + 1.0 will use a pure fuzzy union, while 0.0 will use a pure fuzzy + intersection. + local_connectivity: int (optional, default 1) + The local connectivity required -- i.e. the number of nearest + neighbors that should be assumed to be connected at a local level. + The higher this value the more connected the manifold becomes + locally. In practice this should be not more than the local intrinsic + dimension of the manifold. + verbose: bool (optional, default False) + Whether to report information on the current progress of the algorithm. + Returns + ------- + fuzzy_simplicial_set: coo_matrix + A fuzzy simplicial set represented as a sparse matrix. The (i, + j) entry of the matrix represents the membership strength of the + 1-simplex between the ith and jth sample points. + """ + + deterministic = random_state is not None + if not isinstance(random_state, int): + if isinstance(random_state, np.random.RandomState): + rs = random_state + else: + rs = np.random.RandomState(random_state) + random_state = rs.randint(low=np.iinfo(np.int32).min, + high=np.iinfo(np.int32).max, + dtype=np.int32) + + cdef UMAPParams* umap_params = new UMAPParams() + umap_params.n_neighbors = n_neighbors + umap_params.random_state = random_state + umap_params.deterministic = deterministic + umap_params.set_op_mix_ratio = set_op_mix_ratio + umap_params.local_connectivity = local_connectivity + umap_params.verbosity = verbose + + X_m, n_rows, n_cols, dtype = \ + input_to_cuml_array(X, order='C', check_dtype=np.float32) + + handle = Handle() + cdef handle_t* handle_ = handle.getHandle() + cdef unique_ptr[COO] fss_graph = get_graph(handle_[0], + X_m.ptr, + NULL, + X_m.shape[0], + X_m.shape[1], + umap_params) + + cdef int nnz = fss_graph.get().nnz + coo_arrays = [] + for ptr, dtype in [(fss_graph.get().vals(), np.float32), + (fss_graph.get().rows(), np.int32), + (fss_graph.get().cols(), np.int32)]: + mem = cp.cuda.UnownedMemory(ptr=ptr, + size=nnz * 4, + owner=None, + device_id=-1) + memptr = cp.cuda.memory.MemoryPointer(mem, 0) + coo_arrays.append(cp.ndarray(nnz, dtype=dtype, memptr=memptr)) + + result = cp.sparse.coo_matrix(((coo_arrays[0], + (coo_arrays[1], coo_arrays[2])))) + result = result.copy() + + return result + + +@cuml.internals.api_return_generic() +def simplicial_set_embedding( + data, + graph, + n_components=2, + initial_alpha=1.0, + a=None, + b=None, + repulsion_strength=1.0, + negative_sample_rate=5, + n_epochs=None, + init="spectral", + random_state=None, + metric="euclidean", + metric_kwds={}, + output_metric="euclidean", + output_metric_kwds={}, + verbose=False, +): + """Perform a fuzzy simplicial set embedding, using a specified + initialisation method and then minimizing the fuzzy set cross entropy + between the 1-skeletons of the high and low dimensional fuzzy simplicial + sets. + Parameters + ---------- + data: array of shape (n_samples, n_features) + The source data to be embedded by UMAP. + graph: sparse matrix + The 1-skeleton of the high dimensional fuzzy simplicial set as + represented by a graph for which we require a sparse matrix for the + (weighted) adjacency matrix. + n_components: int + The dimensionality of the euclidean space into which to embed the data. + initial_alpha: float + Initial learning rate for the SGD. + a: float + Parameter of differentiable approximation of right adjoint functor + b: float + Parameter of differentiable approximation of right adjoint functor + repulsion_strength: float + Weight to apply to negative samples. + negative_sample_rate: int (optional, default 5) + The number of negative samples to select per positive sample + in the optimization process. Increasing this value will result + in greater repulsive force being applied, greater optimization + cost, but slightly more accuracy. + n_epochs: int (optional, default 0) + The number of training epochs to be used in optimizing the + low dimensional embedding. Larger values result in more accurate + embeddings. If 0 is specified a value will be selected based on + the size of the input dataset (200 for large datasets, 500 for small). + init: string + How to initialize the low dimensional embedding. Options are: + * 'spectral': use a spectral embedding of the fuzzy 1-skeleton + * 'random': assign initial embedding positions at random. + * A numpy array of initial embedding positions. + random_state: numpy RandomState or equivalent + A state capable being used as a numpy random state. + metric: string or callable + unused + metric_kwds: dict + unused + output_metric: function + Function returning the distance between two points in embedding space + and the gradient of the distance wrt the first argument. + output_metric_kwds: dict + Key word arguments to be passed to the output_metric function. + verbose: bool (optional, default False) + Whether to report information on the current progress of the algorithm. + Returns + ------- + embedding: array of shape (n_samples, n_components) + The optimized of ``graph`` into an ``n_components`` dimensional + euclidean space. + """ + + if init not in ['spectral', 'random']: + raise Exception("Initialization strategy not supported: %d" % init) + + if output_metric not in ['euclidean', 'categorical']: + raise Exception("Invalid output metric: {}" % output_metric) + + n_epochs = n_epochs if n_epochs else 0 + + if a is None or b is None: + spread = 1.0 + min_dist = 0.1 + a, b = UMAP.find_ab_params(spread, min_dist) + + deterministic = random_state is not None + if not isinstance(random_state, int): + if isinstance(random_state, np.random.RandomState): + rs = random_state + else: + rs = np.random.RandomState(random_state) + random_state = rs.randint(low=np.iinfo(np.int32).min, + high=np.iinfo(np.int32).max, + dtype=np.int32) + + cdef UMAPParams* umap_params = new UMAPParams() + umap_params.n_components = n_components + umap_params.initial_alpha = initial_alpha + umap_params.a = a + umap_params.b = b + umap_params.repulsion_strength = repulsion_strength + umap_params.negative_sample_rate = negative_sample_rate + umap_params.n_epochs = n_epochs + if init == 'spectral': + umap_params.init = 1 + else: # init == 'random' + umap_params.init = 0 + umap_params.random_state = random_state + umap_params.deterministic = deterministic + if output_metric == 'euclidean': + umap_params.target_metric = MetricType.EUCLIDEAN + else: # output_metric == 'categorical' + umap_params.target_metric = MetricType.CATEGORICAL + umap_params.target_weight = output_metric_kwds['p'] \ + if 'p' in output_metric_kwds else 0 + umap_params.verbosity = verbose + + X_m, n_rows, n_cols, dtype = \ + input_to_cuml_array(data, order='C', check_dtype=np.float32) + + handle = Handle() + cdef handle_t* handle_ = handle.getHandle() + cdef COO* fss_graph = new COO(handle_.get_stream()) + fss_graph.allocate(graph.nnz, graph.shape[0], False, handle_.get_stream()) + handle_.sync_stream() + + graph = graph.tocoo() + graph.sum_duplicates() + if not isinstance(graph, cp.sparse.coo_matrix): + graph = cp.sparse.coo_matrix(graph) + + nnz_size = graph.nnz * 4 # sizeof(float) == sizeof(int) == 4 + for dst_ptr, src_ptr in \ + [(fss_graph.vals(), graph.data.data.ptr), + (fss_graph.rows(), graph.row.data.ptr), + (fss_graph.cols(), graph.col.data.ptr)]: + dest_buff = cp.cuda.UnownedMemory(ptr=dst_ptr, + size=nnz_size, + owner=None, + device_id=-1) + dest_mptr = cp.cuda.memory.MemoryPointer(dest_buff, 0) + src_buff = cp.cuda.UnownedMemory(ptr=src_ptr, + size=nnz_size, + owner=None, + device_id=-1) + src_mptr = cp.cuda.memory.MemoryPointer(src_buff, 0) + dest_mptr.copy_from_device(src_mptr, nnz_size) + + embedding = CumlArray.zeros((X_m.shape[0], n_components), + order="C", dtype=np.float32, + index=X_m.index) + + refine(handle_[0], + X_m.ptr, + X_m.shape[0], + X_m.shape[1], + fss_graph, + umap_params, + embedding.ptr) + + cuml.internals.set_api_output_type("cupy") + return embedding diff --git a/python/cuml/tests/test_simpl_set.py b/python/cuml/tests/test_simpl_set.py new file mode 100644 index 0000000000..450127e1b8 --- /dev/null +++ b/python/cuml/tests/test_simpl_set.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.datasets import make_blobs +import numpy as np +import umap.distances as dist +from cuml.manifold.umap import UMAP +from umap.umap_ import fuzzy_simplicial_set as ref_fuzzy_simplicial_set +from cuml.manifold.simpl_set import fuzzy_simplicial_set \ + as cu_fuzzy_simplicial_set +from umap.umap_ import simplicial_set_embedding as ref_simplicial_set_embedding +from cuml.manifold.simpl_set import simplicial_set_embedding \ + as cu_simplicial_set_embedding + + +def test_fuzzy_simplicial_set(): + n_rows = 10000 + n_features = 16 + n_clusters = 30 + n_neighbors = 5 + random_state = 42 + metric = 'euclidean' + + X, _ = make_blobs(n_samples=n_rows, centers=n_clusters, + n_features=n_features, random_state=random_state) + X = X.get() + + ref_fss_graph = ref_fuzzy_simplicial_set(X, + n_neighbors, + random_state, + metric)[0].tocoo() + cu_fss_graph = cu_fuzzy_simplicial_set(X, + n_neighbors, + random_state, + metric) + + print('ref_fss_graph:', type(ref_fss_graph), ref_fss_graph.shape, + ref_fss_graph.nnz, ref_fss_graph.data, ref_fss_graph.row, + ref_fss_graph.col) + print('cu_fss_graph:', type(cu_fss_graph), cu_fss_graph.shape, + cu_fss_graph.nnz, cu_fss_graph.data, cu_fss_graph.row, + cu_fss_graph.col) + + +def test_simplicial_set_embedding(): + n_rows = 10000 + n_features = 16 + n_clusters = 30 + n_neighbors = 5 + random_state = 42 + metric = 'euclidean' + n_components = 3 + initial_alpha = 1.0 + a, b = UMAP.find_ab_params(1.0, 0.1) + gamma = 1 + negative_sample_rate = 5 + n_epochs = 0 + init = 'random' + metric = 'euclidean' + metric_kwds = {} + densmap = False + densmap_kwds = {} + output_dens = False + output_metric = 'euclidean' + output_metric_kwds = {} + + X, _ = make_blobs(n_samples=n_rows, centers=n_clusters, + n_features=n_features, random_state=random_state) + X = X.get() + + ref_fss_graph = ref_fuzzy_simplicial_set(X, + n_neighbors, + random_state, + metric)[0] + ref_embedding = ref_simplicial_set_embedding( + X, + ref_fss_graph, + n_components, + initial_alpha, + a, + b, + gamma, + negative_sample_rate, + n_epochs, + init, + np.random.RandomState(random_state), + dist.named_distances_with_gradients[metric], + metric_kwds, + densmap, + densmap_kwds, + output_dens, + output_metric=output_metric, + output_metric_kwds=output_metric_kwds)[0] + + cu_fss_graph = cu_fuzzy_simplicial_set(X, + n_neighbors, + random_state, + metric) + + cu_embedding = cu_simplicial_set_embedding( + X, + cu_fss_graph, + n_components, + initial_alpha, + a, + b, + gamma, + negative_sample_rate, + n_epochs, + init, + random_state, + metric, + metric_kwds, + output_metric=output_metric, + output_metric_kwds=output_metric_kwds) + + print('ref_embedding:', type(ref_embedding), ref_embedding.shape, + ref_embedding) + print('cu_embedding:', type(cu_embedding), cu_embedding.shape, + cu_embedding) From 8fd0a5958bddf25a30a0bf21f41c8f1e9f3ecc2e Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 27 Apr 2022 13:37:29 +0200 Subject: [PATCH 02/12] Tests --- python/cuml/manifold/simpl_set.pyx | 10 +++---- python/cuml/tests/test_simpl_set.py | 42 ++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx index c954ab2cbb..0e5e1eabbd 100644 --- a/python/cuml/manifold/simpl_set.pyx +++ b/python/cuml/manifold/simpl_set.pyx @@ -323,17 +323,17 @@ def simplicial_set_embedding( X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(data, order='C', check_dtype=np.float32) + graph = graph.tocoo() + graph.sum_duplicates() + if not isinstance(graph, cp.sparse.coo_matrix): + graph = cp.sparse.coo_matrix(graph) + handle = Handle() cdef handle_t* handle_ = handle.getHandle() cdef COO* fss_graph = new COO(handle_.get_stream()) fss_graph.allocate(graph.nnz, graph.shape[0], False, handle_.get_stream()) handle_.sync_stream() - graph = graph.tocoo() - graph.sum_duplicates() - if not isinstance(graph, cp.sparse.coo_matrix): - graph = cp.sparse.coo_matrix(graph) - nnz_size = graph.nnz * 4 # sizeof(float) == sizeof(int) == 4 for dst_ptr, src_ptr in \ [(fss_graph.vals(), graph.data.data.ptr), diff --git a/python/cuml/tests/test_simpl_set.py b/python/cuml/tests/test_simpl_set.py index 450127e1b8..0bd22b0d01 100644 --- a/python/cuml/tests/test_simpl_set.py +++ b/python/cuml/tests/test_simpl_set.py @@ -15,6 +15,7 @@ from cuml.datasets import make_blobs import numpy as np +import cupy as cp import umap.distances as dist from cuml.manifold.umap import UMAP from umap.umap_ import fuzzy_simplicial_set as ref_fuzzy_simplicial_set @@ -25,6 +26,21 @@ as cu_simplicial_set_embedding +def correctness_dense(a, b, rtol=0.1, threshold=0.8): + n_elms = a.size + n_correct = (cp.abs(a - b) <= (rtol * cp.abs(b))).sum() + correctness = n_correct / n_elms + return correctness >= threshold + + +def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.8): + n_ref_zeros = (a == 0).sum() + n_ref_non_zero_elms = a.size - n_ref_zeros + n_correct = (cp.abs(a - b) <= (atol + rtol * cp.abs(b))).sum() + correctness = (n_correct - n_ref_zeros) / n_ref_non_zero_elms + return correctness >= threshold + + def test_fuzzy_simplicial_set(): n_rows = 10000 n_features = 16 @@ -46,12 +62,13 @@ def test_fuzzy_simplicial_set(): random_state, metric) - print('ref_fss_graph:', type(ref_fss_graph), ref_fss_graph.shape, - ref_fss_graph.nnz, ref_fss_graph.data, ref_fss_graph.row, - ref_fss_graph.col) - print('cu_fss_graph:', type(cu_fss_graph), cu_fss_graph.shape, - cu_fss_graph.nnz, cu_fss_graph.data, cu_fss_graph.row, - cu_fss_graph.col) + ref_fss_graph = cp.sparse.coo_matrix(ref_fss_graph).todense() + cu_fss_graph = cu_fss_graph.todense() + assert correctness_sparse(ref_fss_graph, + cu_fss_graph, + atol=0.1, + rtol=0.2, + threshold=0.8) def test_simplicial_set_embedding(): @@ -64,9 +81,9 @@ def test_simplicial_set_embedding(): n_components = 3 initial_alpha = 1.0 a, b = UMAP.find_ab_params(1.0, 0.1) - gamma = 1 + gamma = 0 negative_sample_rate = 5 - n_epochs = 0 + n_epochs = 500 init = 'random' metric = 'euclidean' metric_kwds = {} @@ -126,7 +143,8 @@ def test_simplicial_set_embedding(): output_metric=output_metric, output_metric_kwds=output_metric_kwds) - print('ref_embedding:', type(ref_embedding), ref_embedding.shape, - ref_embedding) - print('cu_embedding:', type(cu_embedding), cu_embedding.shape, - cu_embedding) + ref_embedding = cp.array(ref_embedding) + assert correctness_dense(ref_embedding, + cu_embedding, + rtol=0.1, + threshold=0.8) From 62523366332da698da430d6689fc95340322a9f1 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 27 Apr 2022 19:22:26 +0200 Subject: [PATCH 03/12] Remove check in UMAP on the number of rows --- python/cuml/manifold/umap.pyx | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 27b770aa7b..79f530068f 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -732,10 +732,6 @@ class UMAP(Base, n_rows = X_m.shape[0] n_cols = X_m.shape[1] - if n_rows <= 1: - raise ValueError("There needs to be more than 1 sample to " - "build nearest the neighbors graph") - if n_cols != self.n_dims: raise ValueError("n_features of X must match n_features of " "training data") From ca389ce6b8007d4ce98d944fd13b15dfd2c4e845 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 29 Apr 2022 15:09:52 +0200 Subject: [PATCH 04/12] changes for review --- cpp/include/cuml/manifold/umap.hpp | 2 + cpp/src/umap/umap.cu | 51 ++++++++++----- cpp/test/sg/umap_parametrizable_test.cu | 7 ++- python/cuml/manifold/simpl_set.pyx | 81 +++++++++++++++++++----- python/cuml/tests/test_simpl_set.py | 84 ++++++++++++++++++------- 5 files changed, 167 insertions(+), 58 deletions(-) diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index 1532be3246..dd44cf2daa 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -110,6 +110,8 @@ std::unique_ptr> get_graph(const raft::handle_t& h float* y, // labels int n, int d, + int64_t* knn_indices, + float* knn_dists, UMAPParams* params); void fit_sparse(const raft::handle_t& handle, diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 08eb3bcfc0..359783b2d7 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -137,25 +137,46 @@ void fit(const raft::handle_t& handle, } // get graph -std::unique_ptr> get_graph(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels - int n, - int d, - UMAPParams* params) +std::unique_ptr> get_graph( + const raft::handle_t& handle, + float* X, // input matrix + float* y, // labels + int n, + int d, + knn_indices_dense_t* knn_indices, // precomputed indices + float* knn_dists, // precomputed distances + UMAPParams* params) { - manifold_dense_inputs_t inputs(X, y, n, d); auto cgraph_coo = std::make_unique>(handle.get_stream()); - if (y != nullptr) { - UMAPAlgo:: - _get_graph_supervised, TPB_X>( - handle, inputs, params, cgraph_coo.get()); + if (knn_indices != nullptr && knn_dists != nullptr) { + CUML_LOG_DEBUG("Calling UMAP::get_graph() with precomputed KNN"); + + manifold_precomputed_knn_inputs_t inputs( + knn_indices, knn_dists, X, y, n, d, params->n_neighbors); + if (y != nullptr) { + UMAPAlgo::_get_graph_supervised, + TPB_X>(handle, inputs, params, cgraph_coo.get()); + } else { + UMAPAlgo::_get_graph, + TPB_X>(handle, inputs, params, cgraph_coo.get()); + } + return cgraph_coo; } else { - UMAPAlgo::_get_graph, TPB_X>( - handle, inputs, params, cgraph_coo.get()); + manifold_dense_inputs_t inputs(X, y, n, d); + if (y != nullptr) { + UMAPAlgo:: + _get_graph_supervised, TPB_X>( + handle, inputs, params, cgraph_coo.get()); + } else { + UMAPAlgo::_get_graph, TPB_X>( + handle, inputs, params, cgraph_coo.get()); + } + return cgraph_coo; } - - return cgraph_coo; } // refine diff --git a/cpp/test/sg/umap_parametrizable_test.cu b/cpp/test/sg/umap_parametrizable_test.cu index 765da3c740..dd3512cefe 100644 --- a/cpp/test/sg/umap_parametrizable_test.cu +++ b/cpp/test/sg/umap_parametrizable_test.cu @@ -187,12 +187,13 @@ class UMAPParametrizableTest : public ::testing::Test { if (test_params.refine) { std::cout << "using refine"; if (test_params.supervised) { - auto cgraph_coo = ML::UMAP::get_graph(handle, X, y, n_samples, n_features, &umap_params); + auto cgraph_coo = + ML::UMAP::get_graph(handle, X, y, n_samples, n_features, nullptr, nullptr, &umap_params); ML::UMAP::refine( handle, X, n_samples, n_features, cgraph_coo.get(), &umap_params, model_embedding); } else { - auto cgraph_coo = - ML::UMAP::get_graph(handle, X, nullptr, n_samples, n_features, &umap_params); + auto cgraph_coo = ML::UMAP::get_graph( + handle, X, nullptr, n_samples, n_features, nullptr, nullptr, &umap_params); ML::UMAP::refine( handle, X, n_samples, n_features, cgraph_coo.get(), &umap_params, model_embedding); } diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx index 0e5e1eabbd..0202df545b 100644 --- a/python/cuml/manifold/simpl_set.pyx +++ b/python/cuml/manifold/simpl_set.pyx @@ -87,6 +87,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": float* y, int n, int d, + int64_t* knn_indices, + float* knn_dists, UMAPParams* params) void refine(handle_t &handle, @@ -103,6 +105,8 @@ def fuzzy_simplicial_set(X, random_state=None, metric="euclidean", metric_kwds={}, + knn_indices=None, + knn_dists=None, set_op_mix_ratio=1.0, local_connectivity=1.0, verbose=False): @@ -127,6 +131,16 @@ def fuzzy_simplicial_set(X, unused metric_kwds: dict (optional, default {}) unused + knn_indices: array of shape (n_samples, n_neighbors) (optional) + If the k-nearest neighbors of each point has already been calculated + you can pass them in here to save computation time. This should be + an array with the indices of the k-nearest neighbors as a row for + each data point. + knn_dists: array of shape (n_samples, n_neighbors) (optional) + If the k-nearest neighbors of each point has already been calculated + you can pass them in here to save computation time. This should be + an array with the distances of the k-nearest neighbors as a row for + each data point. set_op_mix_ratio: float (optional, default 1.0) Interpolate between (fuzzy) union and intersection as the set operation used to combine local fuzzy simplicial sets to obtain a global fuzzy @@ -168,17 +182,47 @@ def fuzzy_simplicial_set(X, umap_params.local_connectivity = local_connectivity umap_params.verbosity = verbose - X_m, n_rows, n_cols, dtype = \ - input_to_cuml_array(X, order='C', check_dtype=np.float32) + X_m, n_rows, n_cols, _ = \ + input_to_cuml_array(X, + order='C', + check_dtype=np.float32, + convert_to_dtype=np.float32) + + if knn_indices is not None and knn_dists is not None: + knn_indices_m, _, _, _ = \ + input_to_cuml_array(knn_indices, + order='C', + check_dtype=np.int64, + convert_to_dtype=np.int64) + knn_dists_m, _, _, _ = \ + input_to_cuml_array(knn_dists, + order='C', + check_dtype=np.float32, + convert_to_dtype=np.float32) + X_ptr = 0 + knn_indices_ptr = knn_indices_m.ptr + knn_dists_ptr = knn_dists_m.ptr + else: + X_m, _, _, _ = \ + input_to_cuml_array(X, + order='C', + check_dtype=np.float32, + convert_to_dtype=np.float32) + X_ptr = X_m.ptr + knn_indices_ptr = 0 + knn_dists_ptr = 0 handle = Handle() cdef handle_t* handle_ = handle.getHandle() - cdef unique_ptr[COO] fss_graph = get_graph(handle_[0], - X_m.ptr, - NULL, - X_m.shape[0], - X_m.shape[1], - umap_params) + cdef unique_ptr[COO] fss_graph = get_graph( + handle_[0], + X_ptr, + NULL, + X.shape[0], + X.shape[1], + knn_indices_ptr, + knn_dists_ptr, + umap_params) cdef int nnz = fss_graph.get().nnz coo_arrays = [] @@ -186,7 +230,7 @@ def fuzzy_simplicial_set(X, (fss_graph.get().rows(), np.int32), (fss_graph.get().cols(), np.int32)]: mem = cp.cuda.UnownedMemory(ptr=ptr, - size=nnz * 4, + size=nnz * np.dtype(dtype).itemsize, owner=None, device_id=-1) memptr = cp.cuda.memory.MemoryPointer(mem, 0) @@ -194,6 +238,8 @@ def fuzzy_simplicial_set(X, result = cp.sparse.coo_matrix(((coo_arrays[0], (coo_arrays[1], coo_arrays[2])))) + + # TODO : remove this copy once RAFT sparse objects are redesigned result = result.copy() return result @@ -334,22 +380,23 @@ def simplicial_set_embedding( fss_graph.allocate(graph.nnz, graph.shape[0], False, handle_.get_stream()) handle_.sync_stream() - nnz_size = graph.nnz * 4 # sizeof(float) == sizeof(int) == 4 - for dst_ptr, src_ptr in \ - [(fss_graph.vals(), graph.data.data.ptr), - (fss_graph.rows(), graph.row.data.ptr), - (fss_graph.cols(), graph.col.data.ptr)]: + nnz = graph.nnz + for dst_ptr, src_ptr, dtype in \ + [(fss_graph.vals(), graph.data.data.ptr, np.float32), + (fss_graph.rows(), graph.row.data.ptr, np.int32), + (fss_graph.cols(), graph.col.data.ptr, np.int32)]: + itemsize = np.dtype(dtype).itemsize dest_buff = cp.cuda.UnownedMemory(ptr=dst_ptr, - size=nnz_size, + size=nnz * itemsize, owner=None, device_id=-1) dest_mptr = cp.cuda.memory.MemoryPointer(dest_buff, 0) src_buff = cp.cuda.UnownedMemory(ptr=src_ptr, - size=nnz_size, + size=nnz * itemsize, owner=None, device_id=-1) src_mptr = cp.cuda.memory.MemoryPointer(src_buff, 0) - dest_mptr.copy_from_device(src_mptr, nnz_size) + dest_mptr.copy_from_device(src_mptr, nnz * itemsize) embedding = CumlArray.zeros((X_m.shape[0], n_components), order="C", dtype=np.float32, diff --git a/python/cuml/tests/test_simpl_set.py b/python/cuml/tests/test_simpl_set.py index 0bd22b0d01..ebcf1d4727 100644 --- a/python/cuml/tests/test_simpl_set.py +++ b/python/cuml/tests/test_simpl_set.py @@ -13,11 +13,14 @@ # limitations under the License. # +import pytest from cuml.datasets import make_blobs import numpy as np import cupy as cp import umap.distances as dist from cuml.manifold.umap import UMAP +from cuml.neighbors import NearestNeighbors + from umap.umap_ import fuzzy_simplicial_set as ref_fuzzy_simplicial_set from cuml.manifold.simpl_set import fuzzy_simplicial_set \ as cu_fuzzy_simplicial_set @@ -26,14 +29,14 @@ as cu_simplicial_set_embedding -def correctness_dense(a, b, rtol=0.1, threshold=0.8): +def correctness_dense(a, b, rtol=0.1, threshold=0.95): n_elms = a.size n_correct = (cp.abs(a - b) <= (rtol * cp.abs(b))).sum() correctness = n_correct / n_elms return correctness >= threshold -def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.8): +def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): n_ref_zeros = (a == 0).sum() n_ref_non_zero_elms = a.size - n_ref_zeros n_correct = (cp.abs(a - b) <= (atol + rtol * cp.abs(b))).sum() @@ -41,44 +44,79 @@ def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.8): return correctness >= threshold -def test_fuzzy_simplicial_set(): - n_rows = 10000 - n_features = 16 +@pytest.mark.parametrize('n_rows', [800, 5000]) +@pytest.mark.parametrize('n_features', [8, 32]) +@pytest.mark.parametrize('n_neighbors', [8, 16]) +@pytest.mark.parametrize('precomputed_nearest_neighbors', [False, True]) +def test_fuzzy_simplicial_set(n_rows, + n_features, + n_neighbors, + precomputed_nearest_neighbors): n_clusters = 30 - n_neighbors = 5 random_state = 42 metric = 'euclidean' X, _ = make_blobs(n_samples=n_rows, centers=n_clusters, n_features=n_features, random_state=random_state) - X = X.get() - ref_fss_graph = ref_fuzzy_simplicial_set(X, - n_neighbors, - random_state, - metric)[0].tocoo() - cu_fss_graph = cu_fuzzy_simplicial_set(X, - n_neighbors, - random_state, - metric) + if precomputed_nearest_neighbors: + nn = NearestNeighbors(n_neighbors=n_neighbors, + metric=metric) + nn.fit(X) + knn_dists, knn_indices = nn.kneighbors(X, + n_neighbors, + return_distance=True) + cu_fss_graph = cu_fuzzy_simplicial_set( + X, + n_neighbors, + random_state, + metric, + knn_indices=knn_indices, + knn_dists=knn_dists) + + knn_indices = knn_indices.get() + knn_dists = knn_dists.get() + ref_fss_graph = ref_fuzzy_simplicial_set( + X, + n_neighbors, + random_state, + metric, + knn_indices=knn_indices, + knn_dists=knn_dists)[0].tocoo() + else: + cu_fss_graph = cu_fuzzy_simplicial_set( + X, + n_neighbors, + random_state, + metric) + + X = X.get() + ref_fss_graph = ref_fuzzy_simplicial_set( + X, + n_neighbors, + random_state, + metric)[0].tocoo() - ref_fss_graph = cp.sparse.coo_matrix(ref_fss_graph).todense() cu_fss_graph = cu_fss_graph.todense() + ref_fss_graph = cp.sparse.coo_matrix(ref_fss_graph).todense() assert correctness_sparse(ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, - threshold=0.8) + threshold=0.95) -def test_simplicial_set_embedding(): - n_rows = 10000 - n_features = 16 +@pytest.mark.parametrize('n_rows', [800, 5000]) +@pytest.mark.parametrize('n_features', [8, 32]) +@pytest.mark.parametrize('n_neighbors', [8, 16]) +@pytest.mark.parametrize('n_components', [2, 5]) +def test_simplicial_set_embedding(n_rows, + n_features, + n_neighbors, + n_components): n_clusters = 30 - n_neighbors = 5 random_state = 42 metric = 'euclidean' - n_components = 3 initial_alpha = 1.0 a, b = UMAP.find_ab_params(1.0, 0.1) gamma = 0 @@ -147,4 +185,4 @@ def test_simplicial_set_embedding(): assert correctness_dense(ref_embedding, cu_embedding, rtol=0.1, - threshold=0.8) + threshold=0.95) From bfad70bb0e4f4c8f6470a35b8a88c0af87a81b3c Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 29 Apr 2022 16:23:06 +0200 Subject: [PATCH 05/12] remove cupy copy --- python/cuml/manifold/simpl_set.pyx | 35 ++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx index 0202df545b..b117ef5d30 100644 --- a/python/cuml/manifold/simpl_set.pyx +++ b/python/cuml/manifold/simpl_set.pyx @@ -100,6 +100,25 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": float* embeddings) +cdef class GraphHolder: + cdef unique_ptr[COO] c_graph + + def vals(self): + return self.c_graph.get().vals() + + def rows(self): + return self.c_graph.get().rows() + + def cols(self): + return self.c_graph.get().cols() + + def get_nnz(self): + return self.c_graph.get().nnz + + def __dealloc__(self): + self.c_graph.reset(NULL) + + def fuzzy_simplicial_set(X, n_neighbors, random_state=None, @@ -214,7 +233,8 @@ def fuzzy_simplicial_set(X, handle = Handle() cdef handle_t* handle_ = handle.getHandle() - cdef unique_ptr[COO] fss_graph = get_graph( + fss_graph = GraphHolder() + fss_graph.c_graph = get_graph( handle_[0], X_ptr, NULL, @@ -224,14 +244,14 @@ def fuzzy_simplicial_set(X, knn_dists_ptr, umap_params) - cdef int nnz = fss_graph.get().nnz + cdef int nnz = fss_graph.get_nnz() coo_arrays = [] - for ptr, dtype in [(fss_graph.get().vals(), np.float32), - (fss_graph.get().rows(), np.int32), - (fss_graph.get().cols(), np.int32)]: + for ptr, dtype in [(fss_graph.vals(), np.float32), + (fss_graph.rows(), np.int32), + (fss_graph.cols(), np.int32)]: mem = cp.cuda.UnownedMemory(ptr=ptr, size=nnz * np.dtype(dtype).itemsize, - owner=None, + owner=fss_graph, device_id=-1) memptr = cp.cuda.memory.MemoryPointer(mem, 0) coo_arrays.append(cp.ndarray(nnz, dtype=dtype, memptr=memptr)) @@ -239,9 +259,6 @@ def fuzzy_simplicial_set(X, result = cp.sparse.coo_matrix(((coo_arrays[0], (coo_arrays[1], coo_arrays[2])))) - # TODO : remove this copy once RAFT sparse objects are redesigned - result = result.copy() - return result From cf19646a806592a80e57d8c59f88ae603d2e730d Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 29 Apr 2022 16:44:12 +0200 Subject: [PATCH 06/12] create umap include file --- python/cuml/manifold/simpl_set.pyx | 39 ++------------------------ python/cuml/manifold/umap.pyx | 43 ++--------------------------- python/cuml/manifold/umap_utils.pxd | 41 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 78 deletions(-) create mode 100644 python/cuml/manifold/umap_utils.pxd diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx index b117ef5d30..a0d15ae3d8 100644 --- a/python/cuml/manifold/simpl_set.pyx +++ b/python/cuml/manifold/simpl_set.pyx @@ -19,6 +19,8 @@ import numpy as np import cupy as cp +from cuml.manifold.umap_utils cimport * + import cuml.internals from cuml.common.base import Base from cuml.common.input_utils import input_to_cuml_array @@ -30,46 +32,9 @@ from raft.common.handle cimport handle_t from raft.common.handle import Handle from libc.stdint cimport uintptr_t -from libc.stdint cimport uint64_t -from libc.stdint cimport int64_t -from libcpp cimport bool from libcpp.memory cimport unique_ptr -cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": - - enum MetricType: - EUCLIDEAN = 0, - CATEGORICAL = 1 - - -cdef extern from "cuml/manifold/umapparams.h" namespace "ML": - - cdef cppclass UMAPParams: - int n_neighbors, - int n_components, - int n_epochs, - float learning_rate, - float min_dist, - float spread, - float set_op_mix_ratio, - float local_connectivity, - float repulsion_strength, - int negative_sample_rate, - float transform_queue_size, - int verbosity, - float a, - float b, - float initial_alpha, - int init, - int target_n_neighbors, - MetricType target_metric, - float target_weight, - uint64_t random_state, - bool deterministic, - int optim_batch_size - - cdef extern from "raft/sparse/coo.hpp": cdef cppclass COO "raft::sparse::COO": COO(cuda_stream_view stream) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 79f530068f..19b25ac8d2 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -30,6 +30,8 @@ import cupyx import numba.cuda as cuda +from cuml.manifold.umap_utils cimport * + from cuml.common.sparsefuncs import extract_knn_graph from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix,\ coo_matrix as cp_coo_matrix, csc_matrix as cp_csc_matrix @@ -55,10 +57,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor import rmm -from libcpp cimport bool from libc.stdint cimport uintptr_t -from libc.stdint cimport uint64_t -from libc.stdint cimport int64_t from libc.stdlib cimport calloc, malloc, free from libcpp.memory cimport shared_ptr @@ -66,44 +65,6 @@ from libcpp.memory cimport shared_ptr cimport cuml.common.cuda -cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": - - enum MetricType: - EUCLIDEAN = 0, - CATEGORICAL = 1 - -cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals": - - cdef cppclass GraphBasedDimRedCallback - -cdef extern from "cuml/manifold/umapparams.h" namespace "ML": - - cdef cppclass UMAPParams: - int n_neighbors, - int n_components, - int n_epochs, - float learning_rate, - float min_dist, - float spread, - float set_op_mix_ratio, - float local_connectivity, - float repulsion_strength, - int negative_sample_rate, - float transform_queue_size, - int verbosity, - float a, - float b, - float initial_alpha, - int init, - int target_n_neighbors, - MetricType target_metric, - float target_weight, - uint64_t random_state, - bool deterministic, - int optim_batch_size, - GraphBasedDimRedCallback * callback - - cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": void fit(handle_t & handle, diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd new file mode 100644 index 0000000000..ae26afbf4a --- /dev/null +++ b/python/cuml/manifold/umap_utils.pxd @@ -0,0 +1,41 @@ +from libc.stdint cimport uint64_t +from libc.stdint cimport int64_t +from libcpp cimport bool + + +cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": + + enum MetricType: + EUCLIDEAN = 0, + CATEGORICAL = 1 + +cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals": + + cdef cppclass GraphBasedDimRedCallback + +cdef extern from "cuml/manifold/umapparams.h" namespace "ML": + + cdef cppclass UMAPParams: + int n_neighbors, + int n_components, + int n_epochs, + float learning_rate, + float min_dist, + float spread, + float set_op_mix_ratio, + float local_connectivity, + float repulsion_strength, + int negative_sample_rate, + float transform_queue_size, + int verbosity, + float a, + float b, + float initial_alpha, + int init, + int target_n_neighbors, + MetricType target_metric, + float target_weight, + uint64_t random_state, + bool deterministic, + int optim_batch_size, + GraphBasedDimRedCallback * callback From b4a8cc8290b6a257c84cc0169fd7554b7b0b38ad Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 9 May 2022 12:18:02 +0200 Subject: [PATCH 07/12] Fix style --- cpp/include/cuml/manifold/umap.hpp | 2 +- python/cuml/manifold/umap_utils.pxd | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index dd44cf2daa..ea7fd4a33d 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd index ae26afbf4a..5915a6bf87 100644 --- a/python/cuml/manifold/umap_utils.pxd +++ b/python/cuml/manifold/umap_utils.pxd @@ -1,3 +1,21 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + from libc.stdint cimport uint64_t from libc.stdint cimport int64_t from libcpp cimport bool From e2b4873608cb516499d09024db0bef021a26e3bb Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 11 May 2022 18:51:26 +0200 Subject: [PATCH 08/12] Add fuzzy simplicial set as the graph_ attribute in UMAP --- cpp/bench/sg/umap.cu | 12 +- cpp/include/cuml/manifold/umap.hpp | 204 +++++++++++++------ cpp/src/umap/runner.cuh | 253 +++++++----------------- cpp/src/umap/umap.cu | 239 +++++++++++----------- cpp/test/sg/umap_parametrizable_test.cu | 17 +- python/cuml/manifold/simpl_set.pyx | 83 ++------ python/cuml/manifold/umap.pyx | 18 +- python/cuml/manifold/umap_utils.pxd | 34 +++- python/cuml/manifold/umap_utils.pyx | 102 ++++++++++ python/cuml/tests/test_umap.py | 66 +++++++ 10 files changed, 566 insertions(+), 462 deletions(-) create mode 100644 python/cuml/manifold/umap_utils.pyx diff --git a/cpp/bench/sg/umap.cu b/cpp/bench/sg/umap.cu index fa05df69fb..c75dc0e51e 100644 --- a/cpp/bench/sg/umap.cu +++ b/cpp/bench/sg/umap.cu @@ -115,6 +115,7 @@ class UmapSupervised : public UmapBase { protected: void coreBenchmarkMethod() { + auto graph = raft::sparse::COO(stream); UMAP::fit(*this->handle, this->data.X.data(), yFloat, @@ -123,7 +124,8 @@ class UmapSupervised : public UmapBase { nullptr, nullptr, &uParams, - embeddings); + embeddings, + &graph); } }; ML_BENCH_REGISTER(Params, UmapSupervised, "blobs", getInputs()); @@ -135,6 +137,7 @@ class UmapUnsupervised : public UmapBase { protected: void coreBenchmarkMethod() { + auto graph = raft::sparse::COO(stream); UMAP::fit(*this->handle, this->data.X.data(), nullptr, @@ -143,7 +146,8 @@ class UmapUnsupervised : public UmapBase { nullptr, nullptr, &uParams, - embeddings); + embeddings, + &graph); } }; ML_BENCH_REGISTER(Params, UmapUnsupervised, "blobs", getInputs()); @@ -173,6 +177,7 @@ class UmapTransform : public UmapBase { UmapBase::allocateBuffers(state); auto& handle = *this->handle; alloc(transformed, this->params.nrows * uParams.n_components); + auto graph = raft::sparse::COO(stream); UMAP::fit(handle, this->data.X.data(), yFloat, @@ -181,7 +186,8 @@ class UmapTransform : public UmapBase { nullptr, nullptr, &uParams, - embeddings); + embeddings, + &graph); } void deallocateBuffers(const ::benchmark::State& state) { diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index ea7fd4a33d..008dbb7155 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -30,6 +30,126 @@ namespace ML { class UMAPParams; namespace UMAP { +/** + * Returns the simplical set to be consumed by the ML::UMAP::refine function. + * + * @param[in] handle: raft::handle_t + * @param[out] params: pointer to ML::UMAPParams object of which the a and b parameters will be + * updated + */ +void find_ab(const raft::handle_t& handle, UMAPParams* params); + +/** + * Returns the simplical set to be consumed by the ML::UMAP::refine function. + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array + * @param[in] y: pointer to labels array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] knn_indices: pointer to knn_indices (optional) + * @param[in] knn_dists: pointer to knn_dists (optional) + * @param[in] params: pointer to ML::UMAPParams object + * @return: simplical set as a unique pointer to a raft::sparse::COO object + */ +std::unique_ptr> get_graph(const raft::handle_t& handle, + float* X, // input matrix + float* y, // labels + int n, + int d, + int64_t* knn_indices, + float* knn_dists, + UMAPParams* params); + +/** + * Performs a UMAP fit on existing embeddings without reinitializing them, which enables + * iterative fitting without callbacks. + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] graph: pointer to raft::sparse::COO object computed using ML::UMAP::get_graph + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] embeddings: pointer to current embedding with shape n * n_components, stores updated + * embeddings on executing refine + */ +void refine(const raft::handle_t& handle, + float* X, + int n, + int d, + raft::sparse::COO* graph, + UMAPParams* params, + float* embeddings); + +/** + * Dense fit + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array + * @param[in] y: pointer to labels array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] knn_indices: pointer to knn_indices of input (optional) + * @param[in] knn_dists: pointer to knn_dists of input (optional) + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] embeddings: pointer to embedding produced through projection + * @param[out] graph: pointer to fuzzy simplicial set graph + */ +void fit(const raft::handle_t& handle, + float* X, + float* y, + int n, + int d, + int64_t* knn_indices, + float* knn_dists, + UMAPParams* params, + float* embeddings, + raft::sparse::COO* graph); + +/** + * Sparse fit + * + * @param[in] handle: raft::handle_t + * @param[in] indptr: pointer to index pointer array of input array + * @param[in] indices: pointer to index array of input array + * @param[in] data: pointer to data array of input array + * @param[in] nnz: pointer to data array of input array + * @param[in] y: pointer to labels array + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] embeddings: pointer to embedding produced through projection + * @param[out] graph: pointer to fuzzy simplicial set graph + */ +void fit_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + size_t nnz, + float* y, + int n, + int d, + UMAPParams* params, + float* embeddings, + raft::sparse::COO* graph); + +/** + * Dense transform + * + * @param[in] handle: raft::handle_t + * @param[in] X: pointer to input array to be infered + * @param[in] n: n_samples of input array to be infered + * @param[in] d: n_features of input array to be infered + * @param[in] knn_indices: pointer to knn_indices of input (optional) + * @param[in] knn_dists: pointer to knn_dists of input (optional) + * @param[in] orig_X: pointer to original training array + * @param[in] orig_n: number of rows in original training array + * @param[in] embedding: pointer to embedding created during training + * @param[in] embedding_n: number of rows in embedding created during training + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] transformed: pointer to embedding produced through projection + */ void transform(const raft::handle_t& handle, float* X, int n, @@ -43,6 +163,26 @@ void transform(const raft::handle_t& handle, UMAPParams* params, float* transformed); +/** + * Sparse transform + * + * @param[in] handle: raft::handle_t + * @param[in] indptr: pointer to index pointer array of input array to be infered + * @param[in] indices: pointer to index array of input array to be infered + * @param[in] data: pointer to data array of input array to be infered + * @param[in] nnz: number of stored values of input array to be infered + * @param[in] n: n_samples of input array + * @param[in] d: n_features of input array + * @param[in] orig_x_indptr: pointer to index pointer array of original training array + * @param[in] orig_x_indices: pointer to index array of original training array + * @param[in] orig_x_data: pointer to data array of original training array + * @param[in] orig_nnz: number of stored values of original training array + * @param[in] orig_n: number of rows in original training array + * @param[in] embedding: pointer to embedding created during training + * @param[in] embedding_n: number of rows in embedding created during training + * @param[in] params: pointer to ML::UMAPParams object + * @param[out] transformed: pointer to embedding produced through projection + */ void transform_sparse(const raft::handle_t& handle, int* indptr, int* indices, @@ -60,69 +200,5 @@ void transform_sparse(const raft::handle_t& handle, UMAPParams* params, float* transformed); -void find_ab(const raft::handle_t& handle, UMAPParams* params); - -void fit(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels - int n, - int d, - int64_t* knn_indices, - float* knn_dists, - UMAPParams* params, - float* embeddings); - -/** - * refine performs a UMAP fit on existing embeddings without reinitializing them, which enables - * iterative fitting without callbacks. - * - * @param handle: raft::handle_t - * @param X: pointer to input array - * @param n: n_samples of input array - * @param d: n_features of input array - * @param cgraph_coo: pointer to raft::sparse::COO object computed using ML::UMAP::get_graph - * @param params: pointer to ML::UMAPParams object - * @param embeddings: pointer to current embedding with shape n * n_components, stores updated - * embeddings on executing refine - */ -void refine(const raft::handle_t& handle, - float* X, // input matrix - int n, - int d, - raft::sparse::COO* cgraph_coo, - UMAPParams* params, - float* embeddings); - -/** - * returns a simplical set as a raft::sparse:COO object to be consumed by the ML::UMAP::refine - * function. - * - * @param handle: raft::handle_t - * @param X: pointer to input array - * @param y: pointer to labels array - * @param n: n_samples of input array - * @param d: n_features of input array - * @param params: pointer to ML::UMAPParams object - * @return: simplical set (pointer to raft::sparse::COO object) - */ -std::unique_ptr> get_graph(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels - int n, - int d, - int64_t* knn_indices, - float* knn_dists, - UMAPParams* params); - -void fit_sparse(const raft::handle_t& handle, - int* indptr, // input matrix - int* indices, - float* data, - size_t nnz, - float* y, - int n, // rows - int d, // cols - UMAPParams* params, - float* embeddings); } // namespace UMAP } // namespace ML diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 4c57a388b6..bbbc48d1b6 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -87,89 +87,11 @@ __global__ void init_transform(int* indices, */ void find_ab(UMAPParams* params, cudaStream_t stream) { Optimize::find_params_ab(params, stream); } -template -void _fit(const raft::handle_t& handle, - const umap_inputs& inputs, - UMAPParams* params, - value_t* embeddings) -{ - raft::common::nvtx::range fun_scope("umap::unsupervised::fit"); - cudaStream_t stream = handle.get_stream(); - - int k = params->n_neighbors; - - ML::Logger::get().setLevel(params->verbosity); - - CUML_LOG_DEBUG("n_neighbors=%d", params->n_neighbors); - - raft::common::nvtx::push_range("umap::knnGraph"); - std::unique_ptr> knn_indices_b = nullptr; - std::unique_ptr> knn_dists_b = nullptr; - - knn_graph knn_graph(inputs.n, k); - - /** - * If not given precomputed knn graph, compute it - */ - if (inputs.alloc_knn_graph()) { - /** - * Allocate workspace for kNN graph - */ - knn_indices_b = std::make_unique>(inputs.n * k, stream); - knn_dists_b = std::make_unique>(inputs.n * k, stream); - - knn_graph.knn_indices = knn_indices_b->data(); - knn_graph.knn_dists = knn_dists_b->data(); - } - - CUML_LOG_DEBUG("Calling knn graph run"); - - kNNGraph::run( - handle, inputs, inputs, knn_graph, k, params, stream); - raft::common::nvtx::pop_range(); - - CUML_LOG_DEBUG("Done. Calling fuzzy simplicial set"); - - raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); - FuzzySimplSet::run( - inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &rgraph_coo, params, stream); - - CUML_LOG_DEBUG("Done. Calling remove zeros"); - /** - * Remove zeros from simplicial set - */ - raft::sparse::COO cgraph_coo(stream); - raft::sparse::op::coo_remove_zeros(&rgraph_coo, &cgraph_coo, stream); - raft::common::nvtx::pop_range(); - - /** - * Run initialization method - */ - raft::common::nvtx::push_range("umap::embedding"); - InitEmbed::run(handle, inputs.n, inputs.d, &cgraph_coo, params, embeddings, stream, params->init); - - if (params->callback) { - params->callback->setup(inputs.n, params->n_components); - params->callback->on_preprocess_end(embeddings); - } - - /** - * Run simplicial set embedding to approximate low-dimensional representation - */ - SimplSetEmbed::run(inputs.n, inputs.d, &cgraph_coo, params, embeddings, stream); - raft::common::nvtx::pop_range(); - - if (params->callback) params->callback->on_train_end(embeddings); -} - template void _get_graph(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, - raft::sparse::COO* cgraph_coo // assumes single-precision int as the - // second template argument for COO -) + raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::supervised::_get_graph"); cudaStream_t stream = handle.get_stream(); @@ -209,27 +131,24 @@ void _get_graph(const raft::handle_t& handle, CUML_LOG_DEBUG("Done. Calling fuzzy simplicial set"); raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); + raft::sparse::COO fss_graph(stream); FuzzySimplSet::run( - inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &rgraph_coo, params, stream); + inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &fss_graph, params, stream); CUML_LOG_DEBUG("Done. Calling remove zeros"); /** * Remove zeros from simplicial set */ - raft::sparse::op::coo_remove_zeros(&rgraph_coo, cgraph_coo, stream); + raft::sparse::op::coo_remove_zeros(&fss_graph, graph, stream); raft::common::nvtx::pop_range(); } template -void _get_graph_supervised( - const raft::handle_t& handle, - const umap_inputs& inputs, - UMAPParams* params, - raft::sparse::COO* cgraph_coo // assumes single-precision int as the - // second template argument for COO -) +void _get_graph_supervised(const raft::handle_t& handle, + const umap_inputs& inputs, + UMAPParams* params, + raft::sparse::COO* graph) { raft::common::nvtx::range fun_scope("umap::supervised::_get_graph_supervised"); cudaStream_t stream = handle.get_stream(); @@ -269,24 +188,27 @@ void _get_graph_supervised( * Allocate workspace for fuzzy simplicial set. */ raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); - raft::sparse::COO tmp_coo(stream); - - /** - * Run Fuzzy simplicial set - */ - // int nnz = n*k*2; - FuzzySimplSet::run(inputs.n, - knn_graph.knn_indices, - knn_graph.knn_dists, - params->n_neighbors, - &tmp_coo, - params, - stream); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - raft::sparse::op::coo_remove_zeros(&tmp_coo, &rgraph_coo, stream); + raft::sparse::COO fss_graph(stream); + { + raft::sparse::COO fss_graph_tmp(stream); + /** + * Run Fuzzy simplicial set + */ + // int nnz = n*k*2; + FuzzySimplSet::run(inputs.n, + knn_graph.knn_indices, + knn_graph.knn_dists, + params->n_neighbors, + &fss_graph_tmp, + params, + stream); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + raft::sparse::op::coo_remove_zeros(&fss_graph_tmp, &fss_graph, stream); + } + raft::sparse::COO ci_graph(stream); /** * If target metric is 'categorical', perform * categorical simplicial set intersection. @@ -294,7 +216,7 @@ void _get_graph_supervised( if (params->target_metric == ML::UMAPParams::MetricType::CATEGORICAL) { CUML_LOG_DEBUG("Performing categorical intersection"); Supervised::perform_categorical_intersection( - inputs.y, &rgraph_coo, cgraph_coo, params, stream); + inputs.y, &fss_graph, &ci_graph, params, stream); /** * Otherwise, perform general simplicial set intersection @@ -302,16 +224,14 @@ void _get_graph_supervised( } else { CUML_LOG_DEBUG("Performing general intersection"); Supervised::perform_general_intersection( - handle, inputs.y, &rgraph_coo, cgraph_coo, params, stream); + handle, inputs.y, &fss_graph, &ci_graph, params, stream); } /** * Remove zeros */ - raft::sparse::op::coo_sort(cgraph_coo, stream); - - raft::sparse::COO ocoo(stream); - raft::sparse::op::coo_remove_zeros(cgraph_coo, &ocoo, stream); + raft::sparse::op::coo_sort(&ci_graph, stream); + raft::sparse::op::coo_remove_zeros(&ci_graph, graph, stream); raft::common::nvtx::pop_range(); } @@ -319,112 +239,71 @@ template void _refine(const raft::handle_t& handle, const umap_inputs& inputs, UMAPParams* params, - raft::sparse::COO* cgraph_coo, + raft::sparse::COO* graph, value_t* embeddings) { cudaStream_t stream = handle.get_stream(); /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(inputs.n, inputs.d, cgraph_coo, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); } template -void _fit_supervised(const raft::handle_t& handle, - const umap_inputs& inputs, - UMAPParams* params, - value_t* embeddings) +void _fit(const raft::handle_t& handle, + const umap_inputs& inputs, + UMAPParams* params, + value_t* embeddings, + raft::sparse::COO* graph) { - raft::common::nvtx::range fun_scope("umap::supervised::fit"); - cudaStream_t stream = handle.get_stream(); - - int k = params->n_neighbors; + raft::common::nvtx::range fun_scope("umap::unsupervised::fit"); + cudaStream_t stream = handle.get_stream(); ML::Logger::get().setLevel(params->verbosity); - if (params->target_n_neighbors == -1) params->target_n_neighbors = params->n_neighbors; - - raft::common::nvtx::push_range("umap::knnGraph"); - std::unique_ptr> knn_indices_b = nullptr; - std::unique_ptr> knn_dists_b = nullptr; - - knn_graph knn_graph(inputs.n, k); + UMAPAlgo::_get_graph(handle, inputs, params, graph); /** - * If not given precomputed knn graph, compute it + * Run initialization method */ - if (inputs.alloc_knn_graph()) { - /** - * Allocate workspace for kNN graph - */ - knn_indices_b = std::make_unique>(inputs.n * k, stream); - knn_dists_b = std::make_unique>(inputs.n * k, stream); + raft::common::nvtx::push_range("umap::embedding"); + InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); - knn_graph.knn_indices = knn_indices_b->data(); - knn_graph.knn_dists = knn_dists_b->data(); + if (params->callback) { + params->callback->setup(inputs.n, params->n_components); + params->callback->on_preprocess_end(embeddings); } - kNNGraph::run( - handle, inputs, inputs, knn_graph, k, params, stream); - - raft::common::nvtx::pop_range(); - /** - * Allocate workspace for fuzzy simplicial set. - */ - raft::common::nvtx::push_range("umap::simplicial_set"); - raft::sparse::COO rgraph_coo(stream); - raft::sparse::COO tmp_coo(stream); - - /** - * Run Fuzzy simplicial set + * Run simplicial set embedding to approximate low-dimensional representation */ - // int nnz = n*k*2; - FuzzySimplSet::run(inputs.n, - knn_graph.knn_indices, - knn_graph.knn_dists, - params->n_neighbors, - &tmp_coo, - params, - stream); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - raft::sparse::op::coo_remove_zeros(&tmp_coo, &rgraph_coo, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); + raft::common::nvtx::pop_range(); - raft::sparse::COO final_coo(stream); + if (params->callback) params->callback->on_train_end(embeddings); - /** - * If target metric is 'categorical', perform - * categorical simplicial set intersection. - */ - if (params->target_metric == ML::UMAPParams::MetricType::CATEGORICAL) { - CUML_LOG_DEBUG("Performing categorical intersection"); - Supervised::perform_categorical_intersection( - inputs.y, &rgraph_coo, &final_coo, params, stream); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} - /** - * Otherwise, perform general simplicial set intersection - */ - } else { - CUML_LOG_DEBUG("Performing general intersection"); - Supervised::perform_general_intersection( - handle, inputs.y, &rgraph_coo, &final_coo, params, stream); - } +template +void _fit_supervised(const raft::handle_t& handle, + const umap_inputs& inputs, + UMAPParams* params, + value_t* embeddings, + raft::sparse::COO* graph) +{ + raft::common::nvtx::range fun_scope("umap::supervised::fit"); - /** - * Remove zeros - */ - raft::sparse::op::coo_sort(&final_coo, stream); + cudaStream_t stream = handle.get_stream(); + ML::Logger::get().setLevel(params->verbosity); - raft::sparse::COO ocoo(stream); - raft::sparse::op::coo_remove_zeros(&final_coo, &ocoo, stream); - raft::common::nvtx::pop_range(); + UMAPAlgo::_get_graph(handle, inputs, params, graph); /** * Initialize embeddings */ raft::common::nvtx::push_range("umap::supervised::fit"); - InitEmbed::run(handle, inputs.n, inputs.d, &ocoo, params, embeddings, stream, params->init); + InitEmbed::run(handle, inputs.n, inputs.d, graph, params, embeddings, stream, params->init); if (params->callback) { params->callback->setup(inputs.n, params->n_components); @@ -434,7 +313,7 @@ void _fit_supervised(const raft::handle_t& handle, /** * Run simplicial set embedding to approximate low-dimensional representation */ - SimplSetEmbed::run(inputs.n, inputs.d, &ocoo, params, embeddings, stream); + SimplSetEmbed::run(inputs.n, inputs.d, graph, params, embeddings, stream); raft::common::nvtx::pop_range(); if (params->callback) params->callback->on_train_end(embeddings); diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 359783b2d7..b5be1b2543 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -28,115 +28,12 @@ namespace UMAP { static const int TPB_X = 256; -/** - * \brief Dense transform - * - * @param X: pointer to input array - * @param n: n_samples of input array - * @param d: n_features of input array - * @param orig_X: self.X_m.ptr - * @param orig_n: self.n_rows - * @param embedding: pointer to stored embedding - * @param embedding_n: n_samples in embedding, equals to orig_n - * @param transformed: output array with shape n * n_components - */ -void transform(const raft::handle_t& handle, - float* X, - int n, - int d, - knn_indices_dense_t* knn_indices, - float* knn_dists, - float* orig_X, - int orig_n, - float* embedding, - int embedding_n, - UMAPParams* params, - float* transformed) -{ - if (knn_indices != nullptr && knn_dists != nullptr) { - manifold_precomputed_knn_inputs_t inputs( - knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors); - UMAPAlgo::_transform, - TPB_X>( - handle, inputs, inputs, embedding, embedding_n, params, transformed); - } else { - manifold_dense_inputs_t inputs(X, nullptr, n, d); - manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); - UMAPAlgo::_transform, TPB_X>( - handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); - } -} - -// Sparse transform -void transform_sparse(const raft::handle_t& handle, - int* indptr, - int* indices, - float* data, - size_t nnz, - int n, - int d, - int* orig_x_indptr, - int* orig_x_indices, - float* orig_x_data, - size_t orig_nnz, - int orig_n, - float* embedding, - int embedding_n, - UMAPParams* params, - float* transformed) -{ - manifold_sparse_inputs_t inputs( - indptr, indices, data, nullptr, nnz, n, d); - manifold_sparse_inputs_t orig_x_inputs( - orig_x_indptr, orig_x_indices, orig_x_data, nullptr, orig_nnz, orig_n, d); - - UMAPAlgo::_transform, TPB_X>( - handle, inputs, orig_x_inputs, embedding, embedding_n, params, transformed); -} - -// Dense fit -void fit(const raft::handle_t& handle, - float* X, // input matrix - float* y, // labels - int n, - int d, - knn_indices_dense_t* knn_indices, - float* knn_dists, - UMAPParams* params, - float* embeddings) +void find_ab(const raft::handle_t& handle, UMAPParams* params) { - if (knn_indices != nullptr && knn_dists != nullptr) { - CUML_LOG_DEBUG("Calling UMAP::fit() with precomputed KNN"); - - manifold_precomputed_knn_inputs_t inputs( - knn_indices, knn_dists, X, y, n, d, params->n_neighbors); - if (y != nullptr) { - UMAPAlgo::_fit_supervised, - TPB_X>(handle, inputs, params, embeddings); - } else { - UMAPAlgo::_fit, - TPB_X>(handle, inputs, params, embeddings); - } - - } else { - manifold_dense_inputs_t inputs(X, y, n, d); - if (y != nullptr) { - UMAPAlgo::_fit_supervised, TPB_X>( - handle, inputs, params, embeddings); - } else { - UMAPAlgo::_fit, TPB_X>( - handle, inputs, params, embeddings); - } - } + cudaStream_t stream = handle.get_stream(); + UMAPAlgo::find_ab(params, stream); } -// get graph std::unique_ptr> get_graph( const raft::handle_t& handle, float* X, // input matrix @@ -147,7 +44,7 @@ std::unique_ptr> get_graph( float* knn_dists, // precomputed distances UMAPParams* params) { - auto cgraph_coo = std::make_unique>(handle.get_stream()); + auto graph = std::make_unique>(handle.get_stream()); if (knn_indices != nullptr && knn_dists != nullptr) { CUML_LOG_DEBUG("Calling UMAP::get_graph() with precomputed KNN"); @@ -157,70 +54,158 @@ std::unique_ptr> get_graph( UMAPAlgo::_get_graph_supervised, - TPB_X>(handle, inputs, params, cgraph_coo.get()); + TPB_X>(handle, inputs, params, graph.get()); } else { UMAPAlgo::_get_graph, - TPB_X>(handle, inputs, params, cgraph_coo.get()); + TPB_X>(handle, inputs, params, graph.get()); } - return cgraph_coo; + return graph; } else { manifold_dense_inputs_t inputs(X, y, n, d); if (y != nullptr) { UMAPAlgo:: _get_graph_supervised, TPB_X>( - handle, inputs, params, cgraph_coo.get()); + handle, inputs, params, graph.get()); } else { UMAPAlgo::_get_graph, TPB_X>( - handle, inputs, params, cgraph_coo.get()); + handle, inputs, params, graph.get()); } - return cgraph_coo; + return graph; } } -// refine void refine(const raft::handle_t& handle, - float* X, // input matrix + float* X, int n, int d, - raft::sparse::COO* cgraph_coo, + raft::sparse::COO* graph, UMAPParams* params, float* embeddings) { CUML_LOG_DEBUG("Calling UMAP::refine() with precomputed KNN"); manifold_dense_inputs_t inputs(X, nullptr, n, d); UMAPAlgo::_refine, TPB_X>( - handle, inputs, params, cgraph_coo, embeddings); + handle, inputs, params, graph, embeddings); +} + +void fit(const raft::handle_t& handle, + float* X, + float* y, + int n, + int d, + knn_indices_dense_t* knn_indices, + float* knn_dists, + UMAPParams* params, + float* embeddings, + raft::sparse::COO* graph) +{ + if (knn_indices != nullptr && knn_dists != nullptr) { + CUML_LOG_DEBUG("Calling UMAP::fit() with precomputed KNN"); + + manifold_precomputed_knn_inputs_t inputs( + knn_indices, knn_dists, X, y, n, d, params->n_neighbors); + if (y != nullptr) { + UMAPAlgo::_fit_supervised, + TPB_X>(handle, inputs, params, embeddings, graph); + } else { + UMAPAlgo::_fit, + TPB_X>(handle, inputs, params, embeddings, graph); + } + + } else { + manifold_dense_inputs_t inputs(X, y, n, d); + if (y != nullptr) { + UMAPAlgo::_fit_supervised, TPB_X>( + handle, inputs, params, embeddings, graph); + } else { + UMAPAlgo::_fit, TPB_X>( + handle, inputs, params, embeddings, graph); + } + } } -// Sparse fit void fit_sparse(const raft::handle_t& handle, - int* indptr, // input matrix + int* indptr, int* indices, float* data, size_t nnz, float* y, - int n, // rows - int d, // cols + int n, + int d, UMAPParams* params, - float* embeddings) + float* embeddings, + raft::sparse::COO* graph) { manifold_sparse_inputs_t inputs(indptr, indices, data, y, nnz, n, d); if (y != nullptr) { UMAPAlgo:: _fit_supervised, TPB_X>( - handle, inputs, params, embeddings); + handle, inputs, params, embeddings, graph); } else { UMAPAlgo::_fit, TPB_X>( - handle, inputs, params, embeddings); + handle, inputs, params, embeddings, graph); } } -void find_ab(const raft::handle_t& handle, UMAPParams* params) +void transform(const raft::handle_t& handle, + float* X, + int n, + int d, + knn_indices_dense_t* knn_indices, + float* knn_dists, + float* orig_X, + int orig_n, + float* embedding, + int embedding_n, + UMAPParams* params, + float* transformed) { - cudaStream_t stream = handle.get_stream(); - UMAPAlgo::find_ab(params, stream); + if (knn_indices != nullptr && knn_dists != nullptr) { + manifold_precomputed_knn_inputs_t inputs( + knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors); + UMAPAlgo::_transform, + TPB_X>( + handle, inputs, inputs, embedding, embedding_n, params, transformed); + } else { + manifold_dense_inputs_t inputs(X, nullptr, n, d); + manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); + UMAPAlgo::_transform, TPB_X>( + handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); + } +} + +void transform_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + size_t nnz, + int n, + int d, + int* orig_x_indptr, + int* orig_x_indices, + float* orig_x_data, + size_t orig_nnz, + int orig_n, + float* embedding, + int embedding_n, + UMAPParams* params, + float* transformed) +{ + manifold_sparse_inputs_t inputs( + indptr, indices, data, nullptr, nnz, n, d); + manifold_sparse_inputs_t orig_x_inputs( + orig_x_indptr, orig_x_indices, orig_x_data, nullptr, orig_nnz, orig_n, d); + + UMAPAlgo::_transform, TPB_X>( + handle, inputs, orig_x_inputs, embedding, embedding_n, params, transformed); } } // namespace UMAP diff --git a/cpp/test/sg/umap_parametrizable_test.cu b/cpp/test/sg/umap_parametrizable_test.cu index dd3512cefe..cf90cba1b4 100644 --- a/cpp/test/sg/umap_parametrizable_test.cu +++ b/cpp/test/sg/umap_parametrizable_test.cu @@ -169,9 +169,19 @@ class UMAPParametrizableTest : public ::testing::Test { handle.sync_stream(stream); + auto graph = raft::sparse::COO(stream); + if (test_params.supervised) { - ML::UMAP::fit( - handle, X, y, n_samples, n_features, knn_indices, knn_dists, &umap_params, model_embedding); + ML::UMAP::fit(handle, + X, + y, + n_samples, + n_features, + knn_indices, + knn_dists, + &umap_params, + model_embedding, + &graph); } else { ML::UMAP::fit(handle, X, @@ -181,7 +191,8 @@ class UMAPParametrizableTest : public ::testing::Test { knn_indices, knn_dists, &umap_params, - model_embedding); + model_embedding, + &graph); } if (test_params.refine) { diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx index a0d15ae3d8..7334529a30 100644 --- a/python/cuml/manifold/simpl_set.pyx +++ b/python/cuml/manifold/simpl_set.pyx @@ -20,6 +20,7 @@ import numpy as np import cupy as cp from cuml.manifold.umap_utils cimport * +from cuml.manifold.umap_utils import GraphHolder import cuml.internals from cuml.common.base import Base @@ -27,24 +28,14 @@ from cuml.common.input_utils import input_to_cuml_array from cuml.common.array import CumlArray from cuml.manifold.umap import UMAP -from rmm._lib.cuda_stream_view cimport cuda_stream_view from raft.common.handle cimport handle_t from raft.common.handle import Handle from libc.stdint cimport uintptr_t +from libc.stdlib cimport free from libcpp.memory cimport unique_ptr -cdef extern from "raft/sparse/coo.hpp": - cdef cppclass COO "raft::sparse::COO": - COO(cuda_stream_view stream) - void allocate(int nnz, int size, bool init, cuda_stream_view stream) - int nnz - float* vals() - int* rows() - int* cols() - - cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": unique_ptr[COO] get_graph(handle_t &handle, @@ -65,25 +56,6 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": float* embeddings) -cdef class GraphHolder: - cdef unique_ptr[COO] c_graph - - def vals(self): - return self.c_graph.get().vals() - - def rows(self): - return self.c_graph.get().rows() - - def cols(self): - return self.c_graph.get().cols() - - def get_nnz(self): - return self.c_graph.get().nnz - - def __dealloc__(self): - self.c_graph.reset(NULL) - - def fuzzy_simplicial_set(X, n_neighbors, random_state=None, @@ -198,8 +170,7 @@ def fuzzy_simplicial_set(X, handle = Handle() cdef handle_t* handle_ = handle.getHandle() - fss_graph = GraphHolder() - fss_graph.c_graph = get_graph( + cdef unique_ptr[COO] fss_graph_ptr = get_graph( handle_[0], X_ptr, NULL, @@ -208,26 +179,13 @@ def fuzzy_simplicial_set(X, knn_indices_ptr, knn_dists_ptr, umap_params) + fss_graph = GraphHolder.from_ptr(fss_graph_ptr) - cdef int nnz = fss_graph.get_nnz() - coo_arrays = [] - for ptr, dtype in [(fss_graph.vals(), np.float32), - (fss_graph.rows(), np.int32), - (fss_graph.cols(), np.int32)]: - mem = cp.cuda.UnownedMemory(ptr=ptr, - size=nnz * np.dtype(dtype).itemsize, - owner=fss_graph, - device_id=-1) - memptr = cp.cuda.memory.MemoryPointer(mem, 0) - coo_arrays.append(cp.ndarray(nnz, dtype=dtype, memptr=memptr)) + free(umap_params) - result = cp.sparse.coo_matrix(((coo_arrays[0], - (coo_arrays[1], coo_arrays[2])))) + return fss_graph.get_cupy_coo() - return result - -@cuml.internals.api_return_generic() def simplicial_set_embedding( data, graph, @@ -358,27 +316,9 @@ def simplicial_set_embedding( handle = Handle() cdef handle_t* handle_ = handle.getHandle() - cdef COO* fss_graph = new COO(handle_.get_stream()) - fss_graph.allocate(graph.nnz, graph.shape[0], False, handle_.get_stream()) - handle_.sync_stream() - - nnz = graph.nnz - for dst_ptr, src_ptr, dtype in \ - [(fss_graph.vals(), graph.data.data.ptr, np.float32), - (fss_graph.rows(), graph.row.data.ptr, np.int32), - (fss_graph.cols(), graph.col.data.ptr, np.int32)]: - itemsize = np.dtype(dtype).itemsize - dest_buff = cp.cuda.UnownedMemory(ptr=dst_ptr, - size=nnz * itemsize, - owner=None, - device_id=-1) - dest_mptr = cp.cuda.memory.MemoryPointer(dest_buff, 0) - src_buff = cp.cuda.UnownedMemory(ptr=src_ptr, - size=nnz * itemsize, - owner=None, - device_id=-1) - src_mptr = cp.cuda.memory.MemoryPointer(src_buff, 0) - dest_mptr.copy_from_device(src_mptr, nnz * itemsize) + cdef GraphHolder fss_graph = GraphHolder.from_coo_array(GraphHolder(), + handle, + graph) embedding = CumlArray.zeros((X_m.shape[0], n_components), order="C", dtype=np.float32, @@ -388,9 +328,10 @@ def simplicial_set_embedding( X_m.ptr, X_m.shape[0], X_m.shape[1], - fss_graph, + fss_graph.get(), umap_params, embedding.ptr) - cuml.internals.set_api_output_type("cupy") + free(umap_params) + return embedding diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 19b25ac8d2..9c07c70555 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -31,6 +31,7 @@ import cupyx import numba.cuda as cuda from cuml.manifold.umap_utils cimport * +from cuml.manifold.umap_utils import GraphHolder from cuml.common.sparsefuncs import extract_knn_graph from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix,\ @@ -58,7 +59,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor import rmm from libc.stdint cimport uintptr_t -from libc.stdlib cimport calloc, malloc, free +from libc.stdlib cimport free from libcpp.memory cimport shared_ptr @@ -75,7 +76,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": int64_t * knn_indices, float * knn_dists, UMAPParams * params, - float * embeddings) except + + float * embeddings, + COO * graph) except + void fit_sparse(handle_t &handle, int *indptr, @@ -86,7 +88,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": int n, int d, UMAPParams *params, - float *embeddings) except + + float *embeddings, + COO * graph) except + void transform(handle_t & handle, float * X, @@ -546,6 +549,7 @@ class UMAP(Base, else None)) y_raw = y_m.ptr + fss_graph = GraphHolder.new_graph(handle_.get_stream()) if self.sparse_fit: fit_sparse(handle_[0], self.X_m.indptr.ptr, @@ -556,7 +560,8 @@ class UMAP(Base, self.n_rows, self.n_dims, umap_params, - embed_raw) + embed_raw, + fss_graph.get()) else: fit(handle_[0], @@ -567,7 +572,10 @@ class UMAP(Base, knn_indices_raw, knn_dists_raw, umap_params, - embed_raw) + embed_raw, + fss_graph.get()) + + self.graph_ = fss_graph.get_cupy_coo() self.handle.sync() diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd index 5915a6bf87..10d3f3e80c 100644 --- a/python/cuml/manifold/umap_utils.pxd +++ b/python/cuml/manifold/umap_utils.pxd @@ -16,9 +16,12 @@ # distutils: language = c++ -from libc.stdint cimport uint64_t -from libc.stdint cimport int64_t +from rmm._lib.cuda_stream_view cimport cuda_stream_view +from libcpp.memory cimport unique_ptr + +from libc.stdint cimport uint64_t, uintptr_t, int64_t from libcpp cimport bool +from libcpp.memory cimport shared_ptr cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": @@ -57,3 +60,30 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML": bool deterministic, int optim_batch_size, GraphBasedDimRedCallback * callback + +cdef extern from "raft/sparse/coo.hpp": + cdef cppclass COO "raft::sparse::COO": + COO(cuda_stream_view stream) + void allocate(int nnz, int size, bool init, cuda_stream_view stream) + int nnz + float* vals() + int* rows() + int* cols() + +cdef class GraphHolder: + cdef unique_ptr[COO] c_graph + + @staticmethod + cdef GraphHolder new_graph(cuda_stream_view stream) + + @staticmethod + cdef GraphHolder from_ptr(unique_ptr[COO]& ptr) + + @staticmethod + cdef GraphHolder from_coo_array(graph, handle, coo_array) + + cdef COO* get(GraphHolder self) + cdef uintptr_t vals(GraphHolder self) + cdef uintptr_t rows(GraphHolder self) + cdef uintptr_t cols(GraphHolder self) + cdef uint64_t get_nnz(GraphHolder self) diff --git a/python/cuml/manifold/umap_utils.pyx b/python/cuml/manifold/umap_utils.pyx new file mode 100644 index 0000000000..e7f10bbd53 --- /dev/null +++ b/python/cuml/manifold/umap_utils.pyx @@ -0,0 +1,102 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + +from raft.common.handle cimport handle_t +from cuml.manifold.umap_utils cimport * +from libcpp.utility cimport move +import numpy as np +import cupy as cp + + +cdef class GraphHolder: + @staticmethod + cdef GraphHolder new_graph(cuda_stream_view stream): + cdef GraphHolder graph = GraphHolder.__new__(GraphHolder) + graph.c_graph.reset(new COO(stream)) + return graph + + @staticmethod + cdef GraphHolder from_ptr(unique_ptr[COO]& ptr): + cdef GraphHolder graph = GraphHolder.__new__(GraphHolder) + graph.c_graph = move(ptr) + return graph + + @staticmethod + cdef GraphHolder from_coo_array(graph, handle, coo_array): + def copy_from_array(dst_raft_coo_ptr, src_cp_coo): + size = src_cp_coo.size + itemsize = np.dtype(src_cp_coo.dtype).itemsize + dest_buff = cp.cuda.UnownedMemory(ptr=dst_raft_coo_ptr, + size=size * itemsize, + owner=None, + device_id=-1) + dest_mptr = cp.cuda.memory.MemoryPointer(dest_buff, 0) + src_buff = cp.cuda.UnownedMemory(ptr=src_cp_coo.data.ptr, + size=size * itemsize, + owner=None, + device_id=-1) + src_mptr = cp.cuda.memory.MemoryPointer(src_buff, 0) + dest_mptr.copy_from_device(src_mptr, size * itemsize) + + cdef handle_t* handle_ = handle.getHandle() + graph.c_graph.reset(new COO(handle_.get_stream())) + graph.get().allocate(coo_array.nnz, + coo_array.shape[0], + False, + handle_.get_stream()) + handle_.sync_stream() + + copy_from_array(graph.vals(), coo_array.data.astype('float32')) + copy_from_array(graph.rows(), coo_array.row.astype('int32')) + copy_from_array(graph.cols(), coo_array.col.astype('int32')) + + return graph + + cdef inline COO* get(self): + return self.c_graph.get() + + cdef uintptr_t vals(self): + return self.get().vals() + + cdef uintptr_t rows(self): + return self.get().rows() + + cdef uintptr_t cols(self): + return self.get().cols() + + cdef uint64_t get_nnz(self): + return self.get().nnz + + def get_cupy_coo(self): + def create_nonowning_cp_array(ptr, dtype): + mem = cp.cuda.UnownedMemory(ptr=ptr, + size=(self.get_nnz() * + np.dtype(dtype).itemsize), + owner=self, + device_id=-1) + memptr = cp.cuda.memory.MemoryPointer(mem, 0) + return cp.ndarray(self.get_nnz(), dtype=dtype, memptr=memptr) + + vals = create_nonowning_cp_array(self.vals(), np.float32) + rows = create_nonowning_cp_array(self.rows(), np.int32) + cols = create_nonowning_cp_array(self.cols(), np.int32) + + return cp.sparse.coo_matrix(((vals, (rows, cols)))) + + def __dealloc__(self): + self.c_graph.reset(NULL) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index cd7b5c3bf6..84c2774c3f 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -24,6 +24,7 @@ import cupyx import scipy.sparse +import cupy as cp from cuml.manifold.umap import UMAP as cuUMAP from cuml.testing.utils import array_equal, unit_param, \ @@ -530,3 +531,68 @@ def test_equality(e1, e2): test_equality(embedding3, embedding4) test_equality(embedding5, embedding6) test_equality(embedding6, embedding7) + + +def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): + n_ref_zeros = (a == 0).sum() + n_ref_non_zero_elms = a.size - n_ref_zeros + n_correct = (cp.abs(a - b) <= (atol + rtol * cp.abs(b))).sum() + correctness = (n_correct - n_ref_zeros) / n_ref_non_zero_elms + return correctness >= threshold + + +@pytest.mark.parametrize('n_rows', [800, 5000]) +@pytest.mark.parametrize('n_features', [8, 32]) +@pytest.mark.parametrize('n_neighbors', [8, 16]) +@pytest.mark.parametrize('precomputed_nearest_neighbors', [False, True]) +def test_fuzzy_simplicial_set(n_rows, + n_features, + n_neighbors, + precomputed_nearest_neighbors): + n_clusters = 30 + random_state = 42 + metric = 'euclidean' + + X, _ = make_blobs(n_samples=n_rows, centers=n_clusters, + n_features=n_features, random_state=random_state) + + if precomputed_nearest_neighbors: + nn = NearestNeighbors(n_neighbors=n_neighbors, + metric=metric) + nn.fit(X) + knn_dists, knn_indices = nn.kneighbors(X, + n_neighbors, + return_distance=True) + knn_graph = nn.kneighbors_graph(X, mode="distance") + model = cuUMAP(n_neighbors=n_neighbors) + model.fit(X, + knn_graph=knn_graph) + cu_fss_graph = model.graph_ + + knn_indices = knn_indices + knn_dists = knn_dists + ref_fss_graph = umap.umap_.fuzzy_simplicial_set( + X, + n_neighbors, + random_state, + metric, + knn_indices=knn_indices, + knn_dists=knn_dists)[0].tocoo() + else: + model = cuUMAP(n_neighbors=n_neighbors) + model.fit(X) + cu_fss_graph = model.graph_ + + ref_fss_graph = umap.umap_.fuzzy_simplicial_set( + X, + n_neighbors, + random_state, + metric)[0].tocoo() + + cu_fss_graph = cu_fss_graph.todense() + ref_fss_graph = cp.sparse.coo_matrix(ref_fss_graph).todense() + assert correctness_sparse(ref_fss_graph, + cu_fss_graph, + atol=0.1, + rtol=0.2, + threshold=0.95) From 3ed11996f17c91018ee9e7c4d4010675d077568e Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 12 May 2022 14:40:17 +0200 Subject: [PATCH 09/12] import simplicial set functions from umap --- python/cuml/manifold/simpl_set.pyx | 5 ++--- python/cuml/manifold/umap.pyx | 26 ++++---------------------- python/cuml/manifold/umap_utils.pyx | 25 +++++++++++++++++++++++++ python/cuml/tests/test_simpl_set.py | 4 ++-- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx index 7334529a30..b1ea268900 100644 --- a/python/cuml/manifold/simpl_set.pyx +++ b/python/cuml/manifold/simpl_set.pyx @@ -20,13 +20,12 @@ import numpy as np import cupy as cp from cuml.manifold.umap_utils cimport * -from cuml.manifold.umap_utils import GraphHolder +from cuml.manifold.umap_utils import GraphHolder, find_ab_params import cuml.internals from cuml.common.base import Base from cuml.common.input_utils import input_to_cuml_array from cuml.common.array import CumlArray -from cuml.manifold.umap import UMAP from raft.common.handle cimport handle_t from raft.common.handle import Handle @@ -272,7 +271,7 @@ def simplicial_set_embedding( if a is None or b is None: spread = 1.0 min_dist = 0.1 - a, b = UMAP.find_ab_params(spread, min_dist) + a, b = find_ab_params(spread, min_dist) deterministic = random_state is not None if not isinstance(random_state, int): diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 9c07c70555..01e92c8975 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -31,7 +31,7 @@ import cupyx import numba.cuda as cuda from cuml.manifold.umap_utils cimport * -from cuml.manifold.umap_utils import GraphHolder +from cuml.manifold.umap_utils import GraphHolder, find_ab_params from cuml.common.sparsefuncs import extract_knn_graph from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix,\ @@ -51,6 +51,8 @@ from cuml.common.array_sparse import SparseCumlArray from cuml.common.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse +from cuml.manifold.simpl_set import fuzzy_simplicial_set, simplicial_set_embedding + if has_scipy(True): import scipy.sparse @@ -435,27 +437,7 @@ class UMAP(Base, @staticmethod def find_ab_params(spread, min_dist): - """ Function taken from UMAP-learn : https://github.com/lmcinnes/umap - Fit a, b params for the differentiable curve used in lower - dimensional fuzzy simplicial complex construction. We want the - smooth curve (from a pre-defined family with simple gradient) that - best matches an offset exponential decay. - """ - - def curve(x, a, b): - return 1.0 / (1.0 + a * x ** (2 * b)) - - if has_scipy(): - from scipy.optimize import curve_fit - else: - raise RuntimeError('Scipy is needed to run find_ab_params') - - xv = np.linspace(0, spread * 3, 300) - yv = np.zeros(xv.shape) - yv[xv < min_dist] = 1.0 - yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread) - params, covar = curve_fit(curve, xv, yv) - return params[0], params[1] + return find_ab_params(spread, min_dist) @generate_docstring(convert_dtype_cast='np.float32', X='dense_sparse', diff --git a/python/cuml/manifold/umap_utils.pyx b/python/cuml/manifold/umap_utils.pyx index e7f10bbd53..499aea605c 100644 --- a/python/cuml/manifold/umap_utils.pyx +++ b/python/cuml/manifold/umap_utils.pyx @@ -100,3 +100,28 @@ cdef class GraphHolder: def __dealloc__(self): self.c_graph.reset(NULL) + + +def find_ab_params(spread, min_dist): + """ Function taken from UMAP-learn : https://github.com/lmcinnes/umap + Fit a, b params for the differentiable curve used in lower + dimensional fuzzy simplicial complex construction. We want the + smooth curve (from a pre-defined family with simple gradient) that + best matches an offset exponential decay. + """ + + def curve(x, a, b): + return 1.0 / (1.0 + a * x ** (2 * b)) + + from cuml.common.import_utils import has_scipy + if has_scipy(): + from scipy.optimize import curve_fit + else: + raise RuntimeError('Scipy is needed to run find_ab_params') + + xv = np.linspace(0, spread * 3, 300) + yv = np.zeros(xv.shape) + yv[xv < min_dist] = 1.0 + yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread) + params, covar = curve_fit(curve, xv, yv) + return params[0], params[1] diff --git a/python/cuml/tests/test_simpl_set.py b/python/cuml/tests/test_simpl_set.py index ebcf1d4727..e7ae152d96 100644 --- a/python/cuml/tests/test_simpl_set.py +++ b/python/cuml/tests/test_simpl_set.py @@ -22,10 +22,10 @@ from cuml.neighbors import NearestNeighbors from umap.umap_ import fuzzy_simplicial_set as ref_fuzzy_simplicial_set -from cuml.manifold.simpl_set import fuzzy_simplicial_set \ +from cuml.manifold.umap import fuzzy_simplicial_set \ as cu_fuzzy_simplicial_set from umap.umap_ import simplicial_set_embedding as ref_simplicial_set_embedding -from cuml.manifold.simpl_set import simplicial_set_embedding \ +from cuml.manifold.umap import simplicial_set_embedding \ as cu_simplicial_set_embedding From f6c36c4b046e33822331d63fcb1877337c602f6f Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 13 May 2022 15:01:44 +0200 Subject: [PATCH 10/12] Fix check style --- python/cuml/manifold/umap.pyx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 01e92c8975..801de7e656 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -51,7 +51,8 @@ from cuml.common.array_sparse import SparseCumlArray from cuml.common.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse -from cuml.manifold.simpl_set import fuzzy_simplicial_set, simplicial_set_embedding +from cuml.manifold.simpl_set import fuzzy_simplicial_set, \ + simplicial_set_embedding if has_scipy(True): import scipy.sparse From dacd29274f8ab3fb18a1c98e8dde673a19aeaa9c Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 23 May 2022 19:00:12 +0200 Subject: [PATCH 11/12] Fix fit_supervised --- cpp/src/umap/runner.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index bbbc48d1b6..02f970d348 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -297,7 +297,8 @@ void _fit_supervised(const raft::handle_t& handle, cudaStream_t stream = handle.get_stream(); ML::Logger::get().setLevel(params->verbosity); - UMAPAlgo::_get_graph(handle, inputs, params, graph); + UMAPAlgo::_get_graph_supervised( + handle, inputs, params, graph); /** * Initialize embeddings From e71a558543c08ffd0238c19f387d6bc7ad92bf95 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 24 May 2022 11:09:39 +0200 Subject: [PATCH 12/12] Remove calls to numba compiled UMAP functions + remove simpl_set testing --- python/cuml/tests/test_simpl_set.py | 188 ---------------------------- python/cuml/tests/test_umap.py | 16 +-- 2 files changed, 2 insertions(+), 202 deletions(-) delete mode 100644 python/cuml/tests/test_simpl_set.py diff --git a/python/cuml/tests/test_simpl_set.py b/python/cuml/tests/test_simpl_set.py deleted file mode 100644 index e7ae152d96..0000000000 --- a/python/cuml/tests/test_simpl_set.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import pytest -from cuml.datasets import make_blobs -import numpy as np -import cupy as cp -import umap.distances as dist -from cuml.manifold.umap import UMAP -from cuml.neighbors import NearestNeighbors - -from umap.umap_ import fuzzy_simplicial_set as ref_fuzzy_simplicial_set -from cuml.manifold.umap import fuzzy_simplicial_set \ - as cu_fuzzy_simplicial_set -from umap.umap_ import simplicial_set_embedding as ref_simplicial_set_embedding -from cuml.manifold.umap import simplicial_set_embedding \ - as cu_simplicial_set_embedding - - -def correctness_dense(a, b, rtol=0.1, threshold=0.95): - n_elms = a.size - n_correct = (cp.abs(a - b) <= (rtol * cp.abs(b))).sum() - correctness = n_correct / n_elms - return correctness >= threshold - - -def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): - n_ref_zeros = (a == 0).sum() - n_ref_non_zero_elms = a.size - n_ref_zeros - n_correct = (cp.abs(a - b) <= (atol + rtol * cp.abs(b))).sum() - correctness = (n_correct - n_ref_zeros) / n_ref_non_zero_elms - return correctness >= threshold - - -@pytest.mark.parametrize('n_rows', [800, 5000]) -@pytest.mark.parametrize('n_features', [8, 32]) -@pytest.mark.parametrize('n_neighbors', [8, 16]) -@pytest.mark.parametrize('precomputed_nearest_neighbors', [False, True]) -def test_fuzzy_simplicial_set(n_rows, - n_features, - n_neighbors, - precomputed_nearest_neighbors): - n_clusters = 30 - random_state = 42 - metric = 'euclidean' - - X, _ = make_blobs(n_samples=n_rows, centers=n_clusters, - n_features=n_features, random_state=random_state) - - if precomputed_nearest_neighbors: - nn = NearestNeighbors(n_neighbors=n_neighbors, - metric=metric) - nn.fit(X) - knn_dists, knn_indices = nn.kneighbors(X, - n_neighbors, - return_distance=True) - cu_fss_graph = cu_fuzzy_simplicial_set( - X, - n_neighbors, - random_state, - metric, - knn_indices=knn_indices, - knn_dists=knn_dists) - - knn_indices = knn_indices.get() - knn_dists = knn_dists.get() - ref_fss_graph = ref_fuzzy_simplicial_set( - X, - n_neighbors, - random_state, - metric, - knn_indices=knn_indices, - knn_dists=knn_dists)[0].tocoo() - else: - cu_fss_graph = cu_fuzzy_simplicial_set( - X, - n_neighbors, - random_state, - metric) - - X = X.get() - ref_fss_graph = ref_fuzzy_simplicial_set( - X, - n_neighbors, - random_state, - metric)[0].tocoo() - - cu_fss_graph = cu_fss_graph.todense() - ref_fss_graph = cp.sparse.coo_matrix(ref_fss_graph).todense() - assert correctness_sparse(ref_fss_graph, - cu_fss_graph, - atol=0.1, - rtol=0.2, - threshold=0.95) - - -@pytest.mark.parametrize('n_rows', [800, 5000]) -@pytest.mark.parametrize('n_features', [8, 32]) -@pytest.mark.parametrize('n_neighbors', [8, 16]) -@pytest.mark.parametrize('n_components', [2, 5]) -def test_simplicial_set_embedding(n_rows, - n_features, - n_neighbors, - n_components): - n_clusters = 30 - random_state = 42 - metric = 'euclidean' - initial_alpha = 1.0 - a, b = UMAP.find_ab_params(1.0, 0.1) - gamma = 0 - negative_sample_rate = 5 - n_epochs = 500 - init = 'random' - metric = 'euclidean' - metric_kwds = {} - densmap = False - densmap_kwds = {} - output_dens = False - output_metric = 'euclidean' - output_metric_kwds = {} - - X, _ = make_blobs(n_samples=n_rows, centers=n_clusters, - n_features=n_features, random_state=random_state) - X = X.get() - - ref_fss_graph = ref_fuzzy_simplicial_set(X, - n_neighbors, - random_state, - metric)[0] - ref_embedding = ref_simplicial_set_embedding( - X, - ref_fss_graph, - n_components, - initial_alpha, - a, - b, - gamma, - negative_sample_rate, - n_epochs, - init, - np.random.RandomState(random_state), - dist.named_distances_with_gradients[metric], - metric_kwds, - densmap, - densmap_kwds, - output_dens, - output_metric=output_metric, - output_metric_kwds=output_metric_kwds)[0] - - cu_fss_graph = cu_fuzzy_simplicial_set(X, - n_neighbors, - random_state, - metric) - - cu_embedding = cu_simplicial_set_embedding( - X, - cu_fss_graph, - n_components, - initial_alpha, - a, - b, - gamma, - negative_sample_rate, - n_epochs, - init, - random_state, - metric, - metric_kwds, - output_metric=output_metric, - output_metric_kwds=output_metric_kwds) - - ref_embedding = cp.array(ref_embedding) - assert correctness_dense(ref_embedding, - cu_embedding, - rtol=0.1, - threshold=0.95) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 84c2774c3f..e7dbf83748 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -571,26 +571,14 @@ def test_fuzzy_simplicial_set(n_rows, knn_indices = knn_indices knn_dists = knn_dists - ref_fss_graph = umap.umap_.fuzzy_simplicial_set( - X, - n_neighbors, - random_state, - metric, - knn_indices=knn_indices, - knn_dists=knn_dists)[0].tocoo() + else: model = cuUMAP(n_neighbors=n_neighbors) model.fit(X) cu_fss_graph = model.graph_ - ref_fss_graph = umap.umap_.fuzzy_simplicial_set( - X, - n_neighbors, - random_state, - metric)[0].tocoo() - cu_fss_graph = cu_fss_graph.todense() - ref_fss_graph = cp.sparse.coo_matrix(ref_fss_graph).todense() + ref_fss_graph = cu_fss_graph assert correctness_sparse(ref_fss_graph, cu_fss_graph, atol=0.1,