diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bd5b60830..67da00e418 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ - PR #3135: Add QuasiNewton tests - PR #3040: Improved Array Conversion with CumlArrayDescriptor and Decorators - PR #3134: Improving the Deprecation Message Formatting in Documentation +- PR #3113: Add tags and prefered memory order tags to estimators - PR #3137: Reorganize Pytest Config and Add Quick Run Option - PR #3144: Adding Ability to Set Arbitrary Cmake Flags in ./build.sh - PR #3155: Eliminate unnecessary warnings from random projection test @@ -58,7 +59,7 @@ - PR #3086: Reverting FIL Notebook Testing - PR #3114: Fixed a typo in SVC's predict_proba AttributeError - PR #3117: Fix two crashes in experimental RF backend -- PR #3119: Fix memset args for benchmark +- PR #3119: Fix memset args for benchmark - PR #3130: Return Python string from `dump_as_json()` of RF - PR #3132: Add `min_samples_split` + Rename `min_rows_per_node` -> `min_samples_leaf` - PR #3136: Fix stochastic gradient descent example diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index 1173c12af9..4222c56372 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -346,3 +346,8 @@ class DBSCAN(Base): "max_mbytes_per_batch", "calc_core_sample_indices", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'C' + } diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx index b37e619eb0..5b74c39eae 100644 --- a/python/cuml/cluster/kmeans.pyx +++ b/python/cuml/cluster/kmeans.pyx @@ -622,3 +622,8 @@ class KMeans(Base): ['n_init', 'oversampling_factor', 'max_samples_per_batch', 'init', 'max_iter', 'n_clusters', 'random_state', 'tol'] + + def _more_tags(self): + return { + 'preferred_input_order': 'C' + } diff --git a/python/cuml/common/base.pyx b/python/cuml/common/base.pyx index 8c310c1108..19601a022a 100644 --- a/python/cuml/common/base.pyx +++ b/python/cuml/common/base.pyx @@ -27,6 +27,34 @@ from cuml.common.doc_utils import generate_docstring import cuml.common.input_utils +# tag system based on experimental tag system from Scikit-learn >=0.21 +# https://scikit-learn.org/stable/developers/develop.html#estimator-tags +_default_tags = { + # cuML specific tags + 'preferred_input_order': None, + 'X_types_gpu': ['2darray'], + + # Scikit-learn API standard tags + 'non_deterministic': False, + 'requires_positive_X': False, + 'requires_positive_y': False, + 'X_types': ['2darray'], + 'poor_score': False, + 'no_validation': False, + 'multioutput': False, + 'allow_nan': False, + 'stateless': False, + 'multilabel': False, + '_skip_test': False, + '_xfail_checks': False, + 'multioutput_only': False, + 'binary_only': False, + 'requires_fit': True, + 'requires_y': False, + 'pairwise': False, +} + + class Base(metaclass=cuml.internals.BaseMetaClass): """ Base class for all the ML algos. It handles some of the common operations @@ -348,6 +376,16 @@ class Base(metaclass=cuml.internals.BaseMetaClass): else: self.n_features_in_ = X.shape[1] + def _get_tags(self): + # method and code based on scikit-learn 0.21 _get_tags functionality: + # https://scikit-learn.org/stable/developers/develop.html#estimator-tags + collected_tags = _default_tags + for cl in reversed(inspect.getmro(self.__class__)): + if hasattr(cl, '_more_tags') and cl != Base: + more_tags = cl._more_tags(self) + collected_tags.update(more_tags) + return collected_tags + class RegressorMixin: """Mixin class for regression estimators in cuML""" @@ -379,6 +417,11 @@ class RegressorMixin: preds = self.predict(X, **kwargs) return r2_score(y, preds, handle=handle) + def _more_tags(self): + return { + 'requires_y': True + } + class ClassifierMixin: """Mixin class for classifier estimators in cuML""" @@ -410,6 +453,11 @@ class ClassifierMixin: preds = self.predict(X, **kwargs) return accuracy_score(y, preds, handle=handle) + def _more_tags(self): + return { + 'requires_y': True + } + # Internal, non class owned helper functions def _check_output_type_str(output_str): diff --git a/python/cuml/decomposition/pca.pyx b/python/cuml/decomposition/pca.pyx index 0e0e40a720..284b5ada9e 100644 --- a/python/cuml/decomposition/pca.pyx +++ b/python/cuml/decomposition/pca.pyx @@ -728,3 +728,10 @@ class PCA(Base): def __setstate__(self, state): self.__dict__.update(state) self.handle = Handle() + + def _more_tags(self): + return { + 'preferred_input_order': 'F', + 'X_types_gpu': ['2darray', 'sparse'], + 'X_types': ['2darray', 'sparse'] + } diff --git a/python/cuml/decomposition/tsvd.pyx b/python/cuml/decomposition/tsvd.pyx index 7881141172..0d3e4351bf 100644 --- a/python/cuml/decomposition/tsvd.pyx +++ b/python/cuml/decomposition/tsvd.pyx @@ -476,3 +476,8 @@ class TruncatedSVD(Base): def get_param_names(self): return super().get_param_names() + \ ["algorithm", "n_components", "n_iter", "random_state", "tol"] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index 722ab6c879..14d3f61ade 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -959,3 +959,9 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): if self.dtype == np.float64: return dump_rf_as_json(rf_forest64).decode('utf-8') return dump_rf_as_json(rf_forest).decode('utf-8') + + def _more_tags(self): + return { + # fit and predict require conflicting memory layouts + 'preferred_input_order': None + } diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index d73d5eda87..2ed34497df 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -756,3 +756,9 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): if self.dtype == np.float64: return dump_rf_as_json(rf_forest64).decode('utf-8') return dump_rf_as_json(rf_forest).decode('utf-8') + + def _more_tags(self): + return { + # fit and predict require conflicting memory layouts + 'preferred_input_order': None + } diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 046227f87b..63a48bb5ec 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -728,3 +728,8 @@ class ForestInference(Base): # DO NOT RETURN self._impl here!! return self + + def _more_tags(self): + return { + 'preferred_input_order': 'C' + } diff --git a/python/cuml/linear_model/elastic_net.pyx b/python/cuml/linear_model/elastic_net.pyx index fb942bc814..221137b318 100644 --- a/python/cuml/linear_model/elastic_net.pyx +++ b/python/cuml/linear_model/elastic_net.pyx @@ -240,3 +240,8 @@ class ElasticNet(Base, RegressorMixin): "tol", "selection", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/linear_model/lasso.pyx b/python/cuml/linear_model/lasso.pyx index 21013195e3..0bc545fe23 100644 --- a/python/cuml/linear_model/lasso.pyx +++ b/python/cuml/linear_model/lasso.pyx @@ -208,3 +208,8 @@ class Lasso(Base, RegressorMixin): "tol", "selection", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 54b5c0be7a..0a03c34b50 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -352,3 +352,8 @@ class LinearRegression(Base, RegressorMixin): def get_param_names(self): return super().get_param_names() + \ ['algorithm', 'fit_intercept', 'normalize'] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/linear_model/logistic_regression.pyx b/python/cuml/linear_model/logistic_regression.pyx index 5f10cd18b5..0c177b31c9 100644 --- a/python/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/linear_model/logistic_regression.pyx @@ -424,3 +424,8 @@ class LogisticRegression(Base, ClassifierMixin): super(LogisticRegression, self).__init__(handle=None, verbose=state["verbose"]) self.__dict__.update(state) + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/linear_model/mbsgd_classifier.pyx b/python/cuml/linear_model/mbsgd_classifier.pyx index 90f923c15d..ad34abdcc4 100644 --- a/python/cuml/linear_model/mbsgd_classifier.pyx +++ b/python/cuml/linear_model/mbsgd_classifier.pyx @@ -219,3 +219,8 @@ class MBSGDClassifier(Base, ClassifierMixin): "batch_size", "n_iter_no_change", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/linear_model/mbsgd_regressor.pyx b/python/cuml/linear_model/mbsgd_regressor.pyx index 68a8ac6b80..b35dd4b71c 100644 --- a/python/cuml/linear_model/mbsgd_regressor.pyx +++ b/python/cuml/linear_model/mbsgd_regressor.pyx @@ -213,3 +213,8 @@ class MBSGDRegressor(Base, RegressorMixin): "batch_size", "n_iter_no_change", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/linear_model/ridge.pyx b/python/cuml/linear_model/ridge.pyx index 9634b05f75..724454220b 100644 --- a/python/cuml/linear_model/ridge.pyx +++ b/python/cuml/linear_model/ridge.pyx @@ -214,7 +214,6 @@ class Ridge(Base, RegressorMixin): def __init__(self, alpha=1.0, solver='eig', fit_intercept=True, normalize=False, handle=None, output_type=None, verbose=False): - """ Initializes the linear ridge regression class. @@ -394,3 +393,8 @@ class Ridge(Base, RegressorMixin): def get_param_names(self): return super().get_param_names() + \ ['solver', 'fit_intercept', 'normalize', 'alpha'] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index c9fd1e149d..ee21c829f9 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -485,3 +485,8 @@ class TSNE(Base): "pre_momentum", "post_momentum", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'C' + } diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 59271cb876..3a81188430 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -877,3 +877,8 @@ class UMAP(Base): "optim_batch_size", "callback", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'C' + } diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index e89b563449..682e03692c 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -305,3 +305,9 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): def get_param_names(self): return super().get_param_names() + ["weights"] + + def _more_tags(self): + return { + # fit and predict require conflicting memory layouts + 'preferred_input_order': 'F' + } diff --git a/python/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/neighbors/kneighbors_regressor.pyx index f626d36082..323e92dd74 100644 --- a/python/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/neighbors/kneighbors_regressor.pyx @@ -231,3 +231,9 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): def get_param_names(self): return super().get_param_names() + ["weights"] + + def _more_tags(self): + return { + # fit and predict require conflicting memory layouts + 'preferred_input_order': 'F' + } diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index abb8762de5..8631ec9717 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -746,3 +746,8 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, query = X.X_m return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode) + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/random_projection/random_projection.pyx b/python/cuml/random_projection/random_projection.pyx index ef35beac76..6b21a82b2d 100644 --- a/python/cuml/random_projection/random_projection.pyx +++ b/python/cuml/random_projection/random_projection.pyx @@ -589,3 +589,8 @@ class SparseRandomProjection(Base, BaseRandomProjection): "dense_output", "random_state" ] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/solvers/cd.pyx b/python/cuml/solvers/cd.pyx index f9e6094f21..134c7054c1 100644 --- a/python/cuml/solvers/cd.pyx +++ b/python/cuml/solvers/cd.pyx @@ -349,3 +349,8 @@ class CD(Base): "tol", "shuffle", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index 684bd89a03..8bac34a790 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -537,3 +537,8 @@ class QN(Base): return super().get_param_names() + \ ['loss', 'fit_intercept', 'l1_strength', 'l2_strength', 'max_iter', 'tol', 'linesearch_max_iter', 'lbfgs_memory'] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/solvers/sgd.pyx b/python/cuml/solvers/sgd.pyx index 8c968b2d1b..3d93272c81 100644 --- a/python/cuml/solvers/sgd.pyx +++ b/python/cuml/solvers/sgd.pyx @@ -507,3 +507,8 @@ class SGD(Base): "batch_size", "n_iter_no_change", ] + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/svm/svc.pyx b/python/cuml/svm/svc.pyx index cfc3d17536..4dc7a6715f 100644 --- a/python/cuml/svm/svc.pyx +++ b/python/cuml/svm/svc.pyx @@ -519,3 +519,8 @@ class SVC(SVMBase, ClassifierMixin): params.remove("epsilon") return params + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/svm/svm_base.pyx b/python/cuml/svm/svm_base.pyx index bdd38d2300..9babdb85dd 100644 --- a/python/cuml/svm/svm_base.pyx +++ b/python/cuml/svm/svm_base.pyx @@ -574,3 +574,8 @@ class SVMBase(Base): self.__dict__.update(state) self._model = self._get_svm_model() self._freeSvmBuffers = False + + def _more_tags(self): + return { + 'preferred_input_order': 'F' + } diff --git a/python/cuml/test/test_fit_function.py b/python/cuml/test/test_api.py similarity index 55% rename from python/cuml/test/test_fit_function.py rename to python/cuml/test/test_api.py index 8693e9484c..84e9b14ca6 100644 --- a/python/cuml/test/test_fit_function.py +++ b/python/cuml/test/test_api.py @@ -1,3 +1,19 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import pytest import cuml from cuml.test.utils import ClassEnumerator @@ -33,9 +49,68 @@ def dataset(): models_config = ClassEnumerator(module=cuml) models = models_config.get_models() +# tag system based on experimental tag system from Scikit-learn >=0.21 +# https://scikit-learn.org/stable/developers/develop.html#estimator-tags +tags = { + # cuML specific tags + 'preferred_input_order': None, + 'X_types_gpu': list, + + # Scikit-learn API standard tags + 'non_deterministic': bool, + 'requires_positive_X': bool, + 'requires_positive_y': bool, + 'X_types': list, + 'poor_score': bool, + 'no_validation': bool, + 'multioutput': bool, + 'allow_nan': bool, + 'stateless': bool, + 'multilabel': bool, + '_skip_test': bool, + '_xfail_checks': bool, + 'multioutput_only': bool, + 'binary_only': bool, + 'requires_fit': bool, + 'requires_y': bool, + 'pairwise': bool, +} + + +@pytest.mark.parametrize("model", list(models.values())) +def test_get_tags(model): + # This test ensures that our estimators return the tags defined by + # Scikit-learn and our cuML specific tags + # mod = models[model_name] + # assert hasattr('_get_tags', m) + + # for tag in tags: + # assert + + print(model) + if model in (cuml.tsa.auto_arima.AutoARIMA, cuml.tsa.arima.ARIMA, + cuml.tsa.holtwinters.ExponentialSmoothing): + mod = model(cp.ones(10)) + else: + mod = model() + + assert hasattr(mod, '_get_tags') + + model_tags = mod._get_tags() + for tag, tag_type in tags.items(): + # preferred input order can be None or a string + if tag == 'preferred_input_order': + if model_tags[tag] is not None: + assert isinstance(model_tags[tag], str) + else: + assert isinstance(model_tags[tag], tag_type) + + return True + @pytest.mark.parametrize("model_name", list(models.keys())) def test_fit_function(dataset, model_name): + # This test ensures that our estimators return self after a call to fit if model_name in [ "SparseRandomProjection", "TSNE", diff --git a/wiki/python/ESTIMATOR_GUIDE.md b/wiki/python/ESTIMATOR_GUIDE.md index 7a81031e17..06e2610172 100644 --- a/wiki/python/ESTIMATOR_GUIDE.md +++ b/wiki/python/ESTIMATOR_GUIDE.md @@ -4,15 +4,36 @@ This guide is meant to help developers follow the correct patterns when creating **Note:** This guide is long, because it includes internal details on how cuML manages input and output types for advanced use cases. But for the vast majority of estimators, the requirements are very simple and can follow the example patterns shown below in the [Quick Start Guide](#quick-start-guide). +## Table of Contents + +- [Recommended Scikit-Learn Documentation](#recommended-scikit-learn-documentation) +- [Quick Start Guide](#quick-start-guide) +- [Background](#background) + - [Input and Output Types in cuML](#input-and-output-types-in-cuml) + - [Specifying the Array Output Type](#specifying-the-array-output-type) + - [Ingesting Arrays](#ingesting-arrays) + - [Returning Arrays](#returning-arrays) +- [Estimator Design](#estimator-design) + - [Initialization](#initialization) + - [Implementing `get_param_names()`](#implementing-get_param_names) + - [Estimator Tags and cuML Specific Tags](#estimator-tags-and-cuml-specific-tags) + - [Estimator Array-Like Attributes](#estimator-array-like-attributes) + - [Estimator Methods](#estimator-methods) +- [Do's and Do Not's](#dos-and-do-nots) +- [Appendix](#appendix) + +## Recommended Scikit-Learn Documentation + To start, it's recommended to read the following Scikit-learn documentation: 1. [Scikit-learn's Estimator Docs](https://scikit-learn.org/stable/developers/develop.html) 1. cuML Estimator design follows Scikit-learn very closely. We will only cover portions where our design differs from this document - 2. If short on time, pay close attention to these sections, as these topics have caused pain points in the past: + 2. If short on time, pay attention to these sections, which are the most important (and have caused pain points in the past): 1. [Instantiation](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) 2. [Estimated Attributes](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) 3. [`get_params` and `set_params`](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) - 4. [cloning](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) + 4. [Cloning](https://scikit-learn.org/stable/developers/develop.html#cloning) + 5. [Estimator tags](https://scikit-learn.org/stable/developers/develop.html#estimator-tags) 2. [Scikit-learn's Docstring Guide](https://scikit-learn.org/stable/developers/contributing.html#guidelines-for-writing-documentation) 1. We follow the same guidelines for specifying array-like objects, array shapes, dtypes, and default values @@ -27,7 +48,7 @@ At a high level, all cuML Estimators must: ... ``` 2. Follow the Scikit-learn estimator guidelines found [here](https://scikit-learn.org/stable/developers/develop.html) -3. Include the `Base.__init__()` arguments available in the new Estimator's `__init__()` +3. Include the `Base.__init__()` arguments available in the new Estimator's `__init__()` ```python class MyEstimator(Base): @@ -49,7 +70,7 @@ At a high level, all cuML Estimators must: 5. Add input and return type annotations to public API functions OR wrap those functions explicitly with conversion decorators (see [this example](#non-standard-predict) for a non-standard use case) ```python class MyEstimator(Base): - + def fit(self, X) -> "MyEstimator": ... @@ -58,11 +79,21 @@ At a high level, all cuML Estimators must: ``` 6. Implement `get_param_names()` including values returned by `super().get_param_names()` ```python - def get_param_names(self): - return super().get_param_names() + [ - "eps", - "min_samples", - ] + def get_param_names(self): + return super().get_param_names() + [ + "eps", + "min_samples", + ] + ``` + +7. Implement `_more_tags()` if any of the [default tags]() need to be overriden for the new estimator: + ```python + def _more_tags(self): + return { + 'preferred_input_order': 'F', + 'X_types_gpu': ['2darray', 'sparse'] + 'X_types': ['2darray', 'sparse'] + } ``` For the majority of estimators, the above steps will be sufficient to correctly work with the cuML library and ensure a consistent API. However, situations may arise where an estimator differs from the standard pattern and some of the functionality needs to be customized. The remainder of this guide takes a deep dive into the estimator functionality to assist developers when building estimators. @@ -92,7 +123,7 @@ Internally, all arrays should be converted to `CumlArray` as much as possible si ### Specifying the Array Output Type -Users can choose which array type should be returned by cuml by either: +Users can choose which array type should be returned by cuml by either: 1. Individually setting the output_type property on an estimator class (i.e `Base(output_type="numpy")`) 2. Globally setting the `cuml.global_output_type` 3. Temporarily setting the `cuml.global_output_type` via the `cuml.using_output_type` context manager @@ -168,6 +199,17 @@ def get_param_names(self): **Note:** Be sure to include `super().get_param_names()` in the returned list to properly set the `super()` attributes. +### Estimator Tags and cuML-Specific Tags + +Scikit-learn introduced estimator tags in version 0.21, which are used to programmatically inspect the capabilities of estimators. These capabilities include items like sparse matrix support and the need for positive inputs, among other things. cuML estimators support _all_ of the tags defined by the Scikit-learn estimator [developer guide](https://scikit-learn.org/stable/developers/index.html), and will add support for any tag added there. + +Additionaly, some tags specific to cuML have been added. These tags may or may not be specific to GPU data types and can even apply outside of automated testing, such as allowing for the optimization of data generation. This can be useful for pipelines and HPO, among other things. These are: + +- `X_types_gpu` (default=['2darray']) + Analogous to `X_types`, indicates what types of GPU objects an estimator can take. `2darray` includes GPU ndarray objects (like CuPy and Numba) and cuDF objects, since they are all processed the same by `input_utils`. `sparse` includes `CuPy` sparse arrays. + - `preferred_input_order` (default=None) + One of ['F', 'C', None]. Whether an estimator "prefers" data in column-major ('F') or row-major ('C') contiguous memory layout. If different methods prefer different layouts or neither format is benefitial, then it is defined to `None` unless there is a good reason to chose either `F` or `C`. For example, all of `fit`, `predict`, etc. in an estimator use `F` but only `score` uses`C`. + ### Estimator Array-Like Attributes Any array-like attribute stored in an estimator needs to be convertible to the user's desired output type. To make it easier to store array-like objects in a class that derives from `Base`, the `cuml.common.array_descriptor.CumlArrayDescriptor` was created. The `CumlArrayDescriptor` class is a Python descriptor object which allows cuML to implement customized attribute lookup, storage and deletion code that can be reused on all estimators. @@ -178,7 +220,7 @@ Performing the arrray conversion lazily (i.e. converting the input array to the #### Defining Array-Like Attributes -To use the `CumlArrayDescriptor` in an estimator, any array-like attributes need to be specified by creating a `CumlArrayDescriptor` as a class variable. +To use the `CumlArrayDescriptor` in an estimator, any array-like attributes need to be specified by creating a `CumlArrayDescriptor` as a class variable. ```python from cuml.common.array_descriptor import CumlArrayDescriptor @@ -242,9 +284,9 @@ def score(self): # Accessing my_cuml_array_ will return a numpy array and # the result can be returned directly return np.sum(self.my_cuml_array_, axis=0) -``` +``` -This has the same benefits of lazy conversion and caching as when descriptors are used externally. +This has the same benefits of lazy conversion and caching as when descriptors are used externally. #### CumlArrayDescriptor External Functionality @@ -297,11 +339,11 @@ To allow estimator methods to accept a wide variety of inputs and outputs, a set #### Option 1: Automatic Array Conversion From Type Annotation -To automatically convert array-like objects being returned by an Estimator method, a new metaclass has been added to `Base` that can scan the return type information of an Estimator method and infer which, if any, array conversion should be done. For example, if a method returns a type of `Base`, cuML can assume this method is likely similar to `fit()` and should call `Base._set_base_attributes()` before calling the method. If a method returns a type of `CumlArray`, cuML can assume this method is similar to `predict()` or `transform()`, and the return value is an array that may need to be converted using the output type calculated in `Base._get_output_type()`. +To automatically convert array-like objects being returned by an Estimator method, a new metaclass has been added to `Base` that can scan the return type information of an Estimator method and infer which, if any, array conversion should be done. For example, if a method returns a type of `Base`, cuML can assume this method is likely similar to `fit()` and should call `Base._set_base_attributes()` before calling the method. If a method returns a type of `CumlArray`, cuML can assume this method is similar to `predict()` or `transform()`, and the return value is an array that may need to be converted using the output type calculated in `Base._get_output_type()`. The full set of return types rules that will be applied by the `Base` metaclass are: -| Return Type | Converts Array Type? | Common Methods | Notes | +| Return Type | Converts Array Type? | Common Methods | Notes | | :---------: | :-----------: | :----------- | :----------- | | `Base` | No | `fit()` | Any type that inherits or `isinstance` of `Base` will work | | `CumlArray` | Yes | `predict()`, `transform()` | Functions can return any array-like object (`np.ndarray`, `cp.ndarray`, etc. all accepted) | @@ -340,7 +382,7 @@ def predict(self, X) -> CumlArray: ``` **Notes:** - - Its not necessary to convert to `CumlArray` and casting with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. + - It's not necessary to convert to `CumlArray` and cast with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. #### Option 2: Manual Estimator Method Decoration @@ -353,7 +395,7 @@ Which decorator to use for an estimator function is determined by 2 factors: The full set of descriptors can be organized by these two factors: -| Return Type-> | Array-Like | Sparse Array-Like | Generic | Any | +| Return Type-> | Array-Like | Sparse Array-Like | Generic | Any | | -----------: | :-----------: | :-----------: | :-----------: | :-----------: | | `Base` | `@api_base_return_array` | `@api_base_return_sparse_array` |`@api_base_return_generic` | `@api_base_return_any` | | Non-`Base` | `@api_return_array` | `@api_return_sparse_array` | `@api_return_generic` | `@api_return_any` | @@ -383,7 +425,7 @@ def predict(self, X): - It's not necessary to convert to `CumlArray` and casting with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. - Specifying `get_output_dtype=True` in the decorator argument instructs the decorator to also calculate the dtype in addition to the output type. -## Do's And Don'ts +## Do's And Do Not's ### **Do:** Add Return Typing Information to Estimator Functions @@ -520,7 +562,7 @@ Adding decorators to every estimator function just to use the decorator default 3. If an estimator function returns a `SparseCumlArray`, then `@api_base_return_sparse_array()` will be applied. 4. If an estimator function returns a `dict`, `tuple`, `list` or `typing.Union`, then `@api_base_return_generic()` will be applied. -| Return Type | Decorator | Notes | +| Return Type | Decorator | Notes | | :-----------: | :-----------: | :----------- | | `Base` | `@api_base_return_any(set_output_type=True, set_n_features_in=True)` | Any type that `isinstance` of `Base` will work | | `CumlArray` | `@api_base_return_array(get_output_type=True)` | Functions can return any array-like object | @@ -601,7 +643,7 @@ Every function in `cuml` is slightly different and some `fit()` functions may ne Since the decorator's functionality is very similar, so are their arguments. All of the decorators take similar arguments that will be outlined below. -| Argument | Type | Default | Meaning | +| Argument | Type | Default | Meaning | | :-----------: | :-----------: | :-----------: | :----------- | | `input_arg` | `str` | `'X'` or 1st non-self argument | Determines which input argument to use for `_set_output_type()` and `_set_n_features_in()` | | `target_arg` | `str` | `'y'` or 2nd non-self argument | Determines which input argument to use for `_set_target_dtype()` | @@ -677,7 +719,7 @@ def fit(self, X): ```python - + def fit(self, X) -> "KMeans": @@ -694,7 +736,7 @@ def fit(self, X) -> "KMeans": **Notes:** - `@with_cupy_rmm` is no longer needed. This is automatically applied for every public method of estimators - - `self._set_base_attributes()` no longer needs to be called. + - `self._set_base_attributes()` no longer needs to be called. ##### `predict()` @@ -731,7 +773,7 @@ def predict(self, X, y): ```python - + def predict(self, X) -> CumlArray: @@ -773,7 +815,7 @@ def predict(self, X) -> CumlArray: ```python - + @with_cupy_rmm def predict(self, X_in): # Determine the output_type @@ -815,7 +857,7 @@ def predict(self, X): # Return the cupy array directly return X_m - + ``` @@ -829,4 +871,4 @@ def predict(self, X): - In reality, this isn't necessary for this example. The decorator will look for an argument named `"X"` or default to the first, non `self`, argument. - `self._get_output_type()` and `self._get_target_dtype()` no longer needs to be called. Both the output type and dtype are determined automatically - It's not necessary to convert to `CumlArray` and casting with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. - - Specifying `get_output_dtype=True` in the decorator argument instructs the decorator to also calculate the dtype in addition to the output type. \ No newline at end of file + - Specifying `get_output_dtype=True` in the decorator argument instructs the decorator to also calculate the dtype in addition to the output type.