Skip to content

Commit

Permalink
RF: python api behaviour refactor (#4207)
Browse files Browse the repository at this point in the history
This PR ⬇️ 
* fixes #4193 and fixes #4194 that relates to API incompatibility with dask-ml GridSearchCV
* changes the behaviour of cuml RF in the following cases:
    * In the not-so-uncommon case when `n_bins` > number of rows in training sample, instead of throwing error and exiting, the estimator is made to print a warning and use the `n_bins` as the number of training samples. 
    * When `.predict()` is called using `float64` data, instead of throwing an error asking user to explicitly specify `predict_model="CPU"` and rerun, a warning is displayed and implicity defaults to CPU-based prediction from the default GPU-based prediction.
 * Corresponding tests to capture the warnings from above added
 * the estimators now accept both numbers and strings as input for `split_criterion` parameter thus in parity with sklearn's API that takes in strings as criterion.
 * `split_algo` and `use_experimental_backend` parameters of the estimator class have now been completely removed from both documentation and warnings after deprecation in previous releases (from both single-gpu and dask RF). 
 * `num_classes` parameter of predict and score methods have also been similarly removed

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Rory Mitchell (https://github.com/RAMitchell)

URL: #4207
  • Loading branch information
venkywonka authored Sep 21, 2021
1 parent f415f92 commit b375320
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 165 deletions.
26 changes: 5 additions & 21 deletions python/cuml/dask/ensemble/randomforestclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,11 @@ class RandomForestClassifier(BaseRandomForestModel, DelayedPredictionMixin,
run different models concurrently in different streams by creating
handles in several streams.
If it is None, a new one is created.
split_criterion : The criterion used to split nodes.
0 for GINI, 1 for ENTROPY, 4 for CRITERION_END.
2 and 3 not valid for classification
(default = 0)
split_algo : 0 for HIST and 1 for GLOBAL_QUANTILE (default = 1)
the algorithm to determine how nodes are split in the tree.
split_criterion : The criterion used to split nodes.
0 for GINI, 1 for ENTROPY, 4 for CRITERION_END.
2 and 3 not valid for classification
(default = 0)
split_criterion : int or string (default = 0 ('gini'))
The criterion used to split nodes.
0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY,
2 or 'mse' for MSE
2 or 'mse' not valid for classification
bootstrap : boolean (default = True)
Control bootstrapping.
If set, each tree in the forest is built
Expand Down Expand Up @@ -112,17 +107,6 @@ class RandomForestClassifier(BaseRandomForestModel, DelayedPredictionMixin,
If float, then min_samples_split represents a fraction and
ceil(min_samples_split * n_rows) is the minimum number of samples
for each split.
quantile_per_tree : boolean (default = False)
Whether quantile is computed for individual RF trees.
Only relevant for GLOBAL_QUANTILE split_algo.
use_experimental_backend : boolean (default = True)
If set to true and the following conditions are also met, a new
experimental backend for decision tree training will be used. The
new backend is available only if `split_algo = 1` (GLOBAL_QUANTILE)
and `quantile_per_tree = False` (No per tree quantile computation).
The new backend is considered stable for classification tasks but
not yet for regression tasks. The RAPIDS team is continuing
optimization and evaluation of the new backend for regression tasks.
n_streams : int (default = 4 )
Number of parallel streams used for forest building
workers : optional, list of strings
Expand Down
22 changes: 4 additions & 18 deletions python/cuml/dask/ensemble/randomforestregressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,11 @@ class RandomForestRegressor(BaseRandomForestModel, DelayedPredictionMixin,
run different models concurrently in different streams by creating
handles in several streams.
If it is None, a new one is created.
split_algo : int (default = 1)
0 for HIST, 1 for GLOBAL_QUANTILE
The type of algorithm to be used to create the trees.
split_criterion : int (default = 2)
split_criterion : int or string (default = 2 ('mse'))
The criterion used to split nodes.
0 for GINI, 1 for ENTROPY,
2 for MSE, 3 for MAE and 4 for CRITERION_END.
0 and 1 not valid for regression
0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY,
2 or 'mse' for MSE
only 2 or 'mse' valid for regression
bootstrap : boolean (default = True)
Control bootstrapping.
If set, each tree in the forest is built
Expand Down Expand Up @@ -118,17 +115,6 @@ class RandomForestRegressor(BaseRandomForestModel, DelayedPredictionMixin,
for median of abs error : 'median_ae'
for mean of abs error : 'mean_ae'
for mean square error' : 'mse'
quantile_per_tree : boolean (default = False)
Whether quantile is computed for individual RF trees.
Only relevant for GLOBAL_QUANTILE split_algo.
use_experimental_backend : boolean (default = False)
If set to true and the following conditions are also met, a new
experimental backend for decision tree training will be used. The
new backend is available only if `split_algo = 1` (GLOBAL_QUANTILE)
and `quantile_per_tree = False` (No per tree quantile computation).
The new backend is considered stable for classification tasks but
not yet for regression tasks. The RAPIDS team is continuing
optimization and evaluation of the new backend for regression tasks.
n_streams : int (default = 4 )
Number of parallel streams used for forest building
workers : optional, list of strings
Expand Down
43 changes: 20 additions & 23 deletions python/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import ctypes
import cupy as cp
import math
Expand Down Expand Up @@ -41,21 +40,24 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
class BaseRandomForestModel(Base):
_param_names = ['n_estimators', 'max_depth', 'handle',
'max_features', 'n_bins',
'split_algo', 'split_criterion', 'min_samples_leaf',
'split_criterion', 'min_samples_leaf',
'min_samples_split',
'min_impurity_decrease',
'bootstrap',
'verbose', 'max_samples',
'max_leaves',
'accuracy_metric', 'use_experimental_backend',
'accuracy_metric',
'max_batch_size', 'n_streams', 'dtype',
'output_type', 'min_weight_fraction_leaf', 'n_jobs',
'max_leaf_nodes', 'min_impurity_split', 'oob_score',
'random_state', 'warm_start', 'class_weight',
'criterion']

criterion_dict = {'0': GINI, '1': ENTROPY, '2': MSE,
'3': MAE, '4': CRITERION_END}
criterion_dict = {'0': GINI, 'gini': GINI,
'1': ENTROPY, 'entropy': ENTROPY,
'2': MSE, 'mse': MSE,
'3': MAE, 'mae': MAE,
'4': CRITERION_END}

classes_ = CumlArrayDescriptor()

Expand Down Expand Up @@ -104,14 +106,6 @@ class BaseRandomForestModel(Base):
"recommended. If n_streams is > 1, results may vary "
"due to stream/thread timing differences, even when "
"random_state is set")
if 'use_experimental_backend' in kwargs.keys():
warnings.warn("The 'use_experimental_backend' parameter is "
"deprecated and has no effect. "
"It will be removed in 21.10 release.")
if 'split_algo' in kwargs.keys():
warnings.warn("The 'split_algo' parameter is "
"deprecated and has no effect. "
"It will be removed in 21.10 release.")
if handle is None:
handle = Handle(n_streams)

Expand Down Expand Up @@ -247,8 +241,10 @@ class BaseRandomForestModel(Base):
input_to_cuml_array(X, check_dtype=[np.float32, np.float64],
order='F')
if self.n_bins > self.n_rows:
raise ValueError("The number of bins,`n_bins` can not be greater"
" than the number of samples used for training.")
warnings.warn("The number of bins, `n_bins` is greater than "
"the number of samples used for training. "
"Changing `n_bins` to number of training samples.")
self.n_bins = self.n_rows

if self.RF_type == CLASSIFICATION:
y_m, _, _, y_dtype = \
Expand Down Expand Up @@ -329,14 +325,14 @@ class BaseRandomForestModel(Base):
check_cols=self.n_cols)

if dtype == np.float64 and not convert_dtype:
raise TypeError("GPU based predict only accepts np.float32 data. \
Please set convert_dtype=True to convert the test \
data to the same dtype as the data used to train, \
ie. np.float32. If you would like to use test \
data of dtype=np.float64 please set \
predict_model='CPU' to use the CPU implementation \
of predict.")

warnings.warn("GPU based predict only accepts "
"np.float32 data. The model was "
"trained on np.float64 data hence "
"cannot use GPU-based prediction! "
"\nDefaulting to CPU-based Prediction. "
"\nTo predict on float-64 data, set "
"parameter predict_model = 'CPU'")
return self._predict_model_on_cpu(X, convert_dtype=convert_dtype)
treelite_handle = self._obtain_treelite_handle()

storage_type = \
Expand Down Expand Up @@ -365,6 +361,7 @@ class BaseRandomForestModel(Base):
self.treelite_serialized_model = None

super().set_params(**params)
return self


def _check_fil_parameter_validity(depth, algo, fil_sparse_format):
Expand Down
78 changes: 17 additions & 61 deletions python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

# distutils: language = c++

import numpy as np
import rmm
import warnings
Expand Down Expand Up @@ -176,13 +175,11 @@ class RandomForestClassifier(BaseRandomForestModel,
-----------
n_estimators : int (default = 100)
Number of trees in the forest. (Default changed to 100 in cuML 0.11)
split_criterion : The criterion used to split nodes.
0 for GINI, 1 for ENTROPY
2 and 3 not valid for classification
(default = 0)
split_algo : int (default = 1)
Deprecated and currrently has no effect.
.. deprecated:: 21.06
split_criterion : int or string (default = 0 ('gini'))
The criterion used to split nodes.
0 or 'gini' for GINI, 1 or 'entropy' for ENTROPY,
2 or 'mse' for MSE
2 or 'mse' not valid for classification
bootstrap : boolean (default = True)
Control bootstrapping.
If True, each tree in the forest is built
Expand Down Expand Up @@ -226,9 +223,6 @@ class RandomForestClassifier(BaseRandomForestModel,
min_impurity_decrease : float (default = 0.0)
Minimum decrease in impurity requried for
node to be spilt.
use_experimental_backend : boolean (default = True)
Deprecated and currrently has no effect.
.. deprecated:: 21.08
max_batch_size: int (default = 4096)
Maximum number of nodes that can be processed in a given batch.
random_state : int (default = None)
Expand Down Expand Up @@ -559,8 +553,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
def predict(self, X, predict_model="GPU", threshold=0.5,
algo='auto', num_classes=None,
convert_dtype=True,
algo='auto', convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
"""
Predicts the labels for X.
Expand Down Expand Up @@ -589,13 +582,6 @@ class RandomForestClassifier(BaseRandomForestModel,
threshold : float (default = 0.5)
Threshold used for classification. Optional and required only
while performing the predict operation on the GPU.
num_classes : int (default = None)
number of different classes present in the dataset.

.. deprecated:: 0.16
Parameter 'num_classes' is deprecated and will be removed in
an upcoming version. The number of classes passed must match
the number of classes the model was trained on.

convert_dtype : bool, optional (default = True)
When set to True, the predict method will, when necessary, convert
Expand All @@ -617,24 +603,19 @@ class RandomForestClassifier(BaseRandomForestModel,
y : {}
"""
nvtx_range_push("predict RF-Classifier @randomforestclassifier.pyx")
if num_classes:
warnings.warn("num_classes is deprecated and will be removed"
" in an upcoming version")
if num_classes != self.num_classes:
raise NotImplementedError("limiting num_classes for predict"
" is not implemented")
if predict_model == "CPU":
preds = self._predict_model_on_cpu(X,
convert_dtype=convert_dtype)

elif self.dtype == np.float64:
raise TypeError("GPU based predict only accepts np.float32 data. \
In order use the GPU predict the model should \
also be trained using a np.float32 dataset. \
If you would like to use np.float64 dtype \
then please use the CPU based predict by \
setting predict_model = 'CPU'")

warnings.warn("GPU based predict only accepts "
"np.float32 data. The model was "
"trained on np.float64 data hence "
"cannot use GPU-based prediction! "
"\nDefaulting to CPU-based Prediction. "
"\nTo predict on float-64 data, set "
"parameter predict_model = 'CPU'")
preds = self._predict_model_on_cpu(X,
convert_dtype=convert_dtype)
else:
preds = \
self._predict_model_on_gpu(X=X, output_class=True,
Expand All @@ -650,7 +631,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
def predict_proba(self, X, algo='auto',
num_classes=None, convert_dtype=True,
convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
"""
Predicts class probabilites for X. This function uses the GPU
Expand All @@ -673,14 +654,6 @@ class RandomForestClassifier(BaseRandomForestModel,
* ``'batch_tree_reorg'`` is used for dense storage
and 'naive' for sparse storage

num_classes : int (default = None)
number of different classes present in the dataset.

.. deprecated:: 0.16
Parameter 'num_classes' is deprecated and will be removed in
an upcoming version. The number of classes passed must match
the number of classes the model was trained on.

convert_dtype : bool, optional (default = True)
When set to True, the predict method will, when necessary, convert
the input to the data type which was used to train the model. This
Expand Down Expand Up @@ -708,15 +681,6 @@ class RandomForestClassifier(BaseRandomForestModel,
then please use the CPU based predict by \
setting predict_model = 'CPU'")

if num_classes:
warnings.warn("num_classes is deprecated and will be removed"
" in an upcoming version")
if num_classes != self.num_classes:
raise NotImplementedError("The number of classes in the test "
"dataset should be equal to the "
"number of classes present in the "
"training dataset.")

preds_proba = \
self._predict_model_on_gpu(X, output_class=True,
algo=algo,
Expand All @@ -729,7 +693,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)'),
('dense_intdtype', '(n_samples, 1)')])
def score(self, X, y, threshold=0.5,
algo='auto', num_classes=None, predict_model="GPU",
algo='auto', predict_model="GPU",
convert_dtype=True, fil_sparse_format='auto'):
"""
Calculates the accuracy metric score of the model for X.
Expand All @@ -755,13 +719,6 @@ class RandomForestClassifier(BaseRandomForestModel,
threshold is used to for classification
This is optional and required only while performing the
predict operation on the GPU.
num_classes : int (default = None)
number of different classes present in the dataset.
.. deprecated:: 0.16
Parameter 'num_classes' is deprecated and will be removed in
an upcoming version. The number of classes passed must match
the number of classes the model was trained on.
convert_dtype : boolean, default=True
whether to convert input data to correct dtype automatically
Expand Down Expand Up @@ -803,7 +760,6 @@ class RandomForestClassifier(BaseRandomForestModel,
threshold=threshold, algo=algo,
convert_dtype=convert_dtype,
predict_model=predict_model,
num_classes=num_classes,
fil_sparse_format=fil_sparse_format)

cdef uintptr_t preds_ptr
Expand Down
Loading

0 comments on commit b375320

Please sign in to comment.