Skip to content

Commit

Permalink
Merge pull request #2 from rapidsai/branch-0.12
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
vishalmehta1991 authored Dec 9, 2019
2 parents ba5a222 + 1de0fea commit 3cb4696
Show file tree
Hide file tree
Showing 15 changed files with 215 additions and 109 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
- PR #1258: Documenting new MPI communicator for multi-node multi-GPU testing
- PR #1345: Removing deprecated should_downcast argument
- PR #1362: device_buffer in UMAP + Sparse prims
- PR #1376: AUTO value for FIL algorithm
- PR #1408: Updated pickle tests to delete the pre-pickled model to prevent pointer leakage
- PR #1357: Run benchmarks multiple times for CI
- PR #1382: ARIMA optimization: move functions to C++ side
Expand Down Expand Up @@ -91,7 +92,9 @@
- PR #1401: Patch for lbfgs solver for logistic regression with no l1 penalty
- PR #1416: train_test_split numba and rmm device_array output bugfix
- PR #1419: UMAP pickle tests are using wrong n_neighbors value for trustworthiness
- PR #1438: KNN Classifier to properly return Dataframe with Dataframe input
- PR #1425: Deprecate seed and use random_state similar to Scikit-learn in train_test_split
- PR #1458: Add joblib as an explicit requirement

# cuML 0.10.0 (16 Oct 2019)

Expand Down
1 change: 1 addition & 0 deletions conda/recipes/cuml/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ requirements:
- libcumlprims {{ minor_version }}
- cupy>=6.5,<7.0
- nccl 2.4.*
- joblib >=0.11
- {{ pin_compatible('cudatoolkit', max_pin='x.x') }}

about:
Expand Down
3 changes: 3 additions & 0 deletions cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ namespace fil {

/** Inference algorithm to use. */
enum algo_t {
/** choose the algorithm automatically; currently chooses NAIVE for sparse forests
and BATCH_TREE_REORG for dense ones */
ALGO_AUTO,
/** naive algorithm: 1 thread block predicts 1 row; the row is cached in
shared memory, and the trees are distributed cyclically between threads */
NAIVE,
Expand Down
11 changes: 8 additions & 3 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ struct dense_forest : forest {
void init(const cumlHandle& h, const dense_node_t* nodes,
const forest_params_t* params) {
init_common(params);
if (algo_ == algo_t::NAIVE) algo_ = algo_t::BATCH_TREE_REORG;

int num_nodes = forest_num_nodes(num_trees_, depth_);
nodes_ = (dense_node*)h.getDeviceAllocator()->allocate(
Expand Down Expand Up @@ -212,6 +213,7 @@ struct sparse_forest : forest {
void init(const cumlHandle& h, const int* trees, const sparse_node_t* nodes,
const forest_params_t* params) {
init_common(params);
if (algo_ == algo_t::ALGO_AUTO) algo_ = algo_t::NAIVE;
depth_ = 0; // a placeholder value
num_nodes_ = params->num_nodes;

Expand Down Expand Up @@ -251,18 +253,21 @@ void check_params(const forest_params_t* params, bool dense) {
} else {
ASSERT(params->num_nodes >= 0,
"num_nodes must be non-negative for sparse forests");
ASSERT(params->algo == algo_t::NAIVE,
"only NAIVE algorithm is supported for sparse forests");
ASSERT(params->algo == algo_t::NAIVE || params->algo == algo_t::ALGO_AUTO,
"only ALGO_AUTO and NAIVE algorithms are supported "
"for sparse forests");
}
ASSERT(params->num_trees >= 0, "num_trees must be non-negative");
ASSERT(params->num_cols >= 0, "num_cols must be non-negative");
switch (params->algo) {
case algo_t::ALGO_AUTO:
case algo_t::NAIVE:
case algo_t::TREE_REORG:
case algo_t::BATCH_TREE_REORG:
break;
default:
ASSERT(false, "aglo should be NAIVE, TREE_REORG or BATCH_TREE_REORG");
ASSERT(false,
"algo should be ALGO_AUTO, NAIVE, TREE_REORG or BATCH_TREE_REORG");
}
// output_t::RAW == 0, and doesn't have a separate flag
output_t all_set =
Expand Down
9 changes: 9 additions & 0 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ std::vector<FilTestParams> predict_dense_inputs = {
{20000, 50, 0.05, 8, 50, 0.05,
fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5,
fil::algo_t::TREE_REORG, 42, 2e-3f},
{20000, 50, 0.05, 8, 50, 0.05, fil::output_t::SIGMOID, 0, 0,
fil::algo_t::ALGO_AUTO, 42, 2e-3f},
};

TEST_P(PredictDenseFilTest, Predict) { compare(); }
Expand Down Expand Up @@ -528,6 +530,9 @@ std::vector<FilTestParams> predict_sparse_inputs = {
{20000, 50, 0.05, 8, 50, 0.05,
fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5,
fil::algo_t::NAIVE, 42, 2e-3f},
{20000, 50, 0.05, 8, 50, 0.05,
fil::output_t(fil::output_t::SIGMOID | fil::output_t::THRESHOLD), 0, 0,
fil::algo_t::ALGO_AUTO, 42, 2e-3f},
};

TEST_P(PredictSparseFilTest, Predict) { compare(); }
Expand Down Expand Up @@ -601,6 +606,8 @@ std::vector<FilTestParams> import_dense_inputs = {
{20000, 50, 0.05, 8, 50, 0.05,
fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5,
fil::algo_t::TREE_REORG, 42, 2e-3f, tl::Operator::kGE},
{20000, 50, 0.05, 8, 50, 0.05, fil::output_t::SIGMOID, 0, 0,
fil::algo_t::ALGO_AUTO, 42, 2e-3f, tl::Operator::kLE},
};

TEST_P(TreeliteDenseFilTest, Import) { compare(); }
Expand Down Expand Up @@ -630,6 +637,8 @@ std::vector<FilTestParams> import_sparse_inputs = {
{20000, 50, 0.05, 8, 50, 0.05,
fil::output_t(fil::output_t::AVG | fil::output_t::THRESHOLD), 1.0, 0.5,
fil::algo_t::NAIVE, 42, 2e-3f, tl::Operator::kGE},
{20000, 50, 0.05, 8, 50, 0.05, fil::output_t::RAW, 0, 0,
fil::algo_t::ALGO_AUTO, 42, 2e-3f, tl::Operator::kLT},
};

TEST_P(TreeliteSparseFilTest, Import) { compare(); }
Expand Down
18 changes: 12 additions & 6 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ cdef class TreeliteModel():

cdef extern from "cuml/fil/fil.h" namespace "ML::fil":
cdef enum algo_t:
ALGO_AUTO,
NAIVE,
TREE_REORG,
BATCH_TREE_REORG
Expand Down Expand Up @@ -184,7 +185,9 @@ cdef class ForestInference_impl():
self.handle = handle

def get_algo(self, algo_str):
algo_dict={'NAIVE': algo_t.NAIVE,
algo_dict={'AUTO': algo_t.ALGO_AUTO,
'auto': algo_t.ALGO_AUTO,
'NAIVE': algo_t.NAIVE,
'naive': algo_t.NAIVE,
'BATCH_TREE_REORG': algo_t.BATCH_TREE_REORG,
'batch_tree_reorg': algo_t.BATCH_TREE_REORG,
Expand Down Expand Up @@ -399,7 +402,7 @@ class ForestInference(Base):
return self._impl.predict(X, preds)

def load_from_treelite_model(self, model, output_class,
algo='TREE_REORG',
algo='AUTO',
threshold=0.5,
storage_type='DENSE'):
"""
Expand All @@ -415,6 +418,9 @@ class ForestInference(Base):
If true, return a 1 or 0 depending on whether the raw prediction
exceeds the threshold. If False, just return the raw prediction.
algo : string name of the algo from (from algo_t enum)
'AUTO' or 'auto' - choose the algorithm automatically;
currently 'BATCH_TREE_REORG' is used for dense storage,
and 'NAIVE' for sparse storage
'NAIVE' or 'naive' - simple inference using shared memory
'TREE_REORG' or 'tree_reorg' - similar to naive but trees
rearranged to be more coalescing-friendly
Expand All @@ -429,7 +435,7 @@ class ForestInference(Base):
(currently DENSE is always used)
'DENSE' or 'dense' - create a dense forest
'SPARSE' or 'sparse' - create a sparse forest;
requires algo='NAIVE'
requires algo='NAIVE' or algo='AUTO'
"""
if isinstance(model, TreeliteModel):
# TreeliteModel defined in this file
Expand All @@ -445,7 +451,7 @@ class ForestInference(Base):
def load_from_sklearn(skl_model,
output_class=False,
threshold=0.50,
algo='TREE_REORG',
algo='AUTO',
storage_type='DENSE',
handle=None):
cuml_fm = ForestInference(handle=handle)
Expand All @@ -459,7 +465,7 @@ class ForestInference(Base):
def load(filename,
output_class=False,
threshold=0.50,
algo='TREE_REORG',
algo='AUTO',
storage_type='DENSE',
model_type="xgboost",
handle=None):
Expand Down Expand Up @@ -500,7 +506,7 @@ class ForestInference(Base):
def load_from_randomforest(self,
model_handle,
output_class=False,
algo='TREE_REORG',
algo='AUTO',
storage_type='DENSE',
threshold=0.50):

Expand Down
20 changes: 10 additions & 10 deletions python/cuml/metrics/accuracy.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import numpy as np

from libc.stdint cimport uintptr_t

import cudf

from cuml.common.handle cimport cumlHandle
from cuml.utils import input_to_dev_array
import cuml.common.handle
Expand All @@ -36,7 +38,7 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics":
int n) except +


def accuracy_score(ground_truth, predictions, handle=None):
def accuracy_score(ground_truth, predictions, handle=None, convert_dtype=True):
"""
Calcuates the accuracy score of a classification model.
Expand All @@ -57,19 +59,17 @@ def accuracy_score(ground_truth, predictions, handle=None):
if handle is None else handle
cdef cumlHandle* handle_ =\
<cumlHandle*><size_t>handle.getHandle()
if type(predictions) == np.ndarray:
predictions = predictions.astype(np.int32)
ground_truth = ground_truth.astype(np.int32)
else:
predictions = np.asarray(predictions, dtype=np.int32)
ground_truth = np.asarray(ground_truth, dtype=np.int32)

cdef uintptr_t preds_ptr, ground_truth_ptr
preds_m, preds_ptr, n_rows, _, _ = \
input_to_dev_array(predictions)
input_to_dev_array(predictions,
convert_to_dtype=np.int32
if convert_dtype else None)

ground_truth_m, ground_truth_ptr, _, _, ground_truth_dtype = \
input_to_dev_array(ground_truth)
ground_truth_m, ground_truth_ptr, _, _, ground_truth_dtype=\
input_to_dev_array(ground_truth,
convert_to_dtype=np.int32
if convert_dtype else None)

acc = accuracy_score_py(handle_[0],
<int*> preds_ptr,
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/metrics/regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def r2_score(y, y_hat, convert_dtype=False, handle=None):

n = len(y)

if y.dtype == 'float32':
if y_m.dtype == 'float32':

result_f32 = regression.r2_score_py(handle_[0],
<float*> y_ptr,
Expand Down
12 changes: 8 additions & 4 deletions python/cuml/neighbors/kneighbors_classifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ class KNeighborsClassifier(NearestNeighbors):
Output:
-------
.. code-block:: python
array([3, 1, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 1, 0, 0, 0, 2, 3, 3,
0, 3, 0, 0, 0, 0, 3, 2, 0, 0, 0], dtype=int32)
Notes
------
Expand Down Expand Up @@ -186,6 +186,7 @@ class KNeighborsClassifier(NearestNeighbors):
convert_dtype=convert_dtype)

cdef uintptr_t inds_ctype

inds, inds_ctype, n_rows, _, _ = \
input_to_dev_array(knn_indices, order='C', check_dtype=np.int64,
convert_to_dtype=(np.int64
Expand Down Expand Up @@ -227,7 +228,9 @@ class KNeighborsClassifier(NearestNeighbors):
if isinstance(X, np.ndarray):
return np.array(classes, dtype=np.int32)
elif isinstance(X, cudf.DataFrame):
return cudf.DataFrame.from_gpu_matrix(X)
if classes.ndim == 1:
classes = classes.reshape(classes.shape[0], 1)
return cudf.DataFrame.from_gpu_matrix(classes)
else:
return classes

Expand Down Expand Up @@ -330,6 +333,7 @@ class KNeighborsClassifier(NearestNeighbors):
"""
y_hat = self.predict(X, convert_dtype=convert_dtype)
if isinstance(y_hat, tuple):
return (accuracy_score(y, y_hat_i) for y_hat_i in y_hat)
return (accuracy_score(y, y_hat_i, convert_dtype=convert_dtype)
for y_hat_i in y_hat)
else:
return accuracy_score(y, y_hat)
return accuracy_score(y, y_hat, convert_dtype=convert_dtype)
2 changes: 2 additions & 0 deletions python/cuml/neighbors/kneighbors_regressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class KNeighborsRegressor(NearestNeighbors):
if isinstance(X, np.ndarray):
return np.array(results)
elif isinstance(X, cudf.DataFrame):
if results.ndim == 1:
results = results.reshape(results.shape[0], 1)
return cudf.DataFrame.from_gpu_matrix(results)
else:
return results
Expand Down
Loading

0 comments on commit 3cb4696

Please sign in to comment.