diff --git a/ci/test_wheel.sh b/ci/test_wheel.sh index de6144c84b..52b83ec659 100755 --- a/ci/test_wheel.sh +++ b/ci/test_wheel.sh @@ -13,7 +13,7 @@ if [[ "$(arch)" == "aarch64" ]]; then fi # Always install latest dask for testing -python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10 +python -m pip install git+https://github.com/dask/dask.git@2023.7.1 git+https://github.com/dask/distributed.git@2023.7.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.10 # echo to expand wildcard before adding `[extra]` requires for pip python -m pip install $(echo ./dist/cuml*.whl)[test] diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index aba757bf3d..2592bc6969 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -15,13 +15,13 @@ dependencies: - cudf==23.10.* - cupy>=12.0.0 - cxx-compiler -- cython>=0.29,<0.30 -- dask-core>=2023.5.1 +- cython>=3.0.0 +- dask-core==2023.7.1 - dask-cuda==23.10.* - dask-cudf==23.10.* - dask-ml -- dask>=2023.5.1 -- distributed>=2023.5.1 +- dask==2023.7.1 +- distributed==2023.7.1 - doxygen=1.8.20 - gcc_linux-64=11.* - gmock>=1.13.0 diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml index 49992fb455..62e9025eed 100644 --- a/conda/environments/all_cuda-120_arch-x86_64.yaml +++ b/conda/environments/all_cuda-120_arch-x86_64.yaml @@ -17,13 +17,13 @@ dependencies: - cudf==23.10.* - cupy>=12.0.0 - cxx-compiler -- cython>=0.29,<0.30 -- dask-core>=2023.5.1 +- cython>=3.0.0 +- dask-core==2023.7.1 - dask-cuda==23.10.* - dask-cudf==23.10.* - dask-ml -- dask>=2023.5.1 -- distributed>=2023.5.1 +- dask==2023.7.1 +- distributed==2023.7.1 - doxygen=1.8.20 - gcc_linux-64=11.* - gmock>=1.13.0 diff --git a/conda/recipes/cuml/meta.yaml b/conda/recipes/cuml/meta.yaml index 438f4b74e9..e01cff0430 100644 --- a/conda/recipes/cuml/meta.yaml +++ b/conda/recipes/cuml/meta.yaml @@ -59,7 +59,7 @@ requirements: - cuda-python ==12.0.0 {% endif %} - cudf ={{ minor_version }} - - cython >=0.29,<0.30 + - cython >=3.0.0 - libcuml ={{ version }} - libcumlprims ={{ minor_version }} - pylibraft ={{ minor_version }} @@ -76,9 +76,9 @@ requirements: - cudf ={{ minor_version }} - cupy >=12.0.0 - dask-cudf ={{ minor_version }} - - dask >=2023.5.1 - - dask-core>=2023.5.1 - - distributed >=2023.5.1 + - dask ==2023.7.1 + - dask-core==2023.7.1 + - distributed ==2023.7.1 - joblib >=0.11 - libcuml ={{ version }} - libcumlprims ={{ minor_version }} diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index bacf1a6a79..68d9dda1a3 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -71,6 +71,9 @@ inline void launcher(const raft::handle_t& handle, out.knn_indices, out.knn_dists, n_neighbors, + true, + true, + static_cast*>(nullptr), params->metric, params->p); } diff --git a/dependencies.yaml b/dependencies.yaml index 9714750ca5..659743a1e5 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -151,7 +151,7 @@ dependencies: - output_types: [conda, requirements, pyproject] packages: - scikit-build>=0.13.1 - - cython>=0.29,<0.30 + - cython>=3.0.0 - &treelite treelite==3.2.0 - pylibraft==23.10.* - rmm==23.10.* @@ -175,10 +175,10 @@ dependencies: - output_types: [conda, requirements, pyproject] packages: - cudf==23.10.* - - dask>=2023.5.1 + - dask==2023.7.1 - dask-cuda==23.10.* - dask-cudf==23.10.* - - distributed>=2023.5.1 + - distributed==2023.7.1 - joblib>=0.11 - numba>=0.57 # TODO: Is scipy really a hard dependency, or should @@ -192,7 +192,7 @@ dependencies: - cupy>=12.0.0 - output_types: conda packages: - - dask-core>=2023.5.1 + - dask-core==2023.7.1 - output_types: pyproject packages: - *treelite_runtime @@ -360,9 +360,11 @@ dependencies: common: - output_types: [conda, requirements] packages: + - dask-ml==2023.3.24 - jupyter - matplotlib - numpy - pandas - *scikit_learn - seaborn + diff --git a/python/README.md b/python/README.md index e39a8d3949..1ed0cdf2f2 100644 --- a/python/README.md +++ b/python/README.md @@ -70,8 +70,8 @@ Packages required for multigpu algorithms*: - ucx-py version matching the cuML version - dask-cudf version matching the cuML version - nccl>=2.5 -- dask>=2023.5.1 -- distributed>=2023.5.1 +- dask==2023.7.1 +- distributed==2023.7.1 * this can be avoided with `--singlegpu` argument flag. diff --git a/python/cuml/internals/base_return_types.py b/python/cuml/internals/base_return_types.py index b5d952ba20..5aa0d7f75d 100644 --- a/python/cuml/internals/base_return_types.py +++ b/python/cuml/internals/base_return_types.py @@ -52,7 +52,10 @@ def _get_base_return_type(class_name, attr): # A NameError is raised if the return type is the same as the # type being defined (which is incomplete). Check that here and # return base if the name matches - if attr.__annotations__["return"] == class_name: + # Cython 3 changed to preferring types rather than strings for + # annotations. Strings end up wrapped in an extra layer of quotes, + # which we have to replace here. + if attr.__annotations__["return"].replace("'", "") == class_name: return "base" except Exception: assert False, "Shouldn't get here" diff --git a/python/cuml/internals/logger.pyx b/python/cuml/internals/logger.pyx index 1ced439822..0acc141f72 100644 --- a/python/cuml/internals/logger.pyx +++ b/python/cuml/internals/logger.pyx @@ -29,8 +29,8 @@ cdef extern from "cuml/common/logger.hpp" namespace "ML" nogil: Logger& get() void setLevel(int level) void setPattern(const string& pattern) - void setCallback(void(*callback)(int, char*)) - void setFlush(void(*flush)()) + void setCallback(void(*callback)(int, const char*) except *) + void setFlush(void(*flush)() except *) bool shouldLogFor(int level) const int getLevel() const string getPattern() const diff --git a/python/cuml/manifold/simpl_set.pyx b/python/cuml/manifold/simpl_set.pyx index a22f4da38a..d0f30e3e88 100644 --- a/python/cuml/manifold/simpl_set.pyx +++ b/python/cuml/manifold/simpl_set.pyx @@ -22,7 +22,8 @@ from cuml.internals.safe_imports import gpu_only_import cp = gpu_only_import('cupy') from cuml.manifold.umap_utils cimport * -from cuml.manifold.umap_utils import GraphHolder, find_ab_params +from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \ + metric_parsing from cuml.internals.input_utils import input_to_cuml_array from cuml.internals.array import CumlArray @@ -82,10 +83,17 @@ def fuzzy_simplicial_set(X, structure to the detriment of the larger picture. random_state: numpy RandomState or equivalent A state capable being used as a numpy random state. - metric: string or function (optional, default 'euclidean') - unused - metric_kwds: dict (optional, default {}) - unused + metric: string (default='euclidean'). + Distance metric to use. Supported distances are ['l1, 'cityblock', + 'taxicab', 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'canberra', + 'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger', + 'hamming', 'jaccard'] + Metrics that take arguments (such as minkowski) can have arguments + passed via the metric_kwds dictionary. + Note: The 'jaccard' distance metric is only supported for sparse + inputs. + metric_kwds: dict (optional, default=None) + Metric argument knn_indices: array of shape (n_samples, n_neighbors) (optional) If the k-nearest neighbors of each point has already been calculated you can pass them in here to save computation time. This should be @@ -138,6 +146,14 @@ def fuzzy_simplicial_set(X, umap_params.deterministic = deterministic umap_params.set_op_mix_ratio = set_op_mix_ratio umap_params.local_connectivity = local_connectivity + try: + umap_params.metric = metric_parsing[metric.lower()] + except KeyError: + raise ValueError(f"Invalid value for metric: {metric}") + if metric_kwds is None: + umap_params.p = 2.0 + else: + umap_params.p = metric_kwds.get("p", 2.0) umap_params.verbosity = verbose X_m, _, _, _ = \ @@ -245,10 +261,17 @@ def simplicial_set_embedding( * A numpy array of initial embedding positions. random_state: numpy RandomState or equivalent A state capable being used as a numpy random state. - metric: string or callable - unused - metric_kwds: dict - unused + metric: string (default='euclidean'). + Distance metric to use. Supported distances are ['l1, 'cityblock', + 'taxicab', 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'canberra', + 'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger', + 'hamming', 'jaccard'] + Metrics that take arguments (such as minkowski) can have arguments + passed via the metric_kwds dictionary. + Note: The 'jaccard' distance metric is only supported for sparse + inputs. + metric_kwds: dict (optional, default=None) + Metric argument output_metric: function Function returning the distance between two points in embedding space and the gradient of the distance wrt the first argument. @@ -306,6 +329,14 @@ def simplicial_set_embedding( umap_params.init = 0 umap_params.random_state = random_state umap_params.deterministic = deterministic + try: + umap_params.metric = metric_parsing[metric.lower()] + except KeyError: + raise ValueError(f"Invalid value for metric: {metric}") + if metric_kwds is None: + umap_params.p = 2.0 + else: + umap_params.p = metric_kwds.get("p", 2.0) if output_metric == 'euclidean': umap_params.target_metric = MetricType.EUCLIDEAN else: # output_metric == 'categorical' diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index eb02258e2c..33082f4d4c 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -29,7 +29,8 @@ cupyx = gpu_only_import('cupyx') cuda = gpu_only_import('numba.cuda') from cuml.manifold.umap_utils cimport * -from cuml.manifold.umap_utils import GraphHolder, find_ab_params +from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \ + metric_parsing, DENSE_SUPPORTED_METRICS, SPARSE_SUPPORTED_METRICS from cuml.common.sparsefuncs import extract_knn_infos from cuml.internals.safe_imports import gpu_only_import_from @@ -47,7 +48,6 @@ from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray from cuml.internals.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse -from cuml.metrics.distance_type cimport DistanceType from cuml.manifold.simpl_set import fuzzy_simplicial_set # no-cython-lint from cuml.manifold.simpl_set import simplicial_set_embedding # no-cython-lint @@ -152,13 +152,17 @@ class UMAP(UniversalBase, n_components: int (optional, default 2) The dimension of the space to embed into. This defaults to 2 to provide easy visualization, but can reasonably be set to any - metric : string (default='euclidean'). + metric: string (default='euclidean'). Distance metric to use. Supported distances are ['l1, 'cityblock', 'taxicab', 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'canberra', 'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger', 'hamming', 'jaccard'] Metrics that take arguments (such as minkowski) can have arguments passed via the metric_kwds dictionary. + Note: The 'jaccard' distance metric is only supported for sparse + inputs. + metric_kwds: dict (optional, default=None) + Metric argument n_epochs: int (optional, default None) The number of training epochs to be used in optimizing the low dimensional embedding. Larger values result in more accurate @@ -419,7 +423,7 @@ class UMAP(UniversalBase, raise ValueError("min_dist should be <= spread") @staticmethod - def _build_umap_params(cls): + def _build_umap_params(cls, sparse): cdef UMAPParams* umap_params = new UMAPParams() umap_params.n_neighbors = cls.n_neighbors umap_params.n_components = cls.n_components @@ -448,37 +452,20 @@ class UMAP(UniversalBase, umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic - # metric - metric_parsing = { - "l2": DistanceType.L2SqrtUnexpanded, - "euclidean": DistanceType.L2SqrtUnexpanded, - "sqeuclidean": DistanceType.L2Unexpanded, - "cityblock": DistanceType.L1, - "l1": DistanceType.L1, - "manhattan": DistanceType.L1, - "taxicab": DistanceType.L1, - "minkowski": DistanceType.LpUnexpanded, - "chebyshev": DistanceType.Linf, - "linf": DistanceType.Linf, - "cosine": DistanceType.CosineExpanded, - "correlation": DistanceType.CorrelationExpanded, - "hellinger": DistanceType.HellingerExpanded, - "hamming": DistanceType.HammingUnexpanded, - "jaccard": DistanceType.JaccardExpanded, - "canberra": DistanceType.Canberra - } - - if cls.metric.lower() in metric_parsing: + try: umap_params.metric = metric_parsing[cls.metric.lower()] - else: - raise ValueError("Invalid value for metric: {}" - .format(cls.metric)) - + if sparse: + if umap_params.metric not in SPARSE_SUPPORTED_METRICS: + raise NotImplementedError(f"Metric '{cls.metric}' not supported for sparse inputs.") + elif umap_params.metric not in DENSE_SUPPORTED_METRICS: + raise NotImplementedError(f"Metric '{cls.metric}' not supported for dense inputs.") + + except KeyError: + raise ValueError(f"Invalid value for metric: {cls.metric}") if cls.metric_kwds is None: umap_params.p = 2.0 else: - umap_params.p = cls.metric_kwds.get('p') - + umap_params.p = cls.metric_kwds.get("p", 2.0) cdef uintptr_t callback_ptr = 0 if cls.callback: callback_ptr = cls.callback.get_native_callback() @@ -576,7 +563,7 @@ class UMAP(UniversalBase, cdef uintptr_t embed_raw = self.embedding_.ptr cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self) + UMAP._build_umap_params(self, self.sparse_fit) cdef uintptr_t y_raw = 0 @@ -742,7 +729,7 @@ class UMAP(UniversalBase, cdef uintptr_t embed_ptr = self.embedding_.ptr cdef UMAPParams* umap_params = \ - UMAP._build_umap_params(self) + UMAP._build_umap_params(self, self.sparse_fit) if self.sparse_fit: transform_sparse(handle_[0], diff --git a/python/cuml/manifold/umap_utils.pyx b/python/cuml/manifold/umap_utils.pyx index 5d69ead34c..a68af62195 100644 --- a/python/cuml/manifold/umap_utils.pyx +++ b/python/cuml/manifold/umap_utils.pyx @@ -19,6 +19,7 @@ from rmm._lib.memory_resource cimport get_current_device_resource from pylibraft.common.handle cimport handle_t from cuml.manifold.umap_utils cimport * +from cuml.metrics.distance_type cimport DistanceType from libcpp.utility cimport move from cuml.internals.safe_imports import cpu_only_import np = cpu_only_import('numpy') @@ -41,7 +42,7 @@ cdef class GraphHolder: return graph @staticmethod - cdef GraphHolder from_coo_array(graph, handle, coo_array): + cdef GraphHolder from_coo_array(GraphHolder graph, handle, coo_array): def copy_from_array(dst_raft_coo_ptr, src_cp_coo): size = src_cp_coo.size itemsize = np.dtype(src_cp_coo.dtype).itemsize @@ -130,3 +131,53 @@ def find_ab_params(spread, min_dist): yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread) params, _ = curve_fit(curve, xv, yv) return params[0], params[1] + + +metric_parsing = { + "l2": DistanceType.L2SqrtUnexpanded, + "euclidean": DistanceType.L2SqrtUnexpanded, + "sqeuclidean": DistanceType.L2Unexpanded, + "cityblock": DistanceType.L1, + "l1": DistanceType.L1, + "manhattan": DistanceType.L1, + "taxicab": DistanceType.L1, + "minkowski": DistanceType.LpUnexpanded, + "chebyshev": DistanceType.Linf, + "linf": DistanceType.Linf, + "cosine": DistanceType.CosineExpanded, + "correlation": DistanceType.CorrelationExpanded, + "hellinger": DistanceType.HellingerExpanded, + "hamming": DistanceType.HammingUnexpanded, + "jaccard": DistanceType.JaccardExpanded, + "canberra": DistanceType.Canberra +} + + +DENSE_SUPPORTED_METRICS = [ + DistanceType.Canberra, + DistanceType.CorrelationExpanded, + DistanceType.CosineExpanded, + DistanceType.HammingUnexpanded, + DistanceType.HellingerExpanded, + # DistanceType.JaccardExpanded, # not supported + DistanceType.L1, + DistanceType.L2SqrtUnexpanded, + DistanceType.L2Unexpanded, + DistanceType.Linf, + DistanceType.LpUnexpanded, +] + + +SPARSE_SUPPORTED_METRICS = [ + DistanceType.Canberra, + DistanceType.CorrelationExpanded, + DistanceType.CosineExpanded, + DistanceType.HammingUnexpanded, + DistanceType.HellingerExpanded, + DistanceType.JaccardExpanded, + DistanceType.L1, + DistanceType.L2SqrtUnexpanded, + DistanceType.L2Unexpanded, + DistanceType.Linf, + DistanceType.LpUnexpanded, +] diff --git a/python/cuml/metrics/regression.pyx b/python/cuml/metrics/regression.pyx index e9a090fb37..7ebe8a9c5c 100644 --- a/python/cuml/metrics/regression.pyx +++ b/python/cuml/metrics/regression.pyx @@ -31,7 +31,7 @@ from cuml.internals.input_utils import input_to_cuml_array @cuml.internals.api_return_any() -def r2_score(y, y_hat, convert_dtype=True, handle=None) -> double: +def r2_score(y, y_hat, convert_dtype=True, handle=None) -> float: """ Calculates r2 score between y and y_hat diff --git a/python/cuml/metrics/trustworthiness.pyx b/python/cuml/metrics/trustworthiness.pyx index 9d77238921..6db6e8fb65 100644 --- a/python/cuml/metrics/trustworthiness.pyx +++ b/python/cuml/metrics/trustworthiness.pyx @@ -55,7 +55,7 @@ def _get_array_ptr(obj): @cuml.internals.api_return_any() def trustworthiness(X, X_embedded, handle=None, n_neighbors=5, metric='euclidean', - convert_dtype=True, batch_size=512) -> double: + convert_dtype=True, batch_size=512) -> float: """ Expresses to what extent the local structure is retained in embedding. The score is defined in the range [0, 1]. diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index 4069b032ad..fcc96dac18 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -577,7 +577,7 @@ class QN(Base, handle_[0], qnpams, X_m.ptr, - __is_col_major(X_m), + _is_col_major(X_m), y_ptr, n_rows, self.n_cols, @@ -612,7 +612,7 @@ class QN(Base, handle_[0], qnpams, X_m.ptr, - __is_col_major(X_m), + _is_col_major(X_m), y_ptr, n_rows, self.n_cols, @@ -729,7 +729,7 @@ class QN(Base, handle_[0], qnpams, X_m.ptr, - __is_col_major(X_m), + _is_col_major(X_m), n_rows, n_cols, _num_classes, @@ -755,7 +755,7 @@ class QN(Base, handle_[0], qnpams, X_m.ptr, - __is_col_major(X_m), + _is_col_major(X_m), n_rows, n_cols, _num_classes, @@ -856,7 +856,7 @@ class QN(Base, handle_[0], qnpams, X_m.ptr, - __is_col_major(X_m), + _is_col_major(X_m), n_rows, n_cols, _num_classes, @@ -882,7 +882,7 @@ class QN(Base, handle_[0], qnpams, X_m.ptr, - __is_col_major(X_m), + _is_col_major(X_m), n_rows, n_cols, _num_classes, @@ -940,5 +940,5 @@ class QN(Base, 'warm_start', 'delta', 'penalty_normalized'] -def __is_col_major(X): +def _is_col_major(X): return getattr(X, "order", "F").upper() == "F" diff --git a/python/cuml/solvers/sgd.pyx b/python/cuml/solvers/sgd.pyx index 0de27971c7..abf966eb7c 100644 --- a/python/cuml/solvers/sgd.pyx +++ b/python/cuml/solvers/sgd.pyx @@ -266,11 +266,12 @@ class SGD(Base, raise TypeError("This option will be supported in the future") - if self.alpha == 0: - raise ValueError("alpha must be > 0 since " - "learning_rate is 'optimal'. alpha is " - "used to compute the optimal learning " - " rate.") + # TODO: uncomment this when optimal learning rate is supported + # if self.alpha == 0: + # raise ValueError("alpha must be > 0 since " + # "learning_rate is 'optimal'. alpha is " + # "used to compute the optimal learning " + # " rate.") elif learning_rate == 'constant': self.lr_type = 1 diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 479853dc07..34d899e7bf 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -674,26 +674,26 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): @pytest.mark.parametrize( - "metric", + "metric,supported", [ - "l2", - "euclidean", - "sqeuclidean", - "l1", - "manhattan", - "minkowski", - "chebyshev", - "cosine", - "correlation", - "jaccard", - "hamming", - "canberra", + ("l2", True), + ("euclidean", True), + ("sqeuclidean", True), + ("l1", True), + ("manhattan", True), + ("minkowski", True), + ("chebyshev", True), + ("cosine", True), + ("correlation", True), + ("jaccard", False), + ("hamming", True), + ("canberra", True), ], ) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_distance_metrics_fit_transform_trust(metric): +def test_umap_distance_metrics_fit_transform_trust(metric, supported): data, labels = make_blobs( n_samples=1000, n_features=64, centers=5, random_state=42 ) @@ -707,7 +707,13 @@ def test_umap_distance_metrics_fit_transform_trust(metric): cuml_model = cuUMAP( n_neighbors=10, min_dist=0.01, metric=metric, init="random" ) + if not supported: + with pytest.raises(NotImplementedError): + cuml_model.fit_transform(data) + return + umap_embedding = umap_model.fit_transform(data) + cuml_embedding = cuml_model.fit_transform(data) umap_trust = trustworthiness( @@ -721,24 +727,28 @@ def test_umap_distance_metrics_fit_transform_trust(metric): @pytest.mark.parametrize( - "metric", + "metric,supported,umap_learn_supported", [ - "euclidean", - "l1", - "manhattan", - "minkowski", - "chebyshev", - "cosine", - "correlation", - "jaccard", - "hamming", - "canberra", + ("l2", True, False), + ("euclidean", True, True), + ("sqeuclidean", True, False), + ("l1", True, True), + ("manhattan", True, True), + ("minkowski", True, True), + ("chebyshev", True, True), + ("cosine", True, True), + ("correlation", True, True), + ("jaccard", True, True), + ("hamming", True, True), + ("canberra", True, True), ], ) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_distance_metrics_fit_transform_trust_on_sparse_input(metric): +def test_umap_distance_metrics_fit_transform_trust_on_sparse_input( + metric, supported, umap_learn_supported +): data, labels = make_blobs( n_samples=1000, n_features=64, centers=5, random_state=42 ) @@ -752,20 +762,31 @@ def test_umap_distance_metrics_fit_transform_trust_on_sparse_input(metric): new_data = scipy_sparse.csr_matrix(data[~data_selection]) - umap_model = umap.UMAP( - n_neighbors=10, min_dist=0.01, metric=metric, init="random" - ) + if umap_learn_supported: + umap_model = umap.UMAP( + n_neighbors=10, min_dist=0.01, metric=metric, init="random" + ) + umap_embedding = umap_model.fit_transform(new_data) + umap_trust = trustworthiness( + data[~data_selection], + umap_embedding, + n_neighbors=10, + metric=metric, + ) + cuml_model = cuUMAP( n_neighbors=10, min_dist=0.01, metric=metric, init="random" ) - umap_embedding = umap_model.fit_transform(new_data) - cuml_embedding = cuml_model.fit_transform(new_data) - umap_trust = trustworthiness( - data[~data_selection], umap_embedding, n_neighbors=10, metric=metric - ) + if not supported: + with pytest.raises(NotImplementedError): + cuml_model.fit_transform(new_data) + return + + cuml_embedding = cuml_model.fit_transform(new_data) cuml_trust = trustworthiness( data[~data_selection], cuml_embedding, n_neighbors=10, metric=metric ) - assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) + if umap_learn_supported: + assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) diff --git a/python/pyproject.toml b/python/pyproject.toml index 25d9cde814..869196d27f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -16,7 +16,7 @@ requires = [ "cmake>=3.26.4", "cuda-python>=11.7.1,<12.0a0", - "cython>=0.29,<0.30", + "cython>=3.0.0", "ninja", "pylibraft==23.10.*", "rmm==23.10.*", @@ -61,8 +61,8 @@ dependencies = [ "cupy-cuda11x>=12.0.0", "dask-cuda==23.10.*", "dask-cudf==23.10.*", - "dask>=2023.5.1", - "distributed>=2023.5.1", + "dask==2023.7.1", + "distributed==2023.7.1", "joblib>=0.11", "numba>=0.57", "raft-dask==23.10.*",