From 97742b3258955b87ee8bee245faab8422bf6af45 Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Sat, 16 Nov 2019 00:40:09 +0100 Subject: [PATCH 01/13] AUTO value for FIL algorithm. - 'AUTO' in Python, ALGO_AUTO in C++ - currently it is NAIVE for sparse forests, BATCH_TREE_REORG for dense ones - no need to specify the algorithm explicitly for sparse forests --- cpp/include/cuml/fil/fil.h | 3 +++ cpp/src/fil/fil.cu | 11 ++++++++--- cpp/test/sg/fil_test.cu | 9 +++++++++ python/cuml/fil/fil.pyx | 19 +++++++++++++------ python/cuml/test/test_fil.py | 7 ++++--- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/cpp/include/cuml/fil/fil.h b/cpp/include/cuml/fil/fil.h index 67b28991fa..9129653fc7 100644 --- a/cpp/include/cuml/fil/fil.h +++ b/cpp/include/cuml/fil/fil.h @@ -31,6 +31,9 @@ namespace fil { /** Inference algorithm to use. */ enum algo_t { + /** choose the algorithm automatically; currently chooses NAIVE for sparse forests + and BATCH_TREE_REORG for dense ones */ + ALGO_AUTO, /** naive algorithm: 1 thread block predicts 1 row; the row is cached in shared memory, and the trees are distributed cyclically between threads */ NAIVE, diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 736b8d59b2..6c96973ec2 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -172,6 +172,7 @@ struct dense_forest : forest { void init(const cumlHandle& h, const dense_node_t* nodes, const forest_params_t* params) { init_common(params); + if (algo_ == algo_t::NAIVE) algo_ = algo_t::BATCH_TREE_REORG; int num_nodes = forest_num_nodes(num_trees_, depth_); nodes_ = (dense_node*)h.getDeviceAllocator()->allocate( @@ -212,6 +213,7 @@ struct sparse_forest : forest { void init(const cumlHandle& h, const int* trees, const sparse_node_t* nodes, const forest_params_t* params) { init_common(params); + if (algo_ == algo_t::ALGO_AUTO) algo_ = algo_t::NAIVE; depth_ = 0; // a placeholder value num_nodes_ = params->num_nodes; @@ -251,18 +253,21 @@ void check_params(const forest_params_t* params, bool dense) { } else { ASSERT(params->num_nodes >= 0, "num_nodes must be non-negative for sparse forests"); - ASSERT(params->algo == algo_t::NAIVE, - "only NAIVE algorithm is supported for sparse forests"); + ASSERT(params->algo == algo_t::NAIVE || params->algo == algo_t::ALGO_AUTO, + "only ALGO_AUTO and NAIVE algorithms are supported " + "for sparse forests"); } ASSERT(params->num_trees >= 0, "num_trees must be non-negative"); ASSERT(params->num_cols >= 0, "num_cols must be non-negative"); switch (params->algo) { + case algo_t::ALGO_AUTO: case algo_t::NAIVE: case algo_t::TREE_REORG: case algo_t::BATCH_TREE_REORG: break; default: - ASSERT(false, "aglo should be NAIVE, TREE_REORG or BATCH_TREE_REORG"); + ASSERT(false, + "aglo should be ALGO_AUTO, NAIVE, TREE_REORG or BATCH_TREE_REORG"); } // output_t::RAW == 0, and doesn't have a separate flag output_t all_set = diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index e60ae54112..0b71fb309d 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -497,6 +497,8 @@ std::vector predict_dense_inputs = { {20000, 50, 0.05, 8, 50, 0.05, fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5, fil::algo_t::TREE_REORG, 42, 2e-3f}, + {20000, 50, 0.05, 8, 50, 0.05, fil::output_t::SIGMOID, 0, 0, + fil::algo_t::ALGO_AUTO, 42, 2e-3f}, }; TEST_P(PredictDenseFilTest, Predict) { compare(); } @@ -528,6 +530,9 @@ std::vector predict_sparse_inputs = { {20000, 50, 0.05, 8, 50, 0.05, fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5, fil::algo_t::NAIVE, 42, 2e-3f}, + {20000, 50, 0.05, 8, 50, 0.05, + fil::output_t(fil::output_t::SIGMOID | fil::output_t::THRESHOLD), 0, 0, + fil::algo_t::ALGO_AUTO, 42, 2e-3f}, }; TEST_P(PredictSparseFilTest, Predict) { compare(); } @@ -601,6 +606,8 @@ std::vector import_dense_inputs = { {20000, 50, 0.05, 8, 50, 0.05, fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5, fil::algo_t::TREE_REORG, 42, 2e-3f, tl::Operator::kGE}, + {20000, 50, 0.05, 8, 50, 0.05, fil::output_t::SIGMOID, 0, 0, + fil::algo_t::ALGO_AUTO, 42, 2e-3f, tl::Operator::kLE}, }; TEST_P(TreeliteDenseFilTest, Import) { compare(); } @@ -630,6 +637,8 @@ std::vector import_sparse_inputs = { {20000, 50, 0.05, 8, 50, 0.05, fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5, fil::algo_t::NAIVE, 42, 2e-3f, tl::Operator::kGE}, + {20000, 50, 0.05, 8, 50, 0.05, fil::output_t::RAW, 0, 0, + fil::algo_t::ALGO_AUTO, 42, 2e-3f, tl::Operator::kLT}, }; TEST_P(TreeliteSparseFilTest, Import) { compare(); } diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index c273bdc70e..13107cb7a3 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -140,6 +140,7 @@ cdef class TreeliteModel(): cdef extern from "cuml/fil/fil.h" namespace "ML::fil": cdef enum algo_t: + ALGO_AUTO, NAIVE, TREE_REORG, BATCH_TREE_REORG @@ -184,7 +185,10 @@ cdef class ForestInference_impl(): self.handle = handle def get_algo(self, algo_str): - algo_dict={'NAIVE': algo_t.NAIVE, + algo_dict={ + 'AUTO': algo_t.ALGO_AUTO, + 'auto': algo_t.ALGO_AUTO, + 'NAIVE': algo_t.NAIVE, 'naive': algo_t.NAIVE, 'BATCH_TREE_REORG': algo_t.BATCH_TREE_REORG, 'batch_tree_reorg': algo_t.BATCH_TREE_REORG, @@ -399,7 +403,7 @@ class ForestInference(Base): return self._impl.predict(X, preds) def load_from_treelite_model(self, model, output_class, - algo='TREE_REORG', + algo='AUTO', threshold=0.5, storage_type='DENSE'): """ @@ -415,6 +419,9 @@ class ForestInference(Base): If true, return a 1 or 0 depending on whether the raw prediction exceeds the threshold. If False, just return the raw prediction. algo : string name of the algo from (from algo_t enum) + 'AUTO' or 'auto' - choose the algorithm automatically; + currently 'BATCH_TREE_REORG' is used for dense storage, + and 'NAIVE' for sparse storage 'NAIVE' or 'naive' - simple inference using shared memory 'TREE_REORG' or 'tree_reorg' - similar to naive but trees rearranged to be more coalescing-friendly @@ -429,7 +436,7 @@ class ForestInference(Base): (currently DENSE is always used) 'DENSE' or 'dense' - create a dense forest 'SPARSE' or 'sparse' - create a sparse forest; - requires algo='NAIVE' + requires algo='NAIVE' or algo='AUTO' """ if isinstance(model, TreeliteModel): # TreeliteModel defined in this file @@ -445,7 +452,7 @@ class ForestInference(Base): def load_from_sklearn(skl_model, output_class=False, threshold=0.50, - algo='TREE_REORG', + algo='AUTO', storage_type='DENSE', handle=None): cuml_fm = ForestInference(handle=handle) @@ -459,7 +466,7 @@ class ForestInference(Base): def load(filename, output_class=False, threshold=0.50, - algo='TREE_REORG', + algo='AUTO', storage_type='DENSE', model_type="xgboost", handle=None): @@ -500,7 +507,7 @@ class ForestInference(Base): def load_from_randomforest(self, model_handle, output_class=False, - algo='TREE_REORG', + algo='AUTO', storage_type='DENSE', threshold=0.50): diff --git a/python/cuml/test/test_fil.py b/python/cuml/test/test_fil.py index 05b2f43d36..c775093eed 100644 --- a/python/cuml/test/test_fil.py +++ b/python/cuml/test/test_fil.py @@ -250,8 +250,10 @@ def small_classifier_and_preds(tmpdir_factory): @pytest.mark.skipif(has_xgboost() is False, reason="need to install xgboost") -@pytest.mark.parametrize('algo', ['NAIVE', 'TREE_REORG', 'BATCH_TREE_REORG', - 'naive', 'tree_reorg', 'batch_tree_reorg']) +@pytest.mark.parametrize('algo', ['AUTO', 'NAIVE', 'TREE_REORG', + 'BATCH_TREE_REORG', + 'auto', 'naive', 'tree_reorg', + 'batch_tree_reorg']) def test_output_algos(algo, small_classifier_and_preds): model_path, X, xgb_preds = small_classifier_and_preds fm = ForestInference.load(model_path, @@ -271,7 +273,6 @@ def test_output_algos(algo, small_classifier_and_preds): def test_output_storage_type(storage_type, small_classifier_and_preds): model_path, X, xgb_preds = small_classifier_and_preds fm = ForestInference.load(model_path, - algo='NAIVE', output_class=True, storage_type=storage_type, threshold=0.50) From e09e0035a8ef383e5ad9302897520c6ec5b354c1 Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Sat, 16 Nov 2019 00:45:44 +0100 Subject: [PATCH 02/13] Updated CHANGELOG.md. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6aced51091..24dd2af701 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ - PR #1351: Build treelite temporarily for GPU CI testing of FIL Scikit-learn model importing - PR #1345: Removing deprecated should_downcast argument - PR #1362: device_buffer in UMAP + Sparse prims +- PR #1376: AUTO value for FIL algorithm ## Bug Fixes - PR #1281: Making rng.h threadsafe From d66b580d52c3b272364bbbfb2635a0954cb3a42e Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Sat, 16 Nov 2019 00:46:58 +0100 Subject: [PATCH 03/13] Fixed style errors. --- python/cuml/fil/fil.pyx | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 13107cb7a3..e7bfdb914c 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -185,8 +185,7 @@ cdef class ForestInference_impl(): self.handle = handle def get_algo(self, algo_str): - algo_dict={ - 'AUTO': algo_t.ALGO_AUTO, + algo_dict={'AUTO': algo_t.ALGO_AUTO, 'auto': algo_t.ALGO_AUTO, 'NAIVE': algo_t.NAIVE, 'naive': algo_t.NAIVE, @@ -420,7 +419,7 @@ class ForestInference(Base): exceeds the threshold. If False, just return the raw prediction. algo : string name of the algo from (from algo_t enum) 'AUTO' or 'auto' - choose the algorithm automatically; - currently 'BATCH_TREE_REORG' is used for dense storage, + currently 'BATCH_TREE_REORG' is used for dense storage, and 'NAIVE' for sparse storage 'NAIVE' or 'naive' - simple inference using shared memory 'TREE_REORG' or 'tree_reorg' - similar to naive but trees From 20e5963ff87563e4153b7cf31d01641856ffe426 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 3 Dec 2019 13:49:50 -0500 Subject: [PATCH 04/13] Some small fixes to dataframe inputs for knn classifier --- python/cuml/neighbors/kneighbors_classifier.pyx | 3 ++- python/cuml/neighbors/nearest_neighbors.pyx | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index f5da82ab7f..8b5885b093 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -186,6 +186,7 @@ class KNeighborsClassifier(NearestNeighbors): convert_dtype=convert_dtype) cdef uintptr_t inds_ctype + inds, inds_ctype, n_rows, _, _ = \ input_to_dev_array(knn_indices, order='C', check_dtype=np.int64, convert_to_dtype=(np.int64 @@ -227,7 +228,7 @@ class KNeighborsClassifier(NearestNeighbors): if isinstance(X, np.ndarray): return np.array(classes, dtype=np.int32) elif isinstance(X, cudf.DataFrame): - return cudf.DataFrame.from_gpu_matrix(X) + return cudf.Series(classes) else: return classes diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 5cebd36ef2..8ef3702448 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -291,10 +291,11 @@ class NearestNeighbors(Base): raise ValueError("Dimensions of X need to match dimensions of " "indices (%d)" % self.n_dims) + X_m, X_ctype, N, _, dtype = \ input_to_dev_array(X, order='F', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype - else None)) + else False)) # Need to establish result matrices for indices (Nxk) # and for distances (Nxk) @@ -343,7 +344,6 @@ class NearestNeighbors(Base): for i in range(0, D_ndarr.shape[1]): dists[str(i)] = D_ndarr[:, i] - return dists, inds elif isinstance(X, np.ndarray): inds = np.asarray(I_ndarr) From ab8e8ee73689738a48aa7b1f51886c5ed90ea89f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 3 Dec 2019 14:09:51 -0500 Subject: [PATCH 05/13] Small bug fixes --- python/cuml/neighbors/kneighbors_classifier.pyx | 5 ++++- python/cuml/neighbors/kneighbors_regressor.pyx | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index 8b5885b093..277819a003 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -228,7 +228,10 @@ class KNeighborsClassifier(NearestNeighbors): if isinstance(X, np.ndarray): return np.array(classes, dtype=np.int32) elif isinstance(X, cudf.DataFrame): - return cudf.Series(classes) + + if classes.ndim == 1: + classes = classes.reshape(classes.shape[0], 1) + return cudf.DataFrame.from_gpu_matrix(classes) else: return classes diff --git a/python/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/neighbors/kneighbors_regressor.pyx index 3825ce67a5..b614fec783 100644 --- a/python/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/neighbors/kneighbors_regressor.pyx @@ -234,6 +234,8 @@ class KNeighborsRegressor(NearestNeighbors): if isinstance(X, np.ndarray): return np.array(results) elif isinstance(X, cudf.DataFrame): + if results.ndim == 1: + results = results.reshape(results.shape[0], 1) return cudf.DataFrame.from_gpu_matrix(results) else: return results From 4dd2b030c575ffad84861bdd6c7ed1640b404c95 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 Dec 2019 15:38:12 -0500 Subject: [PATCH 06/13] Adding tests for data frame output from kneighbors classifier / regressor --- python/cuml/metrics/accuracy.pyx | 18 ++--- python/cuml/metrics/regression.pyx | 2 +- .../cuml/neighbors/kneighbors_classifier.pyx | 6 +- python/cuml/neighbors/nearest_neighbors.pyx | 10 +-- .../cuml/test/test_kneighbors_classifier.py | 71 ++++++++++++++++--- python/cuml/test/test_kneighbors_regressor.py | 25 +++++-- 6 files changed, 96 insertions(+), 36 deletions(-) diff --git a/python/cuml/metrics/accuracy.pyx b/python/cuml/metrics/accuracy.pyx index 5239a1c57f..54a4265f2f 100644 --- a/python/cuml/metrics/accuracy.pyx +++ b/python/cuml/metrics/accuracy.pyx @@ -23,6 +23,8 @@ import numpy as np from libc.stdint cimport uintptr_t +import cudf + from cuml.common.handle cimport cumlHandle from cuml.utils import input_to_dev_array import cuml.common.handle @@ -36,7 +38,7 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": int n) except + -def accuracy_score(ground_truth, predictions, handle=None): +def accuracy_score(ground_truth, predictions, handle=None, convert_dtype=True): """ Calcuates the accuracy score of a classification model. @@ -57,19 +59,17 @@ def accuracy_score(ground_truth, predictions, handle=None): if handle is None else handle cdef cumlHandle* handle_ =\ handle.getHandle() - if type(predictions) == np.ndarray: - predictions = predictions.astype(np.int32) - ground_truth = ground_truth.astype(np.int32) - else: - predictions = np.asarray(predictions, dtype=np.int32) - ground_truth = np.asarray(ground_truth, dtype=np.int32) cdef uintptr_t preds_ptr, ground_truth_ptr preds_m, preds_ptr, n_rows, _, _ = \ - input_to_dev_array(predictions) + input_to_dev_array(predictions, + convert_to_dtype=np.int32 + if convert_dtype else None) ground_truth_m, ground_truth_ptr, _, _, ground_truth_dtype = \ - input_to_dev_array(ground_truth) + input_to_dev_array(ground_truth, + convert_to_dtype = np.int32 + if convert_dtype else None) acc = accuracy_score_py(handle_[0], preds_ptr, diff --git a/python/cuml/metrics/regression.pyx b/python/cuml/metrics/regression.pyx index 86571f84f8..ee16694fea 100644 --- a/python/cuml/metrics/regression.pyx +++ b/python/cuml/metrics/regression.pyx @@ -76,7 +76,7 @@ def r2_score(y, y_hat, convert_dtype=False, handle=None): n = len(y) - if y.dtype == 'float32': + if y_m.dtype == 'float32': result_f32 = regression.r2_score_py(handle_[0], y_ptr, diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index 277819a003..10e94185ee 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -228,7 +228,6 @@ class KNeighborsClassifier(NearestNeighbors): if isinstance(X, np.ndarray): return np.array(classes, dtype=np.int32) elif isinstance(X, cudf.DataFrame): - if classes.ndim == 1: classes = classes.reshape(classes.shape[0], 1) return cudf.DataFrame.from_gpu_matrix(classes) @@ -334,6 +333,7 @@ class KNeighborsClassifier(NearestNeighbors): """ y_hat = self.predict(X, convert_dtype=convert_dtype) if isinstance(y_hat, tuple): - return (accuracy_score(y, y_hat_i) for y_hat_i in y_hat) + return (accuracy_score(y, y_hat_i, convert_dtype=convert_dtype) + for y_hat_i in y_hat) else: - return accuracy_score(y, y_hat) + return accuracy_score(y, y_hat, convert_dtype=convert_dtype) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 8ef3702448..773826aefd 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -336,14 +336,8 @@ class NearestNeighbors(Base): D_ndarr = D_ndarr.reshape((N, n_neighbors)) if isinstance(X, cudf.DataFrame): - inds = cudf.DataFrame() - for i in range(0, I_ndarr.shape[1]): - inds[str(i)] = I_ndarr[:, i] - - dists = cudf.DataFrame() - for i in range(0, D_ndarr.shape[1]): - dists[str(i)] = D_ndarr[:, i] - + inds = cudf.DataFrame.from_gpu_matrix(I_ndarr) + dists = cudf.DataFrame.from_gpu_matrix(D_ndarr) elif isinstance(X, np.ndarray): inds = np.asarray(I_ndarr) diff --git a/python/cuml/test/test_kneighbors_classifier.py b/python/cuml/test/test_kneighbors_classifier.py index fd2bd7533b..3f24323f36 100644 --- a/python/cuml/test/test_kneighbors_classifier.py +++ b/python/cuml/test/test_kneighbors_classifier.py @@ -16,6 +16,11 @@ import pytest +import rmm + + +import cudf + from cuml.neighbors import KNeighborsClassifier as cuKNN from sklearn.datasets import make_blobs @@ -23,11 +28,12 @@ from cuml.test.utils import array_equal +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) @pytest.mark.parametrize("nrows", [1000, 10000]) @pytest.mark.parametrize("ncols", [50, 100]) @pytest.mark.parametrize("n_neighbors", [2, 5, 10]) @pytest.mark.parametrize("n_clusters", [2, 5, 10]) -def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters): +def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters,datatype): X, y = make_blobs(n_samples=nrows, centers=n_clusters, @@ -37,19 +43,29 @@ def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters): X = X.astype(np.float32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + y = cudf.DataFrame.from_gpu_matrix(rmm.to_device(y.reshape(nrows, 1))) + knn_cu = cuKNN(n_neighbors=n_neighbors) knn_cu.fit(X, y) predictions = knn_cu.predict(X) + if datatype == "dataframe": + assert isinstance(predictions, cudf.DataFrame) + else: + assert isinstance(predictions, np.ndarray) + assert array_equal(predictions.astype(np.int32), y.astype(np.int32)) +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) @pytest.mark.parametrize("nrows", [1000, 10000]) @pytest.mark.parametrize("ncols", [50, 100]) @pytest.mark.parametrize("n_neighbors", [2, 5, 10]) @pytest.mark.parametrize("n_clusters", [2, 5, 10]) -def test_score(nrows, ncols, n_neighbors, n_clusters): +def test_score(nrows, ncols, n_neighbors, n_clusters, datatype): X, y = make_blobs(n_samples=nrows, centers=n_clusters, n_features=ncols, random_state=0, @@ -57,17 +73,22 @@ def test_score(nrows, ncols, n_neighbors, n_clusters): X = X.astype(np.float32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + y = cudf.DataFrame.from_gpu_matrix(rmm.to_device(y.reshape(nrows, 1))) + knn_cu = cuKNN(n_neighbors=n_neighbors) knn_cu.fit(X, y) - assert knn_cu.score(X, y) == 1.0 + assert knn_cu.score(X, y) >= (1.0 - 0.004) +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) @pytest.mark.parametrize("nrows", [1000, 10000]) @pytest.mark.parametrize("ncols", [50, 100]) @pytest.mark.parametrize("n_neighbors", [2, 5, 10]) @pytest.mark.parametrize("n_clusters", [2, 5, 10]) -def test_predict_proba(nrows, ncols, n_neighbors, n_clusters): +def test_predict_proba(nrows, ncols, n_neighbors, n_clusters, datatype): X, y = make_blobs(n_samples=nrows, centers=n_clusters, @@ -77,17 +98,26 @@ def test_predict_proba(nrows, ncols, n_neighbors, n_clusters): X = X.astype(np.float32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + y = cudf.DataFrame.from_gpu_matrix(rmm.to_device(y.reshape(nrows, 1))) + knn_cu = cuKNN(n_neighbors=n_neighbors) knn_cu.fit(X, y) predictions = knn_cu.predict_proba(X) - assert len(predictions[predictions < 0]) == 0 - assert len(predictions[predictions > 1]) == 0 + if datatype == "dataframe": + assert isinstance(predictions, cudf.DataFrame) + predictions = predictions.as_gpu_matrix().copy_to_host() + y = y.as_gpu_matrix().copy_to_host().reshape(nrows) + else: + assert isinstance(predictions, np.ndarray) - y_hat = np.argmax(np.array(predictions), axis=1) + y_hat = np.argmax(predictions, axis=1) assert array_equal(y_hat.astype(np.int32), y.astype(np.int32)) + assert array_equal(predictions.sum(axis=1), np.ones(nrows)) def test_nonmonotonic_labels(): @@ -104,26 +134,39 @@ def test_nonmonotonic_labels(): assert array_equal(p.astype(np.int32), y) -def test_predict_multioutput(): +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) +def test_predict_multioutput(datatype): X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) - y = np.array([[15, 2], [5, 4]]).astype(np.int32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + y = cudf.DataFrame.from_gpu_matrix(rmm.to_device(y)) + knn_cu = cuKNN(n_neighbors=1) knn_cu.fit(X, y) p = knn_cu.predict(X) + if datatype == "dataframe": + assert isinstance(p, cudf.DataFrame) + else: + assert isinstance(p, np.ndarray) + assert array_equal(p.astype(np.int32), y) -def test_predict_proba_multioutput(): +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) +def test_predict_proba_multioutput(datatype): X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) - y = np.array([[15, 2], [5, 4]]).astype(np.int32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + y = cudf.DataFrame.from_gpu_matrix(rmm.to_device(y)) + expected = (np.array([[0., 1.], [1., 0.]]).astype(np.float32), np.array([[1., 0.], [0., 1.]]).astype(np.float32)) @@ -134,5 +177,11 @@ def test_predict_proba_multioutput(): assert isinstance(p, tuple) + for i in p: + if datatype == "dataframe": + assert isinstance(i, cudf.DataFrame) + else: + assert isinstance(i, np.ndarray) + assert array_equal(p[0].astype(np.float32), expected[0]) assert array_equal(p[1].astype(np.float32), expected[1]) diff --git a/python/cuml/test/test_kneighbors_regressor.py b/python/cuml/test/test_kneighbors_regressor.py index 09991b46b8..ba7b4600a3 100644 --- a/python/cuml/test/test_kneighbors_regressor.py +++ b/python/cuml/test/test_kneighbors_regressor.py @@ -16,6 +16,9 @@ import pytest +import cudf +import rmm + from cuml.neighbors import KNeighborsRegressor as cuKNN from sklearn.datasets import make_blobs @@ -49,7 +52,7 @@ def test_kneighbors_regressor(n_samples=40, assert np.all(abs(y_pred - y_target) < 0.3) -def test_KNeighborsRegressor_multioutput_uniform_weight(): +def test_kneighborsRegressor_multioutput_uniform_weight(): # Test k-neighbors in multi-output regression with uniform weight rng = check_random_state(0) n_features = 5 @@ -75,11 +78,12 @@ def test_KNeighborsRegressor_multioutput_uniform_weight(): assert_array_almost_equal(y_pred, y_pred_idx) +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) @pytest.mark.parametrize("nrows", [1000, 10000]) @pytest.mark.parametrize("ncols", [50, 100]) @pytest.mark.parametrize("n_neighbors", [2, 5, 10]) @pytest.mark.parametrize("n_clusters", [2, 5, 10]) -def test_score(nrows, ncols, n_neighbors, n_clusters): +def test_score(nrows, ncols, n_neighbors, n_clusters, datatype): # Using make_blobs here to check averages and neighborhoods X, y = make_blobs(n_samples=nrows, centers=n_clusters, @@ -89,21 +93,34 @@ def test_score(nrows, ncols, n_neighbors, n_clusters): X = X.astype(np.float32) y = y.astype(np.float32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + y = cudf.DataFrame.from_gpu_matrix(rmm.to_device(y.reshape(nrows, 1))) + knn_cu = cuKNN(n_neighbors=n_neighbors) knn_cu.fit(X, y) assert knn_cu.score(X, y) >= 0.9999 -def test_predict_multioutput(): +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) +def test_predict_multioutput(datatype): X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) - y = np.array([[15.0, 2.0], [5.0, 4.0]]).astype(np.int32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + y = cudf.DataFrame.from_gpu_matrix(rmm.to_device(y)) + knn_cu = cuKNN(n_neighbors=1) knn_cu.fit(X, y) p = knn_cu.predict(X) + if datatype == "dataframe": + assert isinstance(p, cudf.DataFrame) + else: + assert isinstance(p, np.ndarray) + assert array_equal(p.astype(np.int32), y) From 26c02754993a65b39c2ede84c01e268b929c28b3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 Dec 2019 16:24:18 -0500 Subject: [PATCH 07/13] Updating nearest neighbors example in docs --- .../cuml/neighbors/kneighbors_classifier.pyx | 2 +- python/cuml/neighbors/nearest_neighbors.pyx | 108 +++++++++++------- python/cuml/test/test_nearest_neighbors.py | 61 ++++++---- 3 files changed, 108 insertions(+), 63 deletions(-) diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index 10e94185ee..bc174a2121 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -101,6 +101,7 @@ class KNeighborsClassifier(NearestNeighbors): Output: + ------- .. code-block:: python @@ -108,7 +109,6 @@ class KNeighborsClassifier(NearestNeighbors): array([3, 1, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 1, 0, 0, 0, 2, 3, 3, 0, 3, 0, 0, 0, 0, 3, 2, 0, 0, 0], dtype=int32) - Notes ------ diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 773826aefd..4263e2dbd7 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -92,27 +92,22 @@ class NearestNeighbors(Base): import cudf from cuml.neighbors import NearestNeighbors - import numpy as np + from cuml.datasets import make_blobs - np_float = np.array([ - [1,2,3], # Point 1 - [1,2,4], # Point 2 - [2,2,4] # Point 3 - ]).astype('float32') + X, _ = make_blobs(n_samples=25, centers=5, + n_features=10, random_state=42) - gdf_float = cudf.DataFrame() - gdf_float['dim_0'] = np.ascontiguousarray(np_float[:,0]) - gdf_float['dim_1'] = np.ascontiguousarray(np_float[:,1]) - gdf_float['dim_2'] = np.ascontiguousarray(np_float[:,2]) + # build a cudf Dataframe + X_cudf = cudf.DataFrame.from_gpu_matrix(X) - print('n_samples = 3, n_dims = 3') - print(gdf_float) + # fit model + model = NearestNeighbors(n_neighbors=3) + model.fit(X) - nn_float = NearestNeighbors() - nn_float.fit(gdf_float) # get 3 nearest neighbors - distances,indices = nn_float.kneighbors(gdf_float,k=3) + distances, indices = model.kneighbors(X_cudf) + # print results print(indices) print(distances) @@ -120,38 +115,71 @@ class NearestNeighbors(Base): .. code-block:: python - import cudf - - # Both import methods supported - # from cuml.neighbors import NearestNeighbors - from cuml import NearestNeighbors - - n_samples = 3, n_dims = 3 - - dim_0 dim_1 dim_2 - - 0 1.0 2.0 3.0 - 1 1.0 2.0 4.0 - 2 2.0 2.0 4.0 - - # indices: - index_neighbor_0 index_neighbor_1 index_neighbor_2 - 0 0 1 2 - 1 1 0 2 - 2 2 1 0 - # distances: + indices: + + 0 1 2 + 0 0 14 21 + 1 1 19 8 + 2 2 9 23 + 3 3 14 21 + 4 4 15 18 + 5 5 1 19 + 6 6 10 13 + 7 7 14 3 + 8 8 19 1 + 9 9 23 2 + 10 10 24 6 + 11 11 18 15 + 12 12 19 8 + 13 13 17 24 + 14 14 21 7 + 15 15 18 4 + 16 16 23 9 + 17 17 24 13 + 18 18 11 15 + 19 19 1 12 + 20 20 2 9 + 21 21 14 3 + 22 22 18 11 + 23 23 16 9 + 24 24 17 10 + + distances: + + 0 1 2 + 0 0.0 4.883116 5.570006 + 1 0.0 3.047896 4.105496 + 2 0.0 3.558557 3.567704 + 3 0.0 3.806127 3.880100 + 4 0.0 3.604866 4.257425 + 5 0.0 5.013218 5.018098 + 6 0.0 4.701295 5.285057 + 7 0.0 3.761291 4.024879 + 8 0.0 3.942976 4.105496 + 9 0.0 3.404269 3.558557 + 10 0.0 3.818043 4.701295 + 11 0.0 2.254586 3.620186 + 12 0.0 3.457439 4.459381 + 13 0.0 3.987583 4.324923 + 14 0.0 3.149770 3.761291 + 15 0.0 3.097861 3.604866 + 16 0.0 3.357889 3.664541 + 17 0.0 3.428183 3.987583 + 18 0.0 2.254586 3.097861 + 19 0.0 3.047896 3.457439 + 20 0.0 3.963187 4.498718 + 21 0.0 3.149770 3.880100 + 22 0.0 4.210738 4.227068 + 23 0.0 3.357889 3.404269 + 24 0.0 3.428183 3.818043 - distance_neighbor_0 distance_neighbor_1 distance_neighbor_2 - 0 0.0 1.0 2.0 - 1 0.0 1.0 1.0 - 2 0.0 1.0 2.0 Notes ------ For an additional example see `the NearestNeighbors notebook - `_. + `_. For additional docs, see `scikitlearn's NearestNeighbors `_. diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index 6dfc621ca8..9e7e1314a1 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -16,6 +16,8 @@ import pytest +import rmm + from cuml.test.utils import array_equal, unit_param, quality_param, \ stress_param from cuml.neighbors import NearestNeighbors as cuKNN @@ -37,22 +39,33 @@ def predict(neigh_ind, _y, n_neighbors): return ypred.ravel(), count.ravel() * 1.0 / n_neighbors +@pytest.mark.parametrize("datatype", ["dataframe", "numpy"]) @pytest.mark.parametrize("nrows", [500, 1000, 10000]) @pytest.mark.parametrize("ncols", [100, 1000]) @pytest.mark.parametrize("n_neighbors", [10, 50]) @pytest.mark.parametrize("n_clusters", [2, 10]) -def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters): +def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters, + datatype): X, y = make_blobs(n_samples=nrows, centers=n_clusters, n_features=ncols, random_state=0) X = X.astype(np.float32) + if datatype == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) + knn_cu = cuKNN() knn_cu.fit(X) neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors, return_distance=False) + if datatype == "dataframe": + assert isinstance(neigh_ind, cudf.DataFrame) + neigh_ind = neigh_ind.as_gpu_matrix().copy_to_host() + else: + assert isinstance(neigh_ind, np.ndarray) + labels, probs = predict(neigh_ind, y, n_neighbors) assert array_equal(labels, y) @@ -78,7 +91,7 @@ def test_return_dists(): assert len(ret) == 2 -@pytest.mark.parametrize('input_type', ['ndarray']) +@pytest.mark.parametrize('input_type', ['dataframe', 'ndarray']) @pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000), stress_param(500000)]) @pytest.mark.parametrize('n_feats', [unit_param(3), quality_param(100), @@ -86,30 +99,34 @@ def test_return_dists(): @pytest.mark.parametrize('k', [unit_param(3), quality_param(30), stress_param(50)]) def test_cuml_against_sklearn(input_type, nrows, n_feats, k): - n_samples = nrows - X, y = make_blobs(n_samples=n_samples, + X, _ = make_blobs(n_samples=nrows, n_features=n_feats, random_state=0) - knn_cu = cuKNN() - - if input_type == 'ndarray': - - knn_cu.fit(X) - D_cuml, I_cuml = knn_cu.kneighbors(X, k) - - assert type(D_cuml) == np.ndarray - assert type(I_cuml) == np.ndarray + knn_sk = skKNN(metric="euclidean") + knn_sk.fit(X) + D_sk, I_sk = knn_sk.kneighbors(X, k) - D_cuml_arr = D_cuml - I_cuml_arr = I_cuml + if input_type == "dataframe": + X = cudf.DataFrame.from_gpu_matrix(rmm.to_device(X)) - if nrows < 500000: - knn_sk = skKNN(metric="euclidean") - knn_sk.fit(X) - D_sk, I_sk = knn_sk.kneighbors(X, k) - - assert array_equal(D_cuml_arr, D_sk, 1e-2, with_sign=True) - assert I_cuml_arr.all() == I_sk.all() + knn_cu = cuKNN() + knn_cu.fit(X) + D_cuml, I_cuml = knn_cu.kneighbors(X, k) + + if input_type == "dataframe": + assert isinstance(D_cuml, cudf.DataFrame) + assert isinstance(I_cuml, cudf.DataFrame) + D_cuml_arr = D_cuml.as_gpu_matrix().copy_to_host() + I_cuml_arr = I_cuml.as_gpu_matrix().copy_to_host() + else: + assert isinstance(D_cuml, np.ndarray) + assert isinstance(I_cuml, np.ndarray) + D_cuml_arr = D_cuml + I_cuml_arr = I_cuml + + + assert array_equal(D_cuml_arr, D_sk, 1e-2, with_sign=True) + assert I_cuml_arr.all() == I_sk.all() def test_knn_fit_twice(): From f1cc208e011bf2977e9c3f2d8a13bef2efe3d709 Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Thu, 5 Dec 2019 17:59:53 +0100 Subject: [PATCH 08/13] Update cpp/src/fil/fil.cu Co-Authored-By: John Zedlewski <904524+JohnZed@users.noreply.github.com> --- cpp/src/fil/fil.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 6c96973ec2..586b9b2439 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -267,7 +267,7 @@ void check_params(const forest_params_t* params, bool dense) { break; default: ASSERT(false, - "aglo should be ALGO_AUTO, NAIVE, TREE_REORG or BATCH_TREE_REORG"); + "algo should be ALGO_AUTO, NAIVE, TREE_REORG or BATCH_TREE_REORG"); } // output_t::RAW == 0, and doesn't have a separate flag output_t all_set = From a2a427ffec3414f30272e8047f0704ad0256deba Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 5 Dec 2019 14:17:12 -0500 Subject: [PATCH 09/13] Using ellipses to truncate output in example --- python/cuml/neighbors/nearest_neighbors.pyx | 40 +++------------------ 1 file changed, 4 insertions(+), 36 deletions(-) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 4263e2dbd7..0a8cfceb62 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -123,24 +123,8 @@ class NearestNeighbors(Base): 1 1 19 8 2 2 9 23 3 3 14 21 - 4 4 15 18 - 5 5 1 19 - 6 6 10 13 - 7 7 14 3 - 8 8 19 1 - 9 9 23 2 - 10 10 24 6 - 11 11 18 15 - 12 12 19 8 - 13 13 17 24 - 14 14 21 7 - 15 15 18 4 - 16 16 23 9 - 17 17 24 13 - 18 18 11 15 - 19 19 1 12 - 20 20 2 9 - 21 21 14 3 + ... + 22 22 18 11 23 23 16 9 24 24 17 10 @@ -152,24 +136,8 @@ class NearestNeighbors(Base): 1 0.0 3.047896 4.105496 2 0.0 3.558557 3.567704 3 0.0 3.806127 3.880100 - 4 0.0 3.604866 4.257425 - 5 0.0 5.013218 5.018098 - 6 0.0 4.701295 5.285057 - 7 0.0 3.761291 4.024879 - 8 0.0 3.942976 4.105496 - 9 0.0 3.404269 3.558557 - 10 0.0 3.818043 4.701295 - 11 0.0 2.254586 3.620186 - 12 0.0 3.457439 4.459381 - 13 0.0 3.987583 4.324923 - 14 0.0 3.149770 3.761291 - 15 0.0 3.097861 3.604866 - 16 0.0 3.357889 3.664541 - 17 0.0 3.428183 3.987583 - 18 0.0 2.254586 3.097861 - 19 0.0 3.047896 3.457439 - 20 0.0 3.963187 4.498718 - 21 0.0 3.149770 3.880100 + ... + 22 0.0 4.210738 4.227068 23 0.0 3.357889 3.404269 24 0.0 3.428183 3.818043 From 54c3d0905db19d88423443effa6eab12c2b99312 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 5 Dec 2019 14:24:27 -0500 Subject: [PATCH 10/13] CHANGELOG update and style fixes --- CHANGELOG.md | 1 + python/cuml/test/test_kneighbors_classifier.py | 3 ++- python/cuml/test/test_nearest_neighbors.py | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46f1b9172f..7cfe6d666b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,7 @@ - PR #1401: Patch for lbfgs solver for logistic regression with no l1 penalty - PR #1416: train_test_split numba and rmm device_array output bugfix - PR #1419: UMAP pickle tests are using wrong n_neighbors value for trustworthiness +- PR $1438: KNN Classifier to properly return Dataframe with Dataframe input # cuML 0.10.0 (16 Oct 2019) diff --git a/python/cuml/test/test_kneighbors_classifier.py b/python/cuml/test/test_kneighbors_classifier.py index 3f24323f36..9220f9edfe 100644 --- a/python/cuml/test/test_kneighbors_classifier.py +++ b/python/cuml/test/test_kneighbors_classifier.py @@ -33,7 +33,8 @@ @pytest.mark.parametrize("ncols", [50, 100]) @pytest.mark.parametrize("n_neighbors", [2, 5, 10]) @pytest.mark.parametrize("n_clusters", [2, 5, 10]) -def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters,datatype): +def test_neighborhood_predictions(nrows, ncols, n_neighbors, + n_clusters, datatype): X, y = make_blobs(n_samples=nrows, centers=n_clusters, diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index 9e7e1314a1..45cdbff7a0 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -124,7 +124,6 @@ def test_cuml_against_sklearn(input_type, nrows, n_feats, k): D_cuml_arr = D_cuml I_cuml_arr = I_cuml - assert array_equal(D_cuml_arr, D_sk, 1e-2, with_sign=True) assert I_cuml_arr.all() == I_sk.all() From 92416fb2e80bb9e2dea9d9ce295aad881b68e0cd Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 5 Dec 2019 15:05:59 -0500 Subject: [PATCH 11/13] Style fixes --- python/cuml/metrics/accuracy.pyx | 4 ++-- python/cuml/neighbors/nearest_neighbors.pyx | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/cuml/metrics/accuracy.pyx b/python/cuml/metrics/accuracy.pyx index 54a4265f2f..a3c744272b 100644 --- a/python/cuml/metrics/accuracy.pyx +++ b/python/cuml/metrics/accuracy.pyx @@ -66,9 +66,9 @@ def accuracy_score(ground_truth, predictions, handle=None, convert_dtype=True): convert_to_dtype=np.int32 if convert_dtype else None) - ground_truth_m, ground_truth_ptr, _, _, ground_truth_dtype = \ + ground_truth_m, ground_truth_ptr, _, _, ground_truth_dtype=\ input_to_dev_array(ground_truth, - convert_to_dtype = np.int32 + convert_to_dtype=np.int32 if convert_dtype else None) acc = accuracy_score_py(handle_[0], diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 0a8cfceb62..6bd4c59271 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -287,7 +287,6 @@ class NearestNeighbors(Base): raise ValueError("Dimensions of X need to match dimensions of " "indices (%d)" % self.n_dims) - X_m, X_ctype, N, _, dtype = \ input_to_dev_array(X, order='F', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype From 6c0b82ec52fd82d79ad947d73e3ce33e37e66c6e Mon Sep 17 00:00:00 2001 From: John Zedlewski Date: Thu, 5 Dec 2019 14:30:29 -0800 Subject: [PATCH 12/13] Add joblib as explicit dependency --- conda/recipes/cuml/meta.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/conda/recipes/cuml/meta.yaml b/conda/recipes/cuml/meta.yaml index 40c2e58416..ccb634e0f5 100644 --- a/conda/recipes/cuml/meta.yaml +++ b/conda/recipes/cuml/meta.yaml @@ -41,6 +41,7 @@ requirements: - libcumlprims {{ minor_version }} - cupy>=6.5,<7.0 - nccl 2.4.* + - joblib >=0.11 - {{ pin_compatible('cudatoolkit', max_pin='x.x') }} about: From 266414dcc4bdd88cb9acf5f94bb773093747fe0d Mon Sep 17 00:00:00 2001 From: John Zedlewski Date: Fri, 6 Dec 2019 09:22:54 -0800 Subject: [PATCH 13/13] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 34caa9bd67..bafa2ddf59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ - PR #1416: train_test_split numba and rmm device_array output bugfix - PR #1419: UMAP pickle tests are using wrong n_neighbors value for trustworthiness - PR #1425: Deprecate seed and use random_state similar to Scikit-learn in train_test_split +- PR #1458: Add joblib as an explicit requirement # cuML 0.10.0 (16 Oct 2019)